//===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "GPUOpsLowering.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/FormatVariadic.h" using namespace mlir; LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type) { LLVM::LLVMFuncOp ret; if (!(ret = moduleOp.template lookupSymbol(name))) { OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart(moduleOp.getBody()); ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External); } return ret; } static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp, StringRef prefix) { // Get a unique global name. unsigned stringNumber = 0; SmallString<16> stringConstName; do { stringConstName.clear(); (prefix + Twine(stringNumber++)).toStringRef(stringConstName); } while (moduleOp.lookupSymbol(stringConstName)); return stringConstName; } LLVM::GlobalOp mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment, unsigned addrSpace) { llvm::SmallString<20> nullTermStr(str); nullTermStr.push_back('\0'); // Null terminate for C auto globalType = LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes()); StringAttr attr = b.getStringAttr(nullTermStr); // Try to find existing global. for (auto globalOp : moduleOp.getOps()) if (globalOp.getGlobalType() == globalType && globalOp.getConstant() && globalOp.getValueAttr() == attr && globalOp.getAlignment().value_or(0) == alignment && globalOp.getAddrSpace() == addrSpace) return globalOp; // Not found: create new global. OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart(moduleOp.getBody()); SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix); return LLVM::GlobalOp::create(b, loc, globalType, /*isConstant=*/true, LLVM::Linkage::Internal, name, attr, alignment, addrSpace); } LogicalResult GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = gpuFuncOp.getLoc(); SmallVector workgroupBuffers; if (encodeWorkgroupAttributionsAsArguments) { // Append an `llvm.ptr` argument to the function signature to encode // workgroup attributions. ArrayRef workgroupAttributions = gpuFuncOp.getWorkgroupAttributions(); size_t numAttributions = workgroupAttributions.size(); // Insert all arguments at the end. unsigned index = gpuFuncOp.getNumArguments(); SmallVector argIndices(numAttributions, index); // New arguments will simply be `llvm.ptr` with the correct address space Type workgroupPtrType = rewriter.getType(workgroupAddrSpace); SmallVector argTypes(numAttributions, workgroupPtrType); // Attributes: noalias, llvm.mlir.workgroup_attribution(, ) std::array attrs{ rewriter.getNamedAttr(LLVM::LLVMDialect::getNoAliasAttrName(), rewriter.getUnitAttr()), rewriter.getNamedAttr( getDialect().getWorkgroupAttributionAttrHelper().getName(), rewriter.getUnitAttr()), }; SmallVector argAttrs; for (BlockArgument attribution : workgroupAttributions) { auto attributionType = cast(attribution.getType()); IntegerAttr numElements = rewriter.getI64IntegerAttr(attributionType.getNumElements()); Type llvmElementType = getTypeConverter()->convertType(attributionType.getElementType()); if (!llvmElementType) return failure(); TypeAttr type = TypeAttr::get(llvmElementType); attrs.back().setValue( rewriter.getAttr(numElements, type)); argAttrs.push_back(rewriter.getDictionaryAttr(attrs)); } // Location match function location SmallVector argLocs(numAttributions, gpuFuncOp.getLoc()); // Perform signature modification rewriter.modifyOpInPlace( gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() { LogicalResult inserted = static_cast(gpuFuncOp).insertArguments( argIndices, argTypes, argAttrs, argLocs); (void)inserted; assert(succeeded(inserted) && "expected GPU funcs to support inserting any argument"); }); } else { workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions()); for (auto [idx, attribution] : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { auto type = dyn_cast(attribution.getType()); assert(type && type.hasStaticShape() && "unexpected type in attribution"); uint64_t numElements = type.getNumElements(); auto elementType = cast(typeConverter->convertType(type.getElementType())); auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); std::string name = std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx)); uint64_t alignment = 0; if (auto alignAttr = dyn_cast_or_null( gpuFuncOp.getWorkgroupAttributionAttr( idx, LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); auto globalOp = LLVM::GlobalOp::create( rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment, workgroupAddrSpace); workgroupBuffers.push_back(globalOp); } } // Remap proper input types. TypeConverter::SignatureConversion signatureConversion( gpuFuncOp.front().getNumArguments()); Type funcType = getTypeConverter()->convertFunctionSignature( gpuFuncOp.getFunctionType(), /*isVariadic=*/false, getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion); if (!funcType) { return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) { diag << "failed to convert function signature type for: " << gpuFuncOp.getFunctionType(); }); } // Create the new function operation. Only copy those attributes that are // not specific to function modeling. SmallVector attributes; ArrayAttr argAttrs; for (const auto &attr : gpuFuncOp->getAttrs()) { if (attr.getName() == SymbolTable::getSymbolAttrName() || attr.getName() == gpuFuncOp.getFunctionTypeAttrName() || attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() || attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() || attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() || attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() || attr.getName() == gpuFuncOp.getKnownGridSizeAttrName()) continue; if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) { argAttrs = gpuFuncOp.getArgAttrsAttr(); continue; } attributes.push_back(attr); } DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr(); DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr(); // Ensure we don't lose information if the function is lowered before its // surrounding context. auto *gpuDialect = cast(gpuFuncOp->getDialect()); if (knownBlockSize) attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(), knownBlockSize); if (knownGridSize) attributes.emplace_back(gpuDialect->getKnownGridSizeAttrHelper().getName(), knownGridSize); // Add a dialect specific kernel attribute in addition to GPU kernel // attribute. The former is necessary for further translation while the // latter is expected by gpu.launch_func. if (gpuFuncOp.isKernel()) { if (kernelAttributeName) attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr()); // Set the dialect-specific block size attribute if there is one. if (kernelBlockSizeAttributeName && knownBlockSize) { attributes.emplace_back(kernelBlockSizeAttributeName, knownBlockSize); } } LLVM::CConv callingConvention = gpuFuncOp.isKernel() ? kernelCallingConvention : nonKernelCallingConvention; auto llvmFuncOp = LLVM::LLVMFuncOp::create( rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention, /*comdat=*/nullptr, attributes); { // Insert operations that correspond to converted workgroup and private // memory attributions to the body of the function. This must operate on // the original function, before the body region is inlined in the new // function to maintain the relation between block arguments and the // parent operation that assigns their semantics. OpBuilder::InsertionGuard guard(rewriter); // Rewrite workgroup memory attributions to addresses of global buffers. rewriter.setInsertionPointToStart(&gpuFuncOp.front()); unsigned numProperArguments = gpuFuncOp.getNumArguments(); if (encodeWorkgroupAttributionsAsArguments) { // Build a MemRefDescriptor with each of the arguments added above. unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions(); assert(numProperArguments >= numAttributions && "Expecting attributions to be encoded as arguments already"); // Arguments encoding workgroup attributions will be in positions // [numProperArguments, numProperArguments+numAttributions) ArrayRef attributionArguments = gpuFuncOp.getArguments().slice(numProperArguments - numAttributions, numAttributions); for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal( gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) { auto [attribution, arg] = vals; auto type = cast(attribution.getType()); // Arguments are of llvm.ptr type and attributions are of memref type: // we need to wrap them in memref descriptors. Value descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, arg); // And remap the arguments signatureConversion.remapInput(numProperArguments + idx, descr); } } else { for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) { auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()); Value address = LLVM::AddressOfOp::create(rewriter, loc, ptrType, global.getSymNameAttr()); Value memory = LLVM::GEPOp::create(rewriter, loc, ptrType, global.getType(), address, ArrayRef{0, 0}); // Build a memref descriptor pointing to the buffer to plug with the // existing memref infrastructure. This may use more registers than // otherwise necessary given that memref sizes are fixed, but we can try // and canonicalize that away later. Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx]; auto type = cast(attribution.getType()); Value descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, memory); signatureConversion.remapInput(numProperArguments + idx, descr); } } // Rewrite private memory attributions to alloca'ed buffers. unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions(); auto int64Ty = IntegerType::get(rewriter.getContext(), 64); for (const auto [idx, attribution] : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { auto type = cast(attribution.getType()); assert(type && type.hasStaticShape() && "unexpected type in attribution"); // Explicitly drop memory space when lowering private memory // attributions since NVVM models it as `alloca`s in the default // memory space and does not support `alloca`s with addrspace(5). Type elementType = typeConverter->convertType(type.getElementType()); auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace); Value numElements = LLVM::ConstantOp::create( rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); uint64_t alignment = 0; if (auto alignAttr = dyn_cast_or_null(gpuFuncOp.getPrivateAttributionAttr( idx, LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); Value allocated = LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment); Value descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, allocated); signatureConversion.remapInput( numProperArguments + numWorkgroupAttributions + idx, descr); } } // Move the region to the new function, update the entry block signature. rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), llvmFuncOp.end()); if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter, &signatureConversion))) return failure(); // Get memref type from function arguments and set the noalias to // pointer arguments. for (const auto [idx, argTy] : llvm::enumerate(gpuFuncOp.getArgumentTypes())) { auto remapping = signatureConversion.getInputMapping(idx); NamedAttrList argAttr = argAttrs ? cast(argAttrs[idx]) : NamedAttrList(); auto copyAttribute = [&](StringRef attrName) { Attribute attr = argAttr.erase(attrName); if (!attr) return; for (size_t i = 0, e = remapping->size; i < e; ++i) llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr); }; auto copyPointerAttribute = [&](StringRef attrName) { Attribute attr = argAttr.erase(attrName); if (!attr) return; if (remapping->size > 1 && attrName == LLVM::LLVMDialect::getNoAliasAttrName()) { emitWarning(llvmFuncOp.getLoc(), "Cannot copy noalias with non-bare pointers.\n"); return; } for (size_t i = 0, e = remapping->size; i < e; ++i) { if (isa( llvmFuncOp.getArgument(remapping->inputNo + i).getType())) { llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr); } } }; if (argAttr.empty()) continue; copyAttribute(LLVM::LLVMDialect::getReturnedAttrName()); copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName()); copyAttribute(LLVM::LLVMDialect::getInRegAttrName()); bool lowersToPointer = false; for (size_t i = 0, e = remapping->size; i < e; ++i) { lowersToPointer |= isa( llvmFuncOp.getArgument(remapping->inputNo + i).getType()); } if (lowersToPointer) { copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName()); copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName()); copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName()); copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName()); copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName()); copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName()); copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName()); copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName()); copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName()); copyPointerAttribute( LLVM::LLVMDialect::getDereferenceableOrNullAttrName()); copyPointerAttribute( LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr()); } } rewriter.eraseOp(gpuFuncOp); return success(); } LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = gpuPrintfOp->getLoc(); mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type()); auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type()); mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type()); // Note: this is the GPUModule op, not the ModuleOp that surrounds it // This ensures that global constants and declarations are placed within // the device code, not the host code auto moduleOp = gpuPrintfOp->getParentOfType(); auto ocklBegin = getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin", LLVM::LLVMFunctionType::get(llvmI64, {llvmI64})); LLVM::LLVMFuncOp ocklAppendArgs; if (!adaptor.getArgs().empty()) { ocklAppendArgs = getOrDefineFunction( moduleOp, loc, rewriter, "__ockl_printf_append_args", LLVM::LLVMFunctionType::get( llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64, llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32})); } auto ocklAppendStringN = getOrDefineFunction( moduleOp, loc, rewriter, "__ockl_printf_append_string_n", LLVM::LLVMFunctionType::get( llvmI64, {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32})); /// Start the printf hostcall Value zeroI64 = LLVM::ConstantOp::create(rewriter, loc, llvmI64, 0); auto printfBeginCall = LLVM::CallOp::create(rewriter, loc, ocklBegin, zeroI64); Value printfDesc = printfBeginCall.getResult(); // Create the global op or find an existing one. LLVM::GlobalOp global = getOrCreateStringConstant( rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element and pass it to printf() Value globalPtr = LLVM::AddressOfOp::create( rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), global.getSymNameAttr()); Value stringStart = LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), globalPtr, ArrayRef{0, 0}); Value stringLen = LLVM::ConstantOp::create( rewriter, loc, llvmI64, cast(global.getValueAttr()).size()); Value oneI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 1); Value zeroI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 0); auto appendFormatCall = LLVM::CallOp::create( rewriter, loc, ocklAppendStringN, ValueRange{printfDesc, stringStart, stringLen, adaptor.getArgs().empty() ? oneI32 : zeroI32}); printfDesc = appendFormatCall.getResult(); // __ockl_printf_append_args takes 7 values per append call constexpr size_t argsPerAppend = 7; size_t nArgs = adaptor.getArgs().size(); for (size_t group = 0; group < nArgs; group += argsPerAppend) { size_t bound = std::min(group + argsPerAppend, nArgs); size_t numArgsThisCall = bound - group; SmallVector arguments; arguments.push_back(printfDesc); arguments.push_back( LLVM::ConstantOp::create(rewriter, loc, llvmI32, numArgsThisCall)); for (size_t i = group; i < bound; ++i) { Value arg = adaptor.getArgs()[i]; if (auto floatType = dyn_cast(arg.getType())) { if (!floatType.isF64()) arg = LLVM::FPExtOp::create( rewriter, loc, typeConverter->convertType(rewriter.getF64Type()), arg); arg = LLVM::BitcastOp::create(rewriter, loc, llvmI64, arg); } if (arg.getType().getIntOrFloatBitWidth() != 64) arg = LLVM::ZExtOp::create(rewriter, loc, llvmI64, arg); arguments.push_back(arg); } // Pad out to 7 arguments since the hostcall always needs 7 for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) { arguments.push_back(zeroI64); } auto isLast = (bound == nArgs) ? oneI32 : zeroI32; arguments.push_back(isLast); auto call = LLVM::CallOp::create(rewriter, loc, ocklAppendArgs, arguments); printfDesc = call.getResult(); } rewriter.eraseOp(gpuPrintfOp); return success(); } LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = gpuPrintfOp->getLoc(); mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); // Note: this is the GPUModule op, not the ModuleOp that surrounds it // This ensures that global constants and declarations are placed within // the device code, not the host code auto moduleOp = gpuPrintfOp->getParentOfType(); auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType}, /*isVarArg=*/true); LLVM::LLVMFuncOp printfDecl = getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType); // Create the global op or find an existing one. LLVM::GlobalOp global = getOrCreateStringConstant( rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(), /*alignment=*/0, addressSpace); // Get a pointer to the format string's first element Value globalPtr = LLVM::AddressOfOp::create( rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), global.getSymNameAttr()); Value stringStart = LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), globalPtr, ArrayRef{0, 0}); // Construct arguments and function call auto argsRange = adaptor.getArgs(); SmallVector printfArgs; printfArgs.reserve(argsRange.size() + 1); printfArgs.push_back(stringStart); printfArgs.append(argsRange.begin(), argsRange.end()); LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs); rewriter.eraseOp(gpuPrintfOp); return success(); } LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = gpuPrintfOp->getLoc(); mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); // Note: this is the GPUModule op, not the ModuleOp that surrounds it // This ensures that global constants and declarations are placed within // the device code, not the host code auto moduleOp = gpuPrintfOp->getParentOfType(); // Create a valid global location removing any metadata attached to the // location as debug info metadata inside of a function cannot be used outside // of that function. Location globalLoc = loc->findInstanceOfOrUnknown(); auto vprintfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType}); LLVM::LLVMFuncOp vprintfDecl = getOrDefineFunction( moduleOp, globalLoc, rewriter, "vprintf", vprintfType); // Create the global op or find an existing one. LLVM::GlobalOp global = getOrCreateStringConstant(rewriter, globalLoc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, global); Value stringStart = LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), globalPtr, ArrayRef{0, 0}); SmallVector types; SmallVector args; // Promote and pack the arguments into a stack allocation. for (Value arg : adaptor.getArgs()) { Type type = arg.getType(); Value promotedArg = arg; assert(type.isIntOrFloat()); if (isa(type)) { type = rewriter.getF64Type(); promotedArg = LLVM::FPExtOp::create(rewriter, loc, type, arg); } types.push_back(type); args.push_back(promotedArg); } Type structType = LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types); Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), rewriter.getIndexAttr(1)); Value tempAlloc = LLVM::AllocaOp::create(rewriter, loc, ptrType, structType, one, /*alignment=*/0); for (auto [index, arg] : llvm::enumerate(args)) { Value ptr = LLVM::GEPOp::create( rewriter, loc, ptrType, structType, tempAlloc, ArrayRef{0, static_cast(index)}); LLVM::StoreOp::create(rewriter, loc, arg, ptr); } std::array printfArgs = {stringStart, tempAlloc}; LLVM::CallOp::create(rewriter, loc, vprintfDecl, printfArgs); rewriter.eraseOp(gpuPrintfOp); return success(); } /// Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements. /// Used either directly (for ops on 1D vectors) or as the callback passed to /// detail::handleMultidimensionalVectors (for ops on higher-rank vectors). static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands, Type llvm1DVectorTy, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter) { TypeRange operandTypes(operands); VectorType vectorType = cast(llvm1DVectorTy); Location loc = op->getLoc(); Value result = LLVM::PoisonOp::create(rewriter, loc, vectorType); Type indexType = converter.convertType(rewriter.getIndexType()); StringAttr name = op->getName().getIdentifier(); Type elementType = vectorType.getElementType(); for (int64_t i = 0; i < vectorType.getNumElements(); ++i) { Value index = LLVM::ConstantOp::create(rewriter, loc, indexType, i); auto extractElement = [&](Value operand) -> Value { if (!isa(operand.getType())) return operand; return LLVM::ExtractElementOp::create(rewriter, loc, operand, index); }; auto scalarOperands = llvm::map_to_vector(operands, extractElement); Operation *scalarOp = rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs()); result = LLVM::InsertElementOp::create(rewriter, loc, result, scalarOp->getResult(0), index); } return result; } /// Unrolls op to array/vector elements. LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter) { TypeRange operandTypes(operands); if (llvm::any_of(operandTypes, llvm::IsaPred)) { VectorType vectorType = cast(converter.convertType(op->getResultTypes()[0])); rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType, rewriter, converter)); return success(); } if (llvm::any_of(operandTypes, llvm::IsaPred)) { return LLVM::detail::handleMultidimensionalVectors( op, operands, converter, [&](Type llvm1DVectorTy, ValueRange operands) -> Value { return scalarizeVectorOpHelper(op, operands, llvm1DVectorTy, rewriter, converter); }, rewriter); } return rewriter.notifyMatchFailure(op, "no llvm.array or vector to unroll"); } static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) { return IntegerAttr::get(IntegerType::get(ctx, 64), space); } /// Generates a symbol with 0-sized array type for dynamic shared memory usage, /// or uses existing symbol. LLVM::GlobalOp getDynamicSharedMemorySymbol( ConversionPatternRewriter &rewriter, gpu::GPUModuleOp moduleOp, gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter, MemRefType memrefType, unsigned alignmentBit) { uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth(); FailureOr addressSpace = typeConverter->getMemRefAddressSpace(memrefType); if (failed(addressSpace)) { op->emitError() << "conversion of memref memory space " << memrefType.getMemorySpace() << " to integer address space " "failed. Consider adding memory space conversions."; } // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of // LLVM::GlobalOp is suitable for shared memory, return it. llvm::StringSet<> existingGlobalNames; for (auto globalOp : moduleOp.getBody()->getOps()) { existingGlobalNames.insert(globalOp.getSymName()); if (auto arrayType = dyn_cast(globalOp.getType())) { if (globalOp.getAddrSpace() == addressSpace.value() && arrayType.getNumElements() == 0 && globalOp.getAlignment().value_or(0) == alignmentByte) { return globalOp; } } } // Step 2. Find a unique symbol name unsigned uniquingCounter = 0; SmallString<128> symName = SymbolTable::generateSymbolName<128>( "__dynamic_shmem_", [&](StringRef candidate) { return existingGlobalNames.contains(candidate); }, uniquingCounter); // Step 3. Generate a global op OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); auto zeroSizedArrayType = LLVM::LLVMArrayType::get( typeConverter->convertType(memrefType.getElementType()), 0); return LLVM::GlobalOp::create(rewriter, op->getLoc(), zeroSizedArrayType, /*isConstant=*/false, LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte, addressSpace.value()); } LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite( gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); MemRefType memrefType = op.getResultMemref().getType(); Type elementType = typeConverter->convertType(memrefType.getElementType()); // Step 1: Generate a memref<0xi8> type MemRefLayoutAttrInterface layout = {}; auto memrefType0sz = MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace()); // Step 2: Generate a global symbol or existing for the dynamic shared // memory with memref<0xi8> type auto moduleOp = op->getParentOfType(); LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol( rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit); // Step 3. Get address of the global symbol OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); auto basePtr = LLVM::AddressOfOp::create(rewriter, loc, shmemOp); Type baseType = basePtr->getResultTypes().front(); // Step 4. Generate GEP using offsets SmallVector gepArgs = {0}; Value shmemPtr = LLVM::GEPOp::create(rewriter, loc, baseType, elementType, basePtr, gepArgs); // Step 5. Create a memref descriptor SmallVector shape, strides; Value sizeBytes; getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides, sizeBytes); auto memRefDescriptor = this->createMemRefDescriptor( loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter); // Step 5. Replace the op with memref descriptor rewriter.replaceOp(op, {memRefDescriptor}); return success(); } LogicalResult GPUReturnOpLowering::matchAndRewrite( gpu::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); unsigned numArguments = op.getNumOperands(); SmallVector updatedOperands; bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv; if (useBarePtrCallConv) { // For the bare-ptr calling convention, extract the aligned pointer to // be returned from the memref descriptor. for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { Type oldTy = std::get<0>(it).getType(); Value newOperand = std::get<1>(it); if (isa(oldTy) && getTypeConverter()->canConvertToBarePtr( cast(oldTy))) { MemRefDescriptor memrefDesc(newOperand); newOperand = memrefDesc.allocatedPtr(rewriter, loc); } else if (isa(oldTy)) { // Unranked memref is not supported in the bare pointer calling // convention. return failure(); } updatedOperands.push_back(newOperand); } } else { updatedOperands = llvm::to_vector<4>(adaptor.getOperands()); (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), updatedOperands, /*toDynamic=*/true); } // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments <= 1) { rewriter.replaceOpWithNewOp( op, TypeRange(), updatedOperands, op->getAttrs()); return success(); } // Otherwise, we need to pack the arguments into an LLVM struct type before // returning. auto packedType = getTypeConverter()->packFunctionResults( op.getOperandTypes(), useBarePtrCallConv); if (!packedType) { return rewriter.notifyMatchFailure(op, "could not convert result types"); } Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType); for (auto [idx, operand] : llvm::enumerate(updatedOperands)) { packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx); } rewriter.replaceOpWithNewOp(op, TypeRange(), packed, op->getAttrs()); return success(); } void mlir::populateGpuMemorySpaceAttributeConversions( TypeConverter &typeConverter, const MemorySpaceMapping &mapping) { typeConverter.addTypeAttributeConversion( [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) { gpu::AddressSpace memorySpace = memorySpaceAttr.getValue(); unsigned addressSpace = mapping(memorySpace); return wrapNumericMemorySpace(memorySpaceAttr.getContext(), addressSpace); }); }