//===-- CUFDeviceGlobal.cpp -----------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "flang/Optimizer/Transforms/CUFOpConversion.h" #include "flang/Optimizer/Builder/CUFCommon.h" #include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h" #include "flang/Optimizer/Builder/Runtime/RTBuilder.h" #include "flang/Optimizer/CodeGen/TypeConverter.h" #include "flang/Optimizer/Dialect/CUF/CUFOps.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/Support/DataLayout.h" #include "flang/Runtime/CUDA/allocatable.h" #include "flang/Runtime/CUDA/common.h" #include "flang/Runtime/CUDA/descriptor.h" #include "flang/Runtime/CUDA/memory.h" #include "flang/Runtime/CUDA/pointer.h" #include "flang/Runtime/allocatable.h" #include "flang/Runtime/allocator-registry-consts.h" #include "flang/Support/Fortran.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace fir { #define GEN_PASS_DEF_CUFOPCONVERSION #include "flang/Optimizer/Transforms/Passes.h.inc" } // namespace fir using namespace fir; using namespace mlir; using namespace Fortran::runtime; using namespace Fortran::runtime::cuda; namespace { static inline unsigned getMemType(cuf::DataAttribute attr) { if (attr == cuf::DataAttribute::Device) return kMemTypeDevice; if (attr == cuf::DataAttribute::Managed) return kMemTypeManaged; if (attr == cuf::DataAttribute::Unified) return kMemTypeUnified; if (attr == cuf::DataAttribute::Pinned) return kMemTypePinned; llvm::report_fatal_error("unsupported memory type"); } template static bool isPinned(OpTy op) { if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned) return true; return false; } template static bool hasDoubleDescriptors(OpTy op) { if (auto declareOp = mlir::dyn_cast_or_null(op.getBox().getDefiningOp())) { if (mlir::isa_and_nonnull( declareOp.getMemref().getDefiningOp())) { if (isPinned(declareOp)) return false; return true; } } else if (auto declareOp = mlir::dyn_cast_or_null( op.getBox().getDefiningOp())) { if (mlir::isa_and_nonnull( declareOp.getMemref().getDefiningOp())) { if (isPinned(declareOp)) return false; return true; } } return false; } static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type toTy, mlir::Value val) { if (val.getType() != toTy) return fir::ConvertOp::create(rewriter, loc, toTy, val); return val; } template static mlir::LogicalResult convertOpToCall(OpTy op, mlir::PatternRewriter &rewriter, mlir::func::FuncOp func) { auto mod = op->template getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); auto fTy = func.getFunctionType(); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine; if constexpr (std::is_same_v) sourceLine = fir::factory::locationToLineNo( builder, loc, op.getSource() ? fTy.getInput(7) : fTy.getInput(6)); else sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true) : builder.createBool(loc, false); mlir::Value errmsg; if (op.getErrmsg()) { errmsg = op.getErrmsg(); } else { mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType()); errmsg = fir::AbsentOp::create(builder, loc, boxNoneTy).getResult(); } llvm::SmallVector args; if constexpr (std::is_same_v) { mlir::Value pinned = op.getPinned() ? op.getPinned() : builder.createNullConstant( loc, fir::ReferenceType::get( mlir::IntegerType::get(op.getContext(), 1))); if (op.getSource()) { mlir::Value stream = op.getStream() ? op.getStream() : builder.createNullConstant(loc, fTy.getInput(2)); args = fir::runtime::createArguments( builder, loc, fTy, op.getBox(), op.getSource(), stream, pinned, hasStat, errmsg, sourceFile, sourceLine); } else { mlir::Value stream = op.getStream() ? op.getStream() : builder.createNullConstant(loc, fTy.getInput(1)); args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(), stream, pinned, hasStat, errmsg, sourceFile, sourceLine); } } else { args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(), hasStat, errmsg, sourceFile, sourceLine); } auto callOp = fir::CallOp::create(builder, loc, func, args); rewriter.replaceOp(op, callOp); return mlir::success(); } struct CUFAllocateOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(cuf::AllocateOp op, mlir::PatternRewriter &rewriter) const override { auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); bool isPointer = false; if (auto declareOp = mlir::dyn_cast_or_null(op.getBox().getDefiningOp())) if (declareOp.getFortranAttrs() && bitEnumContainsAny(*declareOp.getFortranAttrs(), fir::FortranVariableFlagsEnum::pointer)) isPointer = true; if (hasDoubleDescriptors(op)) { // Allocation for module variable are done with custom runtime entry point // so the descriptors can be synchronized. mlir::func::FuncOp func; if (op.getSource()) { func = isPointer ? fir::runtime::getRuntimeFunc(loc, builder) : fir::runtime::getRuntimeFunc(loc, builder); } else { func = isPointer ? fir::runtime::getRuntimeFunc( loc, builder) : fir::runtime::getRuntimeFunc(loc, builder); } return convertOpToCall(op, rewriter, func); } mlir::func::FuncOp func; if (op.getSource()) { func = isPointer ? fir::runtime::getRuntimeFunc( loc, builder) : fir::runtime::getRuntimeFunc(loc, builder); } else { func = isPointer ? fir::runtime::getRuntimeFunc( loc, builder) : fir::runtime::getRuntimeFunc( loc, builder); } return convertOpToCall(op, rewriter, func); } }; struct CUFDeallocateOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(cuf::DeallocateOp op, mlir::PatternRewriter &rewriter) const override { auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); if (hasDoubleDescriptors(op)) { // Deallocation for module variable are done with custom runtime entry // point so the descriptors can be synchronized. mlir::func::FuncOp func = fir::runtime::getRuntimeFunc( loc, builder); return convertOpToCall(op, rewriter, func); } // Deallocation for local descriptor falls back on the standard runtime // AllocatableDeallocate as the dedicated deallocator is set in the // descriptor before the call. mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); return convertOpToCall(op, rewriter, func); } }; static bool inDeviceContext(mlir::Operation *op) { if (op->getParentOfType()) return true; if (auto funcOp = op->getParentOfType()) return true; if (auto funcOp = op->getParentOfType()) return true; if (auto funcOp = op->getParentOfType()) { if (auto cudaProcAttr = funcOp.getOperation()->getAttrOfType( cuf::getProcAttrName())) { return cudaProcAttr.getValue() != cuf::ProcAttribute::Host && cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice; } } return false; } static int computeWidth(mlir::Location loc, mlir::Type type, fir::KindMapping &kindMap) { auto eleTy = fir::unwrapSequenceType(type); if (auto t{mlir::dyn_cast(eleTy)}) return t.getWidth() / 8; if (auto t{mlir::dyn_cast(eleTy)}) return t.getWidth() / 8; if (eleTy.isInteger(1)) return 1; if (auto t{mlir::dyn_cast(eleTy)}) return kindMap.getLogicalBitsize(t.getFKind()) / 8; if (auto t{mlir::dyn_cast(eleTy)}) { int elemSize = mlir::cast(t.getElementType()).getWidth() / 8; return 2 * elemSize; } if (auto t{mlir::dyn_cast_or_null(eleTy)}) return kindMap.getCharacterBitsize(t.getFKind()) / 8; mlir::emitError(loc, "unsupported type"); return 0; } struct CUFAllocOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; CUFAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl, const fir::LLVMTypeConverter *typeConverter) : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {} mlir::LogicalResult matchAndRewrite(cuf::AllocOp op, mlir::PatternRewriter &rewriter) const override { mlir::Location loc = op.getLoc(); if (inDeviceContext(op.getOperation())) { // In device context just replace the cuf.alloc operation with a fir.alloc // the cuf.free will be removed. auto allocaOp = fir::AllocaOp::create(rewriter, loc, op.getInType(), op.getUniqName() ? *op.getUniqName() : "", op.getBindcName() ? *op.getBindcName() : "", op.getTypeparams(), op.getShape()); allocaOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr()); rewriter.replaceOp(op, allocaOp); return mlir::success(); } auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); if (!mlir::dyn_cast_or_null(op.getInType())) { // Convert scalar and known size array allocations. mlir::Value bytes; fir::KindMapping kindMap{fir::getKindMapping(mod)}; if (fir::isa_trivial(op.getInType())) { int width = computeWidth(loc, op.getInType(), kindMap); bytes = builder.createIntegerConstant(loc, builder.getIndexType(), width); } else if (auto seqTy = mlir::dyn_cast_or_null( op.getInType())) { std::size_t size = 0; if (fir::isa_derived(seqTy.getEleTy())) { mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy()); size = dl->getTypeSizeInBits(structTy) / 8; } else { size = computeWidth(loc, seqTy.getEleTy(), kindMap); } mlir::Value width = builder.createIntegerConstant(loc, builder.getIndexType(), size); mlir::Value nbElem; if (fir::sequenceWithNonConstantShape(seqTy)) { assert(!op.getShape().empty() && "expect shape with dynamic arrays"); nbElem = builder.loadIfRef(loc, op.getShape()[0]); for (unsigned i = 1; i < op.getShape().size(); ++i) { nbElem = mlir::arith::MulIOp::create( rewriter, loc, nbElem, builder.loadIfRef(loc, op.getShape()[i])); } } else { nbElem = builder.createIntegerConstant(loc, builder.getIndexType(), seqTy.getConstantArraySize()); } bytes = mlir::arith::MulIOp::create(rewriter, loc, nbElem, width); } else if (fir::isa_derived(op.getInType())) { mlir::Type structTy = typeConverter->convertType(op.getInType()); std::size_t structSize = dl->getTypeSizeInBits(structTy) / 8; bytes = builder.createIntegerConstant(loc, builder.getIndexType(), structSize); } else { mlir::emitError(loc, "unsupported type in cuf.alloc\n"); } mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(3)); mlir::Value memTy = builder.createIntegerConstant( loc, builder.getI32Type(), getMemType(op.getDataAttr())); llvm::SmallVector args{fir::runtime::createArguments( builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)}; auto callOp = fir::CallOp::create(builder, loc, func, args); callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr()); auto convOp = builder.createConvert(loc, op.getResult().getType(), callOp.getResult(0)); rewriter.replaceOp(op, convOp); return mlir::success(); } // Convert descriptor allocations to function call. auto boxTy = mlir::dyn_cast_or_null(op.getInType()); mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy); std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; mlir::Value sizeInBytes = builder.createIntegerConstant(loc, builder.getIndexType(), boxSize); llvm::SmallVector args{fir::runtime::createArguments( builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)}; auto callOp = fir::CallOp::create(builder, loc, func, args); callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr()); auto convOp = builder.createConvert(loc, op.getResult().getType(), callOp.getResult(0)); rewriter.replaceOp(op, convOp); return mlir::success(); } private: mlir::DataLayout *dl; const fir::LLVMTypeConverter *typeConverter; }; struct CUFDeviceAddressOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; CUFDeviceAddressOpConversion(mlir::MLIRContext *context, const mlir::SymbolTable &symtab) : OpRewritePattern(context), symTab{symtab} {} mlir::LogicalResult matchAndRewrite(cuf::DeviceAddressOp op, mlir::PatternRewriter &rewriter) const override { if (auto global = symTab.lookup( op.getHostSymbol().getRootReference().getValue())) { auto mod = op->getParentOfType(); mlir::Location loc = op.getLoc(); auto hostAddr = fir::AddrOfOp::create( rewriter, loc, fir::ReferenceType::get(global.getType()), op.getHostSymbol()); fir::FirOpBuilder builder(rewriter, mod); mlir::func::FuncOp callee = fir::runtime::getRuntimeFunc(loc, builder); auto fTy = callee.getFunctionType(); mlir::Value conv = createConvertOp(rewriter, loc, fTy.getInput(0), hostAddr); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); llvm::SmallVector args{fir::runtime::createArguments( builder, loc, fTy, conv, sourceFile, sourceLine)}; auto call = fir::CallOp::create(rewriter, loc, callee, args); mlir::Value addr = createConvertOp(rewriter, loc, hostAddr.getType(), call->getResult(0)); rewriter.replaceOp(op, addr.getDefiningOp()); return success(); } return failure(); } private: const mlir::SymbolTable &symTab; }; struct DeclareOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; DeclareOpConversion(mlir::MLIRContext *context, const mlir::SymbolTable &symtab) : OpRewritePattern(context), symTab{symtab} {} mlir::LogicalResult matchAndRewrite(fir::DeclareOp op, mlir::PatternRewriter &rewriter) const override { if (auto addrOfOp = op.getMemref().getDefiningOp()) { if (auto global = symTab.lookup( addrOfOp.getSymbol().getRootReference().getValue())) { if (cuf::isRegisteredDeviceGlobal(global)) { rewriter.setInsertionPointAfter(addrOfOp); mlir::Value devAddr = cuf::DeviceAddressOp::create( rewriter, op.getLoc(), addrOfOp.getType(), addrOfOp.getSymbol()); rewriter.startOpModification(op); op.getMemrefMutable().assign(devAddr); rewriter.finalizeOpModification(op); return success(); } } } return failure(); } private: const mlir::SymbolTable &symTab; }; struct CUFFreeOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(cuf::FreeOp op, mlir::PatternRewriter &rewriter) const override { if (inDeviceContext(op.getOperation())) { rewriter.eraseOp(op); return mlir::success(); } if (!mlir::isa(op.getDevptr().getType())) return failure(); auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); auto refTy = mlir::dyn_cast(op.getDevptr().getType()); if (!mlir::isa(refTy.getEleTy())) { mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(3)); mlir::Value memTy = builder.createIntegerConstant( loc, builder.getI32Type(), getMemType(op.getDataAttr())); llvm::SmallVector args{fir::runtime::createArguments( builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)}; fir::CallOp::create(builder, loc, func, args); rewriter.eraseOp(op); return mlir::success(); } // Convert cuf.free on descriptors. mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); llvm::SmallVector args{fir::runtime::createArguments( builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)}; auto callOp = fir::CallOp::create(builder, loc, func, args); callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr()); rewriter.eraseOp(op); return mlir::success(); } }; static bool isDstGlobal(cuf::DataTransferOp op) { if (auto declareOp = op.getDst().getDefiningOp()) if (declareOp.getMemref().getDefiningOp()) return true; if (auto declareOp = op.getDst().getDefiningOp()) if (declareOp.getMemref().getDefiningOp()) return true; return false; } static mlir::Value getShapeFromDecl(mlir::Value src) { if (auto declareOp = src.getDefiningOp()) return declareOp.getShape(); if (auto declareOp = src.getDefiningOp()) return declareOp.getShape(); return mlir::Value{}; } static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter, cuf::DataTransferOp op, const mlir::SymbolTable &symtab, mlir::Type dstEleTy = nullptr) { auto mod = op->getParentOfType(); mlir::Location loc = op.getLoc(); fir::FirOpBuilder builder(rewriter, mod); mlir::Value addr; mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType()); if (fir::isa_trivial(srcTy) && mlir::matchPattern(op.getSrc().getDefiningOp(), mlir::m_Constant())) { mlir::Value src = op.getSrc(); if (srcTy.isInteger(1)) { // i1 is not a supported type in the descriptor and it is actually coming // from a LOGICAL constant. Store it as a fir.logical. srcTy = fir::LogicalType::get(rewriter.getContext(), 4); src = createConvertOp(rewriter, loc, srcTy, src); addr = builder.createTemporary(loc, srcTy); fir::StoreOp::create(builder, loc, src, addr); } else { if (dstEleTy && fir::isa_trivial(dstEleTy) && srcTy != dstEleTy) { // Use dstEleTy and convert to avoid assign mismatch. addr = builder.createTemporary(loc, dstEleTy); auto conv = fir::ConvertOp::create(builder, loc, dstEleTy, src); fir::StoreOp::create(builder, loc, conv, addr); srcTy = dstEleTy; } else { // Put constant in memory if it is not. addr = builder.createTemporary(loc, srcTy); fir::StoreOp::create(builder, loc, src, addr); } } } else { addr = op.getSrc(); } llvm::SmallVector lenParams; mlir::Type boxTy = fir::BoxType::get(srcTy); mlir::Value box = builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getSrc()), /*slice=*/nullptr, lenParams, /*tdesc=*/nullptr); mlir::Value src = builder.createTemporary(loc, box.getType()); fir::StoreOp::create(builder, loc, box, src); return src; } static mlir::Value emboxDst(mlir::PatternRewriter &rewriter, cuf::DataTransferOp op, const mlir::SymbolTable &symtab) { auto mod = op->getParentOfType(); mlir::Location loc = op.getLoc(); fir::FirOpBuilder builder(rewriter, mod); mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType()); mlir::Value dstAddr = op.getDst(); mlir::Type dstBoxTy = fir::BoxType::get(dstTy); llvm::SmallVector lenParams; mlir::Value dstBox = builder.createBox(loc, dstBoxTy, dstAddr, getShapeFromDecl(op.getDst()), /*slice=*/nullptr, lenParams, /*tdesc=*/nullptr); mlir::Value dst = builder.createTemporary(loc, dstBox.getType()); fir::StoreOp::create(builder, loc, dstBox, dst); return dst; } struct CUFDataTransferOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; CUFDataTransferOpConversion(mlir::MLIRContext *context, const mlir::SymbolTable &symtab, mlir::DataLayout *dl, const fir::LLVMTypeConverter *typeConverter) : OpRewritePattern(context), symtab{symtab}, dl{dl}, typeConverter{typeConverter} {} mlir::LogicalResult matchAndRewrite(cuf::DataTransferOp op, mlir::PatternRewriter &rewriter) const override { mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType()); mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType()); mlir::Location loc = op.getLoc(); unsigned mode = 0; if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) { mode = kHostToDevice; } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceHost) { mode = kDeviceToHost; } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceDevice) { mode = kDeviceToDevice; } else { mlir::emitError(loc, "unsupported transfer kind\n"); } auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); fir::KindMapping kindMap{fir::getKindMapping(mod)}; mlir::Value modeValue = builder.createIntegerConstant(loc, builder.getI32Type(), mode); // Convert data transfer without any descriptor. if (!mlir::isa(srcTy) && !mlir::isa(dstTy)) { if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) { // Initialization of an array from a scalar value should be implemented // via a kernel launch. Use the flan runtime via the Assign function // until we have more infrastructure. mlir::Value src = emboxSrc(rewriter, op, symtab); mlir::Value dst = emboxDst(rewriter, op, symtab); mlir::func::FuncOp func = fir::runtime::getRuntimeFunc( loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); llvm::SmallVector args{fir::runtime::createArguments( builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)}; fir::CallOp::create(builder, loc, func, args); rewriter.eraseOp(op); return mlir::success(); } mlir::Type i64Ty = builder.getI64Type(); mlir::Value nbElement; if (op.getShape()) { llvm::SmallVector extents; if (auto shapeOp = mlir::dyn_cast(op.getShape().getDefiningOp())) { extents = shapeOp.getExtents(); } else if (auto shapeShiftOp = mlir::dyn_cast( op.getShape().getDefiningOp())) { for (auto i : llvm::enumerate(shapeShiftOp.getPairs())) if (i.index() & 1) extents.push_back(i.value()); } nbElement = fir::ConvertOp::create(rewriter, loc, i64Ty, extents[0]); for (unsigned i = 1; i < extents.size(); ++i) { auto operand = fir::ConvertOp::create(rewriter, loc, i64Ty, extents[i]); nbElement = mlir::arith::MulIOp::create(rewriter, loc, nbElement, operand); } } else { if (auto seqTy = mlir::dyn_cast_or_null(dstTy)) nbElement = builder.createIntegerConstant( loc, i64Ty, seqTy.getConstantArraySize()); } unsigned width = 0; if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) { mlir::Type structTy = typeConverter->convertType(fir::unwrapSequenceType(dstTy)); width = dl->getTypeSizeInBits(structTy) / 8; } else { width = computeWidth(loc, dstTy, kindMap); } mlir::Value widthValue = mlir::arith::ConstantOp::create( rewriter, loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width)); mlir::Value bytes = nbElement ? mlir::arith::MulIOp::create( rewriter, loc, nbElement, widthValue) : widthValue; mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(5)); mlir::Value dst = op.getDst(); mlir::Value src = op.getSrc(); // Materialize the src if constant. if (matchPattern(src.getDefiningOp(), mlir::m_Constant())) { mlir::Value temp = builder.createTemporary(loc, srcTy); fir::StoreOp::create(builder, loc, src, temp); src = temp; } llvm::SmallVector args{ fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes, modeValue, sourceFile, sourceLine)}; fir::CallOp::create(builder, loc, func, args); rewriter.eraseOp(op); return mlir::success(); } auto materializeBoxIfNeeded = [&](mlir::Value val) -> mlir::Value { if (mlir::isa(val.getDefiningOp())) { // Materialize the box to memory to be able to call the runtime. mlir::Value box = builder.createTemporary(loc, val.getType()); fir::StoreOp::create(builder, loc, val, box); return box; } return val; }; // Conversion of data transfer involving at least one descriptor. if (auto dstBoxTy = mlir::dyn_cast(dstTy)) { // Transfer to a descriptor. mlir::func::FuncOp func = isDstGlobal(op) ? fir::runtime::getRuntimeFunc(loc, builder) : fir::runtime::getRuntimeFunc( loc, builder); mlir::Value dst = op.getDst(); mlir::Value src = op.getSrc(); if (!mlir::isa(srcTy)) { mlir::Type dstEleTy = fir::unwrapInnerType(dstBoxTy.getEleTy()); src = emboxSrc(rewriter, op, symtab, dstEleTy); if (fir::isa_trivial(srcTy)) func = fir::runtime::getRuntimeFunc( loc, builder); } src = materializeBoxIfNeeded(src); dst = materializeBoxIfNeeded(dst); auto fTy = func.getFunctionType(); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); llvm::SmallVector args{fir::runtime::createArguments( builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)}; fir::CallOp::create(builder, loc, func, args); rewriter.eraseOp(op); } else { // Transfer from a descriptor. mlir::Value dst = emboxDst(rewriter, op, symtab); mlir::Value src = materializeBoxIfNeeded(op.getSrc()); mlir::func::FuncOp func = fir::runtime::getRuntimeFunc(loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); llvm::SmallVector args{fir::runtime::createArguments( builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)}; fir::CallOp::create(builder, loc, func, args); rewriter.eraseOp(op); } return mlir::success(); } private: const mlir::SymbolTable &symtab; mlir::DataLayout *dl; const fir::LLVMTypeConverter *typeConverter; }; struct CUFLaunchOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; CUFLaunchOpConversion(mlir::MLIRContext *context, const mlir::SymbolTable &symTab) : OpRewritePattern(context), symTab{symTab} {} mlir::LogicalResult matchAndRewrite(cuf::KernelLaunchOp op, mlir::PatternRewriter &rewriter) const override { mlir::Location loc = op.getLoc(); auto idxTy = mlir::IndexType::get(op.getContext()); mlir::Value zero = mlir::arith::ConstantOp::create( rewriter, loc, rewriter.getIntegerType(32), rewriter.getI32IntegerAttr(0)); auto gridSizeX = mlir::arith::IndexCastOp::create(rewriter, loc, idxTy, op.getGridX()); auto gridSizeY = mlir::arith::IndexCastOp::create(rewriter, loc, idxTy, op.getGridY()); auto gridSizeZ = mlir::arith::IndexCastOp::create(rewriter, loc, idxTy, op.getGridZ()); auto blockSizeX = mlir::arith::IndexCastOp::create(rewriter, loc, idxTy, op.getBlockX()); auto blockSizeY = mlir::arith::IndexCastOp::create(rewriter, loc, idxTy, op.getBlockY()); auto blockSizeZ = mlir::arith::IndexCastOp::create(rewriter, loc, idxTy, op.getBlockZ()); auto kernelName = mlir::SymbolRefAttr::get( rewriter.getStringAttr(cudaDeviceModuleName), {mlir::SymbolRefAttr::get( rewriter.getContext(), op.getCallee().getLeafReference().getValue())}); mlir::Value clusterDimX, clusterDimY, clusterDimZ; cuf::ProcAttributeAttr procAttr; if (auto funcOp = symTab.lookup( op.getCallee().getLeafReference())) { if (auto clusterDimsAttr = funcOp->getAttrOfType( cuf::getClusterDimsAttrName())) { clusterDimX = mlir::arith::ConstantIndexOp::create( rewriter, loc, clusterDimsAttr.getX().getInt()); clusterDimY = mlir::arith::ConstantIndexOp::create( rewriter, loc, clusterDimsAttr.getY().getInt()); clusterDimZ = mlir::arith::ConstantIndexOp::create( rewriter, loc, clusterDimsAttr.getZ().getInt()); } procAttr = funcOp->getAttrOfType(cuf::getProcAttrName()); } llvm::SmallVector args; for (mlir::Value arg : op.getArgs()) { // If the argument is a global descriptor, make sure we pass the device // copy of this descriptor and not the host one. if (mlir::isa(fir::unwrapRefType(arg.getType()))) { if (auto declareOp = mlir::dyn_cast_or_null(arg.getDefiningOp())) { if (auto addrOfOp = mlir::dyn_cast_or_null( declareOp.getMemref().getDefiningOp())) { if (auto global = symTab.lookup( addrOfOp.getSymbol().getRootReference().getValue())) { if (cuf::isRegisteredDeviceGlobal(global)) { arg = rewriter .create(op.getLoc(), addrOfOp.getType(), addrOfOp.getSymbol()) .getResult(); } } } } } args.push_back(arg); } mlir::Value dynamicShmemSize = op.getBytes() ? op.getBytes() : zero; auto gpuLaunchOp = mlir::gpu::LaunchFuncOp::create( rewriter, loc, kernelName, mlir::gpu::KernelDim3{gridSizeX, gridSizeY, gridSizeZ}, mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, dynamicShmemSize, args); if (clusterDimX && clusterDimY && clusterDimZ) { gpuLaunchOp.getClusterSizeXMutable().assign(clusterDimX); gpuLaunchOp.getClusterSizeYMutable().assign(clusterDimY); gpuLaunchOp.getClusterSizeZMutable().assign(clusterDimZ); } if (op.getStream()) { mlir::OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(gpuLaunchOp); mlir::Value stream = cuf::StreamCastOp::create(rewriter, loc, op.getStream()); gpuLaunchOp.getAsyncDependenciesMutable().append(stream); } if (procAttr) gpuLaunchOp->setAttr(cuf::getProcAttrName(), procAttr); else // Set default global attribute of the original was not found. gpuLaunchOp->setAttr(cuf::getProcAttrName(), cuf::ProcAttributeAttr::get( op.getContext(), cuf::ProcAttribute::Global)); rewriter.replaceOp(op, gpuLaunchOp); return mlir::success(); } private: const mlir::SymbolTable &symTab; }; struct CUFSyncDescriptorOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(cuf::SyncDescriptorOp op, mlir::PatternRewriter &rewriter) const override { auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); auto globalOp = mod.lookupSymbol(op.getGlobalName()); if (!globalOp) return mlir::failure(); auto hostAddr = fir::AddrOfOp::create( builder, loc, fir::ReferenceType::get(globalOp.getType()), op.getGlobalName()); fir::runtime::cuda::genSyncGlobalDescriptor(builder, loc, hostAddr); op.erase(); return mlir::success(); } }; struct CUFSetAllocatorIndexOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(cuf::SetAllocatorIndexOp op, mlir::PatternRewriter &rewriter) const override { auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); int idx = kDefaultAllocator; if (op.getDataAttr() == cuf::DataAttribute::Device) { idx = kDeviceAllocatorPos; } else if (op.getDataAttr() == cuf::DataAttribute::Managed) { idx = kManagedAllocatorPos; } else if (op.getDataAttr() == cuf::DataAttribute::Unified) { idx = kUnifiedAllocatorPos; } else if (op.getDataAttr() == cuf::DataAttribute::Pinned) { idx = kPinnedAllocatorPos; } mlir::Value index = builder.createIntegerConstant(loc, builder.getI32Type(), idx); fir::runtime::cuda::genSetAllocatorIndex(builder, loc, op.getBox(), index); op.erase(); return mlir::success(); } }; class CUFOpConversion : public fir::impl::CUFOpConversionBase { public: void runOnOperation() override { auto *ctx = &getContext(); mlir::RewritePatternSet patterns(ctx); mlir::ConversionTarget target(*ctx); mlir::Operation *op = getOperation(); mlir::ModuleOp module = mlir::dyn_cast(op); if (!module) return signalPassFailure(); mlir::SymbolTable symtab(module); std::optional dl = fir::support::getOrSetMLIRDataLayout( module, /*allowDefaultLayout=*/false); fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false, /*forceUnifiedTBAATree=*/false, *dl); target.addLegalDialect(); target.addLegalOp(); cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab, patterns); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) { mlir::emitError(mlir::UnknownLoc::get(ctx), "error in CUF op conversion\n"); signalPassFailure(); } target.addDynamicallyLegalOp([&](fir::DeclareOp op) { if (inDeviceContext(op)) return true; if (auto addrOfOp = op.getMemref().getDefiningOp()) { if (auto global = symtab.lookup( addrOfOp.getSymbol().getRootReference().getValue())) { if (mlir::isa(fir::unwrapRefType(global.getType()))) return true; if (cuf::isRegisteredDeviceGlobal(global)) return false; } } return true; }); patterns.clear(); cuf::populateFIRCUFConversionPatterns(symtab, patterns); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) { mlir::emitError(mlir::UnknownLoc::get(ctx), "error in CUF op conversion\n"); signalPassFailure(); } } }; } // namespace void cuf::populateCUFToFIRConversionPatterns( const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl, const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) { patterns.insert(patterns.getContext(), &dl, &converter); patterns.insert(patterns.getContext()); patterns.insert(patterns.getContext(), symtab, &dl, &converter); patterns.insert( patterns.getContext(), symtab); } void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) { patterns.insert( patterns.getContext(), symtab); }