diff options
Diffstat (limited to 'flang/lib/Optimizer')
96 files changed, 11253 insertions, 3149 deletions
diff --git a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp index 73ddd1f..0eb00e2 100644 --- a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp +++ b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "flang/Optimizer/Analysis/AliasAnalysis.h" +#include "flang/Optimizer/Dialect/CUF/CUFOps.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Dialect/FIRType.h" @@ -21,12 +22,38 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" using namespace mlir; #define DEBUG_TYPE "fir-alias-analysis" +llvm::cl::opt<bool> supportCrayPointers( + "unsafe-cray-pointers", + llvm::cl::desc("Support Cray POINTERs that ALIAS with non-TARGET data"), + llvm::cl::init(false)); + +// Inspect for value-scoped Allocate effects and determine whether +// 'candidate' is a new allocation. Returns SourceKind::Allocate if a +// MemAlloc effect is attached +static fir::AliasAnalysis::SourceKind +classifyAllocateFromEffects(mlir::Operation *op, mlir::Value candidate) { + if (!op) + return fir::AliasAnalysis::SourceKind::Unknown; + auto interface = llvm::dyn_cast<mlir::MemoryEffectOpInterface>(op); + if (!interface) + return fir::AliasAnalysis::SourceKind::Unknown; + llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 4> effects; + interface.getEffects(effects); + for (mlir::MemoryEffects::EffectInstance &e : effects) { + if (mlir::isa<mlir::MemoryEffects::Allocate>(e.getEffect()) && + e.getValue() && e.getValue() == candidate) + return fir::AliasAnalysis::SourceKind::Allocate; + } + return fir::AliasAnalysis::SourceKind::Unknown; +} + //===----------------------------------------------------------------------===// // AliasAnalysis: alias //===----------------------------------------------------------------------===// @@ -40,15 +67,28 @@ getAttrsFromVariable(fir::FortranVariableOpInterface var) { attrs.set(fir::AliasAnalysis::Attribute::Pointer); if (var.isIntentIn()) attrs.set(fir::AliasAnalysis::Attribute::IntentIn); + if (var.isCrayPointer()) + attrs.set(fir::AliasAnalysis::Attribute::CrayPointer); + if (var.isCrayPointee()) + attrs.set(fir::AliasAnalysis::Attribute::CrayPointee); return attrs; } -static bool hasGlobalOpTargetAttr(mlir::Value v, fir::AddrOfOp op) { - auto globalOpName = - mlir::OperationName(fir::GlobalOp::getOperationName(), op->getContext()); - return fir::valueHasFirAttribute( - v, fir::GlobalOp::getTargetAttrName(globalOpName)); +bool fir::AliasAnalysis::symbolMayHaveTargetAttr(mlir::SymbolRefAttr symbol, + mlir::Operation *from) { + assert(from); + + // If we cannot find the nearest SymbolTable assume the worst. + const mlir::SymbolTable *symTab = getNearestSymbolTable(from); + if (!symTab) + return true; + + if (auto globalOp = symTab->lookup<fir::GlobalOp>(symbol.getLeafReference())) + return globalOp.getTarget().value_or(false); + + // If the symbol is not defined by fir.global assume the worst. + return true; } static bool isEvaluateInMemoryBlockArg(mlir::Value v) { @@ -118,6 +158,18 @@ bool AliasAnalysis::Source::isPointer() const { return attributes.test(Attribute::Pointer); } +bool AliasAnalysis::Source::isCrayPointee() const { + return attributes.test(Attribute::CrayPointee); +} + +bool AliasAnalysis::Source::isCrayPointer() const { + return attributes.test(Attribute::CrayPointer); +} + +bool AliasAnalysis::Source::isCrayPointerOrPointee() const { + return isCrayPointer() || isCrayPointee(); +} + bool AliasAnalysis::Source::isDummyArgument() const { if (auto v = origin.u.dyn_cast<mlir::Value>()) { return fir::isDummyArgument(v); @@ -175,6 +227,34 @@ bool AliasAnalysis::Source::mayBeActualArgWithPtr( return false; } +// Return true if the two locations cannot alias based +// on the access data type, e.g. an address of a descriptor +// cannot alias with an address of data (unless the data +// may contain a descriptor). +static bool noAliasBasedOnType(mlir::Value lhs, mlir::Value rhs) { + mlir::Type lhsType = lhs.getType(); + mlir::Type rhsType = rhs.getType(); + if (!fir::isa_ref_type(lhsType) || !fir::isa_ref_type(rhsType)) + return false; + mlir::Type lhsElemType = fir::unwrapRefType(lhsType); + mlir::Type rhsElemType = fir::unwrapRefType(rhsType); + if (mlir::isa<fir::BaseBoxType>(lhsElemType) != + mlir::isa<fir::BaseBoxType>(rhsElemType)) { + // One of the types is fir.box and another is not. + mlir::Type nonBoxType; + if (mlir::isa<fir::BaseBoxType>(lhsElemType)) + nonBoxType = rhsElemType; + else + nonBoxType = lhsElemType; + + if (!fir::isRecordWithDescriptorMember(nonBoxType)) { + LLVM_DEBUG(llvm::dbgs() << " no alias based on the access types\n"); + return true; + } + } + return false; +} + AliasResult AliasAnalysis::alias(mlir::Value lhs, mlir::Value rhs) { // A wrapper around alias(Source lhsSrc, Source rhsSrc, mlir::Value lhs, // mlir::Value rhs) This allows a user to provide Source that may be obtained @@ -196,6 +276,10 @@ AliasResult AliasAnalysis::alias(Source lhsSrc, Source rhsSrc, mlir::Value lhs, llvm::dbgs() << " rhs: " << rhs << "\n"; llvm::dbgs() << " rhsSrc: " << rhsSrc << "\n";); + // Disambiguate data and descriptors addresses. + if (noAliasBasedOnType(lhs, rhs)) + return AliasResult::NoAlias; + // Indirect case currently not handled. Conservatively assume // it aliases with everything if (lhsSrc.kind >= SourceKind::Indirect || @@ -204,6 +288,15 @@ AliasResult AliasAnalysis::alias(Source lhsSrc, Source rhsSrc, mlir::Value lhs, return AliasResult::MayAlias; } + // Cray pointers/pointees can alias with anything via LOC. + if (supportCrayPointers) { + if (lhsSrc.isCrayPointerOrPointee() || rhsSrc.isCrayPointerOrPointee()) { + LLVM_DEBUG(llvm::dbgs() + << " aliasing because of Cray pointer/pointee\n"); + return AliasResult::MayAlias; + } + } + if (lhsSrc.kind == rhsSrc.kind) { // If the kinds and origins are the same, then lhs and rhs must alias unless // either source is approximate. Approximate sources are for parts of the @@ -214,6 +307,17 @@ AliasResult AliasAnalysis::alias(Source lhsSrc, Source rhsSrc, mlir::Value lhs, << " aliasing because same source kind and origin\n"); if (approximateSource) return AliasResult::MayAlias; + // One should be careful about relying on MustAlias. + // The LLVM definition implies that the two MustAlias + // memory objects start at exactly the same location. + // With Fortran array slices two objects may have + // the same starting location, but otherwise represent + // partially overlapping memory locations, e.g.: + // integer :: a(10) + // ... a(5:1:-1) ! starts at a(5) and addresses a(5), ..., a(1) + // ... a(5:10:1) ! starts at a(5) and addresses a(5), ..., a(10) + // The current implementation of FIR alias analysis will always + // return MayAlias for such cases. return AliasResult::MustAlias; } // If one value is the address of a composite, and if the other value is the @@ -287,6 +391,12 @@ AliasResult AliasAnalysis::alias(Source lhsSrc, Source rhsSrc, mlir::Value lhs, // of non-data is included below. if (src1->isTargetOrPointer() && src2->isTargetOrPointer() && src1->isData() && src2->isData()) { + // Two distinct TARGET globals may not alias. + if (!src1->isPointer() && !src2->isPointer() && + src1->kind == SourceKind::Global && src2->kind == SourceKind::Global && + src1->origin.u != src2->origin.u) { + return AliasResult::NoAlias; + } LLVM_DEBUG(llvm::dbgs() << " aliasing because of target or pointer\n"); return AliasResult::MayAlias; } @@ -400,7 +510,8 @@ static ModRefResult getCallModRef(fir::CallOp call, mlir::Value var) { // TODO: limit to Fortran functions?? // 1. Detect variables that can be accessed indirectly. fir::AliasAnalysis aliasAnalysis; - fir::AliasAnalysis::Source varSrc = aliasAnalysis.getSource(var); + fir::AliasAnalysis::Source varSrc = + aliasAnalysis.getSource(var, /*getLastInstantiationPoint=*/true); // If the variable is not a user variable, we cannot safely assume that // Fortran semantics apply (e.g., a bare alloca/allocmem result may very well // be placed in an allocatable/pointer descriptor and escape). @@ -430,6 +541,7 @@ static ModRefResult getCallModRef(fir::CallOp call, mlir::Value var) { // At that stage, it has been ruled out that local (including the saved ones) // and dummy cannot be indirectly accessed in the call. if (varSrc.kind != fir::AliasAnalysis::SourceKind::Allocate && + varSrc.kind != fir::AliasAnalysis::SourceKind::Argument && !varSrc.isDummyArgument()) { if (varSrc.kind != fir::AliasAnalysis::SourceKind::Global || !isSavedLocal(varSrc)) @@ -450,25 +562,43 @@ static ModRefResult getCallModRef(fir::CallOp call, mlir::Value var) { return ModRefResult::getNoModRef(); } -/// This is mostly inspired by MLIR::LocalAliasAnalysis with 2 notable -/// differences 1) Regions are not handled here but will be handled by a data -/// flow analysis to come 2) Allocate and Free effects are considered -/// modifying +/// This is mostly inspired by MLIR::LocalAliasAnalysis, except that +/// fir.call's are handled in a special way. ModRefResult AliasAnalysis::getModRef(Operation *op, Value location) { - MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); - if (!interface) { - if (auto call = llvm::dyn_cast<fir::CallOp>(op)) - return getCallModRef(call, location); - return ModRefResult::getModAndRef(); - } + if (auto call = llvm::dyn_cast<fir::CallOp>(op)) + return getCallModRef(call, location); // Build a ModRefResult by merging the behavior of the effects of this // operation. + ModRefResult result = ModRefResult::getNoModRef(); + MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); + if (op->hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) { + for (mlir::Region ®ion : op->getRegions()) { + result = result.merge(getModRef(region, location)); + if (result.isModAndRef()) + break; + } + + // In MLIR, RecursiveMemoryEffects can be combined with + // MemoryEffectOpInterface to describe extra effects on top of the + // effects of the nested operations. However, the presence of + // RecursiveMemoryEffects and the absence of MemoryEffectOpInterface + // implies the operation has no other memory effects than the one of its + // nested operations. + if (!interface) + return result; + } + + if (!interface || result.isModAndRef()) + return ModRefResult::getModAndRef(); + SmallVector<MemoryEffects::EffectInstance> effects; interface.getEffects(effects); - ModRefResult result = ModRefResult::getNoModRef(); for (const MemoryEffects::EffectInstance &effect : effects) { + // MemAlloc and MemFree are not mod-ref effects. + if (isa<MemoryEffects::Allocate, MemoryEffects::Free>(effect.getEffect())) + continue; // Check for an alias between the effect and our memory location. AliasResult aliasResult = AliasResult::MayAlias; @@ -495,22 +625,6 @@ ModRefResult AliasAnalysis::getModRef(mlir::Region ®ion, mlir::Value location) { ModRefResult result = ModRefResult::getNoModRef(); for (mlir::Operation &op : region.getOps()) { - if (op.hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) { - for (mlir::Region &subRegion : op.getRegions()) { - result = result.merge(getModRef(subRegion, location)); - // Fast return is already mod and ref. - if (result.isModAndRef()) - return result; - } - // In MLIR, RecursiveMemoryEffects can be combined with - // MemoryEffectOpInterface to describe extra effects on top of the - // effects of the nested operations. However, the presence of - // RecursiveMemoryEffects and the absence of MemoryEffectOpInterface - // implies the operation has no other memory effects than the one of its - // nested operations. - if (!mlir::isa<mlir::MemoryEffectOpInterface>(op)) - continue; - } result = result.merge(getModRef(&op, location)); if (result.isModAndRef()) return result; @@ -534,13 +648,28 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v, Source::Attributes attributes; mlir::Operation *instantiationPoint{nullptr}; while (defOp && !breakFromLoop) { - ty = defOp->getResultTypes()[0]; + // Value-scoped allocation detection via effects. + if (classifyAllocateFromEffects(defOp, v) == SourceKind::Allocate) { + type = SourceKind::Allocate; + break; + } + // Operations may have multiple results, so we need to analyze + // the result for which the source is queried. + auto opResult = mlir::cast<OpResult>(v); + assert(opResult.getOwner() == defOp && "v must be a result of defOp"); + ty = opResult.getType(); llvm::TypeSwitch<Operation *>(defOp) - .Case<hlfir::AsExprOp>([&](auto op) { + .Case([&](hlfir::AsExprOp op) { + // TODO: we should probably always report hlfir.as_expr + // as a unique source, and let the codegen decide whether + // to use the original buffer or create a copy. v = op.getVar(); defOp = v.getDefiningOp(); }) - .Case<hlfir::AssociateOp>([&](auto op) { + .Case([&](hlfir::AssociateOp op) { + assert(opResult != op.getMustFreeStrorageFlag() && + "MustFreeStorageFlag result is not an aliasing candidate"); + mlir::Value source = op.getSource(); if (fir::isa_trivial(source.getType())) { // Trivial values will always use distinct temp memory, @@ -554,17 +683,7 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v, defOp = v.getDefiningOp(); } }) - .Case<fir::AllocaOp, fir::AllocMemOp>([&](auto op) { - // Unique memory allocation. - type = SourceKind::Allocate; - breakFromLoop = true; - }) - .Case<fir::ConvertOp>([&](auto op) { - // Skip ConvertOp's and track further through the operand. - v = op->getOperand(0); - defOp = v.getDefiningOp(); - }) - .Case<fir::PackArrayOp>([&](auto op) { + .Case([&](fir::PackArrayOp op) { // The packed array is not distinguishable from the original // array, so skip PackArrayOp and track further through // the array operand. @@ -572,29 +691,7 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v, defOp = v.getDefiningOp(); approximateSource = true; }) - .Case<fir::BoxAddrOp>([&](auto op) { - v = op->getOperand(0); - defOp = v.getDefiningOp(); - if (mlir::isa<fir::BaseBoxType>(v.getType())) - followBoxData = true; - }) - .Case<fir::ArrayCoorOp, fir::CoordinateOp>([&](auto op) { - if (isPointerReference(ty)) - attributes.set(Attribute::Pointer); - v = op->getOperand(0); - defOp = v.getDefiningOp(); - if (mlir::isa<fir::BaseBoxType>(v.getType())) - followBoxData = true; - approximateSource = true; - }) - .Case<fir::EmboxOp, fir::ReboxOp>([&](auto op) { - if (followBoxData) { - v = op->getOperand(0); - defOp = v.getDefiningOp(); - } else - breakFromLoop = true; - }) - .Case<fir::LoadOp>([&](auto op) { + .Case([&](fir::LoadOp op) { // If load is inside target and it points to mapped item, // continue tracking. Operation *loadMemrefOp = op.getMemref().getDefiningOp(); @@ -623,21 +720,35 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v, isCapturedInInternalProcedure |= boxSrc.isCapturedInInternalProcedure; + if (getLastInstantiationPoint) { + if (!instantiationPoint) + instantiationPoint = boxSrc.origin.instantiationPoint; + } else { + instantiationPoint = boxSrc.origin.instantiationPoint; + } + global = llvm::dyn_cast<mlir::SymbolRefAttr>(boxSrc.origin.u); if (global) { type = SourceKind::Global; } else { auto def = llvm::cast<mlir::Value>(boxSrc.origin.u); - // TODO: Add support to fir.allocmem - if (auto allocOp = def.template getDefiningOp<fir::AllocaOp>()) { - v = def; - defOp = v.getDefiningOp(); - type = SourceKind::Allocate; - } else if (isDummyArgument(def)) { - defOp = nullptr; - v = def; - } else { - type = SourceKind::Indirect; + bool classified = false; + if (auto defDefOp = def.getDefiningOp()) { + if (classifyAllocateFromEffects(defDefOp, def) == + SourceKind::Allocate) { + v = def; + defOp = defDefOp; + type = SourceKind::Allocate; + classified = true; + } + } + if (!classified) { + if (isDummyArgument(def)) { + defOp = nullptr; + v = def; + } else { + type = SourceKind::Indirect; + } } } breakFromLoop = true; @@ -647,28 +758,39 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v, type = SourceKind::Indirect; breakFromLoop = true; }) - .Case<fir::AddrOfOp>([&](auto op) { + .Case<fir::AddrOfOp, cuf::DeviceAddressOp>([&](auto op) { // Address of a global scope object. ty = v.getType(); type = SourceKind::Global; - - if (hasGlobalOpTargetAttr(v, op)) - attributes.set(Attribute::Target); - // TODO: Take followBoxData into account when setting the pointer // attribute if (isPointerReference(ty)) attributes.set(Attribute::Pointer); - global = llvm::cast<fir::AddrOfOp>(op).getSymbol(); + + if constexpr (std::is_same_v<std::decay_t<decltype(op)>, + fir::AddrOfOp>) + global = op.getSymbol(); + else if constexpr (std::is_same_v<std::decay_t<decltype(op)>, + cuf::DeviceAddressOp>) + global = op.getHostSymbol(); + else + llvm_unreachable("unexpected operation"); + + if (symbolMayHaveTargetAttr(global, op)) + attributes.set(Attribute::Target); + breakFromLoop = true; }) .Case<hlfir::DeclareOp, fir::DeclareOp>([&](auto op) { + // The declare operations support FortranObjectViewOpInterface, + // but their handling is more complex. Maybe we can find better + // abstractions to handle them in a general fashion. bool isPrivateItem = false; if (omp::BlockArgOpenMPOpInterface argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(op->getParentOp())) { Value ompValArg; llvm::TypeSwitch<Operation *>(op->getParentOp()) - .template Case<omp::TargetOp>([&](auto targetOp) { + .Case([&](omp::TargetOp targetOp) { // If declare operation is inside omp target region, // continue alias analysis outside the target region for (auto [opArg, blockArg] : llvm::zip_equal( @@ -713,7 +835,7 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v, // currently provide any useful information. The host associated // access will end up dereferencing the host association tuple, // so we may as well stop right now. - v = defOp->getResult(0); + v = opResult; // TODO: if the host associated variable is a dummy argument // of the host, I think, we can treat it as SourceKind::Argument // for the purpose of alias analysis inside the internal procedure. @@ -748,21 +870,45 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v, v = op.getMemref(); defOp = v.getDefiningOp(); }) - .Case<hlfir::DesignateOp>([&](auto op) { - auto varIf = llvm::cast<fir::FortranVariableOpInterface>(defOp); - attributes |= getAttrsFromVariable(varIf); - // Track further through the memory indexed into - // => if the source arrays/structures don't alias then nor do the - // results of hlfir.designate - v = op.getMemref(); + .Case([&](fir::FortranObjectViewOpInterface op) { + // This case must be located after the cases for concrete + // operations that support FortraObjectViewOpInterface, + // so that their special handling kicks in. + + // fir.embox/rebox case: this is the only case where we check + // for followBoxData. + // TODO: it looks like we do not have LIT tests that fail + // upon removal of the followBoxData code. We should come up + // with a test or remove this code. + if (!followBoxData && + (mlir::isa<fir::EmboxOp>(op) || mlir::isa<fir::ReboxOp>(op))) { + breakFromLoop = true; + return; + } + + // Collect attributes from FortranVariableOpInterface operations. + if (auto varIf = + mlir::dyn_cast<fir::FortranVariableOpInterface>(defOp)) + attributes |= getAttrsFromVariable(varIf); + // Set Pointer attribute based on the reference type. + if (isPointerReference(ty)) + attributes.set(Attribute::Pointer); + + // Update v to point to the operand that represents the object + // referenced by the operation's result. + v = op.getViewSource(opResult); defOp = v.getDefiningOp(); - // TODO: there will be some cases which provably don't alias if one - // takes into account the component or indices, which are currently - // ignored here - leading to false positives - // because of this limitation, we need to make sure we never return - // MustAlias after going through a designate operation - approximateSource = true; - if (mlir::isa<fir::BaseBoxType>(v.getType())) + // If the input the resulting object references are offsetted, + // then set approximateSource. + auto offset = op.getViewOffset(opResult); + if (!offset || *offset != 0) + approximateSource = true; + + // If the source is a box, and the result is not a box, + // then this is one of the box "unpacking" operations, + // so we should set followBoxData. + if (mlir::isa<fir::BaseBoxType>(v.getType()) && + !mlir::isa<fir::BaseBoxType>(ty)) followBoxData = true; }) .Default([&](auto op) { @@ -803,4 +949,16 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v, isCapturedInInternalProcedure}; } +const mlir::SymbolTable * +fir::AliasAnalysis::getNearestSymbolTable(mlir::Operation *from) { + assert(from); + Operation *symTabOp = mlir::SymbolTable::getNearestSymbolTable(from); + if (!symTabOp) + return nullptr; + auto it = symTabMap.find(symTabOp); + if (it != symTabMap.end()) + return &it->second; + return &symTabMap.try_emplace(symTabOp, symTabOp).first->second; +} + } // namespace fir diff --git a/flang/lib/Optimizer/Analysis/ArraySectionAnalyzer.cpp b/flang/lib/Optimizer/Analysis/ArraySectionAnalyzer.cpp new file mode 100644 index 0000000..f5ee298 --- /dev/null +++ b/flang/lib/Optimizer/Analysis/ArraySectionAnalyzer.cpp @@ -0,0 +1,300 @@ +//===- ArraySectionAnalyzer.cpp - Analyze array sections ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Analysis/ArraySectionAnalyzer.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIROpsSupport.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "array-section-analyzer" + +using namespace fir; + +ArraySectionAnalyzer::SectionDesc::SectionDesc(mlir::Value lb, mlir::Value ub, + mlir::Value stride) + : lb(lb), ub(ub), stride(stride) { + assert(lb && "lower bound or index must be specified"); + normalize(); +} + +void ArraySectionAnalyzer::SectionDesc::normalize() { + if (!ub) + ub = lb; + if (lb == ub) + stride = nullptr; + if (stride) + if (auto val = fir::getIntIfConstant(stride)) + if (*val == 1) + stride = nullptr; +} + +bool ArraySectionAnalyzer::SectionDesc::operator==( + const SectionDesc &other) const { + return lb == other.lb && ub == other.ub && stride == other.stride; +} + +ArraySectionAnalyzer::SectionDesc +ArraySectionAnalyzer::readSectionDesc(mlir::Operation::operand_iterator &it, + bool isTriplet) { + if (isTriplet) + return {*it++, *it++, *it++}; + return {*it++, nullptr, nullptr}; +} + +std::pair<mlir::Value, mlir::Value> +ArraySectionAnalyzer::getOrderedBounds(const SectionDesc &desc) { + mlir::Value stride = desc.stride; + // Null stride means stride=1. + if (!stride) + return {desc.lb, desc.ub}; + // Reverse the bounds, if stride is negative. + if (auto val = fir::getIntIfConstant(stride)) { + if (*val >= 0) + return {desc.lb, desc.ub}; + else + return {desc.ub, desc.lb}; + } + + return {nullptr, nullptr}; +} + +bool ArraySectionAnalyzer::areDisjointSections(const SectionDesc &desc1, + const SectionDesc &desc2) { + auto [lb1, ub1] = getOrderedBounds(desc1); + auto [lb2, ub2] = getOrderedBounds(desc2); + if (!lb1 || !lb2) + return false; + // Note that this comparison must be made on the ordered bounds, + // otherwise 'a(x:y:1) = a(z:x-1:-1) + 1' may be incorrectly treated + // as not overlapping (x=2, y=10, z=9). + if (isLess(ub1, lb2) || isLess(ub2, lb1)) + return true; + return false; +} + +bool ArraySectionAnalyzer::areIdenticalSections(const SectionDesc &desc1, + const SectionDesc &desc2) { + if (desc1 == desc2) + return true; + return false; +} + +ArraySectionAnalyzer::SlicesOverlapKind +ArraySectionAnalyzer::analyze(mlir::Value ref1, mlir::Value ref2) { + if (ref1 == ref2) + return SlicesOverlapKind::DefinitelyIdentical; + + auto des1 = ref1.getDefiningOp<hlfir::DesignateOp>(); + auto des2 = ref2.getDefiningOp<hlfir::DesignateOp>(); + // We only support a pair of designators right now. + if (!des1 || !des2) + return SlicesOverlapKind::Unknown; + + if (des1.getMemref() != des2.getMemref()) { + // If the bases are different, then there is unknown overlap. + LLVM_DEBUG(llvm::dbgs() << "No identical base for:\n" + << des1 << "and:\n" + << des2 << "\n"); + return SlicesOverlapKind::Unknown; + } + + // Require all components of the designators to be the same. + // It might be too strict, e.g. we may probably allow for + // different type parameters. + if (des1.getComponent() != des2.getComponent() || + des1.getComponentShape() != des2.getComponentShape() || + des1.getSubstring() != des2.getSubstring() || + des1.getComplexPart() != des2.getComplexPart() || + des1.getTypeparams() != des2.getTypeparams()) { + LLVM_DEBUG(llvm::dbgs() << "Different designator specs for:\n" + << des1 << "and:\n" + << des2 << "\n"); + return SlicesOverlapKind::Unknown; + } + + // Analyze the subscripts. + auto des1It = des1.getIndices().begin(); + auto des2It = des2.getIndices().begin(); + bool identicalTriplets = true; + bool identicalIndices = true; + for (auto [isTriplet1, isTriplet2] : + llvm::zip(des1.getIsTriplet(), des2.getIsTriplet())) { + SectionDesc desc1 = readSectionDesc(des1It, isTriplet1); + SectionDesc desc2 = readSectionDesc(des2It, isTriplet2); + + // See if we can prove that any of the sections do not overlap. + // This is mostly a Polyhedron/nf performance hack that looks for + // particular relations between the lower and upper bounds + // of the array sections, e.g. for any positive constant C: + // X:Y does not overlap with (Y+C):Z + // X:Y does not overlap with Z:(X-C) + if (areDisjointSections(desc1, desc2)) + return SlicesOverlapKind::DefinitelyDisjoint; + + if (!areIdenticalSections(desc1, desc2)) { + if (isTriplet1 || isTriplet2) { + // For example: + // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %0) + // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %1) + // + // If all the triplets (section speficiers) are the same, then + // we do not care if %0 is equal to %1 - the slices are either + // identical or completely disjoint. + // + // Also, treat these as identical sections: + // hlfir.designate %6#0 (%c2:%c2:%c1) + // hlfir.designate %6#0 (%c2) + identicalTriplets = false; + LLVM_DEBUG(llvm::dbgs() << "Triplet mismatch for:\n" + << des1 << "and:\n" + << des2 << "\n"); + } else { + identicalIndices = false; + LLVM_DEBUG(llvm::dbgs() << "Indices mismatch for:\n" + << des1 << "and:\n" + << des2 << "\n"); + } + } + } + + if (identicalTriplets) { + if (identicalIndices) + return SlicesOverlapKind::DefinitelyIdentical; + else + return SlicesOverlapKind::EitherIdenticalOrDisjoint; + } + + LLVM_DEBUG(llvm::dbgs() << "Different sections for:\n" + << des1 << "and:\n" + << des2 << "\n"); + return SlicesOverlapKind::Unknown; +} + +bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) { + auto removeConvert = [](mlir::Value v) -> mlir::Operation * { + auto *op = v.getDefiningOp(); + while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op)) + op = conv.getValue().getDefiningOp(); + return op; + }; + + auto isPositiveConstant = [](mlir::Value v) -> bool { + if (auto val = fir::getIntIfConstant(v)) + return *val > 0; + return false; + }; + + auto *op1 = removeConvert(v1); + auto *op2 = removeConvert(v2); + if (!op1 || !op2) + return false; + + // Check if they are both constants. + if (auto val1 = fir::getIntIfConstant(op1->getResult(0))) + if (auto val2 = fir::getIntIfConstant(op2->getResult(0))) + return *val1 < *val2; + + // Handle some variable cases (C > 0): + // v2 = v1 + C + // v2 = C + v1 + // v1 = v2 - C + if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2)) + if ((addi.getLhs().getDefiningOp() == op1 && + isPositiveConstant(addi.getRhs())) || + (addi.getRhs().getDefiningOp() == op1 && + isPositiveConstant(addi.getLhs()))) + return true; + if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1)) + if (subi.getLhs().getDefiningOp() == op2 && + isPositiveConstant(subi.getRhs())) + return true; + return false; +} + +/// Returns the array indices for the given hlfir.designate. +/// It recognizes the computations used to transform the one-based indices +/// into the array's lb-based indices, and returns the one-based indices +/// in these cases. +static llvm::SmallVector<mlir::Value> +getDesignatorIndices(hlfir::DesignateOp designate) { + mlir::Value memref = designate.getMemref(); + + // If the object is a box, then the indices may be adjusted + // according to the box's lower bound(s). Scan through + // the computations to try to find the one-based indices. + if (mlir::isa<fir::BaseBoxType>(memref.getType())) { + // Look for the following pattern: + // %13 = fir.load %12 : !fir.ref<!fir.box<...> + // %14:3 = fir.box_dims %13, %c0 : (!fir.box<...>, index) -> ... + // %17 = arith.subi %14#0, %c1 : index + // %18 = arith.addi %arg2, %17 : index + // %19 = hlfir.designate %13 (%18) : (!fir.box<...>, index) -> ... + // + // %arg2 is a one-based index. + + auto isNormalizedLb = [memref](mlir::Value v, unsigned dim) { + // Return true, if v and dim are such that: + // %14:3 = fir.box_dims %13, %dim : (!fir.box<...>, index) -> ... + // %17 = arith.subi %14#0, %c1 : index + // %19 = hlfir.designate %13 (...) : (!fir.box<...>, index) -> ... + if (auto subOp = + mlir::dyn_cast_or_null<mlir::arith::SubIOp>(v.getDefiningOp())) { + auto cst = fir::getIntIfConstant(subOp.getRhs()); + if (!cst || *cst != 1) + return false; + if (auto dimsOp = mlir::dyn_cast_or_null<fir::BoxDimsOp>( + subOp.getLhs().getDefiningOp())) { + if (memref != dimsOp.getVal() || + dimsOp.getResult(0) != subOp.getLhs()) + return false; + auto dimsOpDim = fir::getIntIfConstant(dimsOp.getDim()); + return dimsOpDim && dimsOpDim == dim; + } + } + return false; + }; + + llvm::SmallVector<mlir::Value> newIndices; + for (auto index : llvm::enumerate(designate.getIndices())) { + if (auto addOp = mlir::dyn_cast_or_null<mlir::arith::AddIOp>( + index.value().getDefiningOp())) { + for (unsigned opNum = 0; opNum < 2; ++opNum) + if (isNormalizedLb(addOp->getOperand(opNum), index.index())) { + newIndices.push_back(addOp->getOperand((opNum + 1) % 2)); + break; + } + + // If new one-based index was not added, exit early. + if (newIndices.size() <= index.index()) + break; + } + } + + // If any of the indices is not adjusted to the array's lb, + // then return the original designator indices. + if (newIndices.size() != designate.getIndices().size()) + return designate.getIndices(); + + return newIndices; + } + + return designate.getIndices(); +} + +bool fir::ArraySectionAnalyzer::isDesignatingArrayInOrder( + hlfir::DesignateOp designate, hlfir::ElementalOpInterface elemental) { + + auto indices = getDesignatorIndices(designate); + auto elementalIndices = elemental.getIndices(); + if (indices.size() == elementalIndices.size()) + return std::equal(indices.begin(), indices.end(), elementalIndices.begin(), + elementalIndices.end()); + return false; +} diff --git a/flang/lib/Optimizer/Analysis/CMakeLists.txt b/flang/lib/Optimizer/Analysis/CMakeLists.txt index 4d4ad88..398a6d3 100644 --- a/flang/lib/Optimizer/Analysis/CMakeLists.txt +++ b/flang/lib/Optimizer/Analysis/CMakeLists.txt @@ -1,14 +1,16 @@ add_flang_library(FIRAnalysis AliasAnalysis.cpp + ArraySectionAnalyzer.cpp TBAAForest.cpp DEPENDS + CUFDialect FIRDialect FIRSupport HLFIRDialect LINK_LIBS - FIRBuilder + CUFDialect FIRDialect FIRSupport HLFIRDialect diff --git a/flang/lib/Optimizer/Analysis/TBAAForest.cpp b/flang/lib/Optimizer/Analysis/TBAAForest.cpp index 44a0348..7154785 100644 --- a/flang/lib/Optimizer/Analysis/TBAAForest.cpp +++ b/flang/lib/Optimizer/Analysis/TBAAForest.cpp @@ -66,12 +66,9 @@ fir::TBAATree::TBAATree(mlir::LLVM::TBAATypeDescriptorAttr anyAccess, mlir::LLVM::TBAATypeDescriptorAttr dataRoot, mlir::LLVM::TBAATypeDescriptorAttr boxMemberTypeDesc) : targetDataTree(dataRoot.getContext(), "target data", dataRoot), - globalDataTree(dataRoot.getContext(), "global data", - targetDataTree.getRoot()), - allocatedDataTree(dataRoot.getContext(), "allocated data", - targetDataTree.getRoot()), + globalDataTree(dataRoot.getContext(), "global data", dataRoot), + allocatedDataTree(dataRoot.getContext(), "allocated data", dataRoot), dummyArgDataTree(dataRoot.getContext(), "dummy arg data", dataRoot), - directDataTree(dataRoot.getContext(), "direct data", - targetDataTree.getRoot()), + directDataTree(dataRoot.getContext(), "direct data", dataRoot), anyAccessDesc(anyAccess), boxMemberTypeDesc(boxMemberTypeDesc), anyDataTypeDesc(dataRoot) {} diff --git a/flang/lib/Optimizer/Builder/CMakeLists.txt b/flang/lib/Optimizer/Builder/CMakeLists.txt index 1f95259..d966c52 100644 --- a/flang/lib/Optimizer/Builder/CMakeLists.txt +++ b/flang/lib/Optimizer/Builder/CMakeLists.txt @@ -5,6 +5,7 @@ add_flang_library(FIRBuilder BoxValue.cpp Character.cpp Complex.cpp + CUDAIntrinsicCall.cpp CUFCommon.cpp DoLoopHelper.cpp FIRBuilder.cpp @@ -46,6 +47,7 @@ add_flang_library(FIRBuilder LINK_LIBS CUFAttrs CUFDialect + FIRAnalysis FIRDialect FIRDialectSupport FIRSupport diff --git a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp new file mode 100644 index 0000000..fe2db46 --- /dev/null +++ b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp @@ -0,0 +1,1722 @@ +//===-- CUDAIntrinsicCall.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper routines for constructing the FIR dialect of MLIR for PowerPC +// intrinsics. Extensive use of MLIR interfaces and MLIR's coding style +// (https://mlir.llvm.org/getting_started/DeveloperGuide/) is used in this +// module. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/CUDAIntrinsicCall.h" +#include "flang/Evaluate/common.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/MutableBox.h" +#include "flang/Optimizer/Dialect/CUF/CUFOps.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +namespace fir { + +using CI = CUDAIntrinsicLibrary; + +static const char __ldca_i4x4[] = "__ldca_i4x4_"; +static const char __ldca_i8x2[] = "__ldca_i8x2_"; +static const char __ldca_r2x2[] = "__ldca_r2x2_"; +static const char __ldca_r4x4[] = "__ldca_r4x4_"; +static const char __ldca_r8x2[] = "__ldca_r8x2_"; +static const char __ldcg_i4x4[] = "__ldcg_i4x4_"; +static const char __ldcg_i8x2[] = "__ldcg_i8x2_"; +static const char __ldcg_r2x2[] = "__ldcg_r2x2_"; +static const char __ldcg_r4x4[] = "__ldcg_r4x4_"; +static const char __ldcg_r8x2[] = "__ldcg_r8x2_"; +static const char __ldcs_i4x4[] = "__ldcs_i4x4_"; +static const char __ldcs_i8x2[] = "__ldcs_i8x2_"; +static const char __ldcs_r2x2[] = "__ldcs_r2x2_"; +static const char __ldcs_r4x4[] = "__ldcs_r4x4_"; +static const char __ldcs_r8x2[] = "__ldcs_r8x2_"; +static const char __ldcv_i4x4[] = "__ldcv_i4x4_"; +static const char __ldcv_i8x2[] = "__ldcv_i8x2_"; +static const char __ldcv_r2x2[] = "__ldcv_r2x2_"; +static const char __ldcv_r4x4[] = "__ldcv_r4x4_"; +static const char __ldcv_r8x2[] = "__ldcv_r8x2_"; +static const char __ldlu_i4x4[] = "__ldlu_i4x4_"; +static const char __ldlu_i8x2[] = "__ldlu_i8x2_"; +static const char __ldlu_r2x2[] = "__ldlu_r2x2_"; +static const char __ldlu_r4x4[] = "__ldlu_r4x4_"; +static const char __ldlu_r8x2[] = "__ldlu_r8x2_"; + +static constexpr unsigned kTMAAlignment = 16; + +// CUDA specific intrinsic handlers. +static constexpr IntrinsicHandler cudaHandlers[]{ + {"__ldca_i4x4", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldca_i4x4, 4>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldca_i8x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldca_i8x2, 2>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldca_r2x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldca_r2x2, 2>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldca_r4x4", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldca_r4x4, 4>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldca_r8x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldca_r8x2, 2>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldcg_i4x4", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldcg_i4x4, 4>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldcg_i8x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldcg_i8x2, 2>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldcg_r2x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldcg_r2x2, 2>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldcg_r4x4", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldcg_r4x4, 4>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldcg_r8x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldcg_r8x2, 2>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldcs_i4x4", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldcs_i4x4, 4>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldcs_i8x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldcs_i8x2, 2>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldcs_r2x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldcs_r2x2, 2>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldcs_r4x4", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldcs_r4x4, 4>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldcs_r8x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldcs_r8x2, 2>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldcv_i4x4", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldcv_i4x4, 4>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldcv_i8x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldcv_i8x2, 2>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldcv_r2x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldcv_r2x2, 2>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldcv_r4x4", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldcv_r4x4, 4>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldcv_r8x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldcv_r8x2, 2>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldlu_i4x4", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldlu_i4x4, 4>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldlu_i8x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldlu_i8x2, 2>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldlu_r2x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldlu_r2x2, 2>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldlu_r4x4", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldlu_r4x4, 4>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"__ldlu_r8x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genLDXXFunc<__ldlu_r8x2, 2>), + {{{"a", asAddr}}}, + /*isElemental=*/false}, + {"all_sync", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genVoteSync<mlir::NVVM::VoteSyncKind::all>), + {{{"mask", asValue}, {"pred", asValue}}}, + /*isElemental=*/false}, + {"any_sync", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genVoteSync<mlir::NVVM::VoteSyncKind::any>), + {{{"mask", asValue}, {"pred", asValue}}}, + /*isElemental=*/false}, + {"atomicadd_r4x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genAtomicAddVector<2>), + {{{"a", asAddr}, {"v", asAddr}}}, + false}, + {"atomicadd_r4x4", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genAtomicAddVector4x4), + {{{"a", asAddr}, {"v", asAddr}}}, + false}, + {"atomicaddd", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicAdd), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicaddf", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicAdd), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicaddi", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicAdd), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicaddl", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicAdd), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicaddr2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicAddR2), + {{{"a", asAddr}, {"v", asAddr}}}, + false}, + {"atomicaddvector_r2x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genAtomicAddVector<2>), + {{{"a", asAddr}, {"v", asAddr}}}, + false}, + {"atomicaddvector_r4x2", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>( + &CI::genAtomicAddVector<2>), + {{{"a", asAddr}, {"v", asAddr}}}, + false}, + {"atomicandi", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicAnd), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomiccasd", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicCas), + {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}}, + false}, + {"atomiccasf", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicCas), + {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}}, + false}, + {"atomiccasi", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicCas), + {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}}, + false}, + {"atomiccasul", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicCas), + {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}}, + false}, + {"atomicdeci", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicDec), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicexchd", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicExch), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicexchf", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicExch), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicexchi", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicExch), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicexchul", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicExch), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicinci", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicInc), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicmaxd", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMax), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicmaxf", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMax), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicmaxi", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMax), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicmaxl", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMax), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicmind", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMin), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicminf", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMin), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicmini", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMin), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicminl", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMin), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicori", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicOr), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicsubd", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicSub), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicsubf", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicSub), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicsubi", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicSub), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicsubl", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicSub), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicxori", + static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicXor), + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"ballot_sync", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genVoteSync<mlir::NVVM::VoteSyncKind::ballot>), + {{{"mask", asValue}, {"pred", asValue}}}, + /*isElemental=*/false}, + {"barrier_arrive", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genBarrierArrive), + {{{"barrier", asAddr}}}, + /*isElemental=*/false}, + {"barrier_arrive_cnt", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genBarrierArriveCnt), + {{{"barrier", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, + {"barrier_init", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genBarrierInit), + {{{"barrier", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, + {"barrier_try_wait", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genBarrierTryWait), + {{{"barrier", asAddr}, {"token", asValue}}}, + /*isElemental=*/false}, + {"barrier_try_wait_sleep", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genBarrierTryWaitSleep), + {{{"barrier", asAddr}, {"token", asValue}, {"ns", asValue}}}, + /*isElemental=*/false}, + {"clock", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genNVVMTime<mlir::NVVM::ClockOp>), + {}, + /*isElemental=*/false}, + {"clock64", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genNVVMTime<mlir::NVVM::Clock64Op>), + {}, + /*isElemental=*/false}, + {"cluster_block_index", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genClusterBlockIndex), + {}, + /*isElemental=*/false}, + {"cluster_dim_blocks", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genClusterDimBlocks), + {}, + /*isElemental=*/false}, + {"fence_proxy_async", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genFenceProxyAsync), + {}, + /*isElemental=*/false}, + {"globaltimer", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genNVVMTime<mlir::NVVM::GlobalTimerOp>), + {}, + /*isElemental=*/false}, + {"match_all_syncjd", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genMatchAllSync), + {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}}, + /*isElemental=*/false}, + {"match_all_syncjf", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genMatchAllSync), + {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}}, + /*isElemental=*/false}, + {"match_all_syncjj", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genMatchAllSync), + {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}}, + /*isElemental=*/false}, + {"match_all_syncjx", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genMatchAllSync), + {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}}, + /*isElemental=*/false}, + {"match_any_syncjd", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genMatchAnySync), + {{{"mask", asValue}, {"value", asValue}}}, + /*isElemental=*/false}, + {"match_any_syncjf", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genMatchAnySync), + {{{"mask", asValue}, {"value", asValue}}}, + /*isElemental=*/false}, + {"match_any_syncjj", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genMatchAnySync), + {{{"mask", asValue}, {"value", asValue}}}, + /*isElemental=*/false}, + {"match_any_syncjx", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genMatchAnySync), + {{{"mask", asValue}, {"value", asValue}}}, + /*isElemental=*/false}, + {"syncthreads", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genSyncThreads), + {}, + /*isElemental=*/false}, + {"syncthreads_and_i4", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genSyncThreadsAnd), + {}, + /*isElemental=*/false}, + {"syncthreads_and_l4", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genSyncThreadsAnd), + {}, + /*isElemental=*/false}, + {"syncthreads_count_i4", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genSyncThreadsCount), + {}, + /*isElemental=*/false}, + {"syncthreads_count_l4", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genSyncThreadsCount), + {}, + /*isElemental=*/false}, + {"syncthreads_or_i4", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genSyncThreadsOr), + {}, + /*isElemental=*/false}, + {"syncthreads_or_l4", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genSyncThreadsOr), + {}, + /*isElemental=*/false}, + {"syncwarp", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(&CI::genSyncWarp), + {}, + /*isElemental=*/false}, + {"this_cluster", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genThisCluster), + {}, + /*isElemental=*/false}, + {"this_grid", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genThisGrid), + {}, + /*isElemental=*/false}, + {"this_thread_block", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>( + &CI::genThisThreadBlock), + {}, + /*isElemental=*/false}, + {"this_warp", + static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genThisWarp), + {}, + /*isElemental=*/false}, + {"threadfence", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genThreadFence<mlir::NVVM::MemScopeKind::GPU>), + {}, + /*isElemental=*/false}, + {"threadfence_block", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genThreadFence<mlir::NVVM::MemScopeKind::CTA>), + {}, + /*isElemental=*/false}, + {"threadfence_system", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genThreadFence<mlir::NVVM::MemScopeKind::SYS>), + {}, + /*isElemental=*/false}, + {"tma_bulk_commit_group", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkCommitGroup), + {{}}, + /*isElemental=*/false}, + {"tma_bulk_g2s", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(&CI::genTMABulkG2S), + {{{"barrier", asAddr}, + {"src", asAddr}, + {"dst", asAddr}, + {"nbytes", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_ldc4", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkLoadC4), + {{{"barrier", asAddr}, + {"src", asAddr}, + {"dst", asAddr}, + {"nelems", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_ldc8", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkLoadC8), + {{{"barrier", asAddr}, + {"src", asAddr}, + {"dst", asAddr}, + {"nelems", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_ldi4", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkLoadI4), + {{{"barrier", asAddr}, + {"src", asAddr}, + {"dst", asAddr}, + {"nelems", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_ldi8", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkLoadI8), + {{{"barrier", asAddr}, + {"src", asAddr}, + {"dst", asAddr}, + {"nelems", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_ldr2", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkLoadR2), + {{{"barrier", asAddr}, + {"src", asAddr}, + {"dst", asAddr}, + {"nelems", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_ldr4", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkLoadR4), + {{{"barrier", asAddr}, + {"src", asAddr}, + {"dst", asAddr}, + {"nelems", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_ldr8", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkLoadR8), + {{{"barrier", asAddr}, + {"src", asAddr}, + {"dst", asAddr}, + {"nelems", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_s2g", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(&CI::genTMABulkS2G), + {{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_store_c4", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkStoreC4), + {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_store_c8", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkStoreC8), + {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_store_i4", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkStoreI4), + {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_store_i8", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkStoreI8), + {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_store_r2", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkStoreR2), + {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_store_r4", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkStoreR4), + {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_store_r8", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkStoreR8), + {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_wait_group", + static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>( + &CI::genTMABulkWaitGroup), + {{}}, + /*isElemental=*/false}, +}; + +template <std::size_t N> +static constexpr bool isSorted(const IntrinsicHandler (&array)[N]) { + // Replace by std::sorted when C++20 is default (will be constexpr). + const IntrinsicHandler *lastSeen{nullptr}; + bool isSorted{true}; + for (const auto &x : array) { + if (lastSeen) + isSorted &= std::string_view{lastSeen->name} < std::string_view{x.name}; + lastSeen = &x; + } + return isSorted; +} +static_assert(isSorted(cudaHandlers) && "map must be sorted"); + +const IntrinsicHandler *findCUDAIntrinsicHandler(llvm::StringRef name) { + auto compare = [](const IntrinsicHandler &cudaHandler, llvm::StringRef name) { + return name.compare(cudaHandler.name) > 0; + }; + auto result = llvm::lower_bound(cudaHandlers, name, compare); + return result != std::end(cudaHandlers) && result->name == name ? result + : nullptr; +} + +static mlir::Value convertPtrToNVVMSpace(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value barrier, + mlir::NVVM::NVVMMemorySpace space) { + mlir::Value llvmPtr = fir::ConvertOp::create( + builder, loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()), + barrier); + mlir::Value addrCast = mlir::LLVM::AddrSpaceCastOp::create( + builder, loc, + mlir::LLVM::LLVMPointerType::get(builder.getContext(), + static_cast<unsigned>(space)), + llvmPtr); + return addrCast; +} + +static mlir::Value genAtomBinOp(fir::FirOpBuilder &builder, mlir::Location &loc, + mlir::LLVM::AtomicBinOp binOp, mlir::Value arg0, + mlir::Value arg1) { + auto llvmPointerType = mlir::LLVM::LLVMPointerType::get(builder.getContext()); + arg0 = builder.createConvert(loc, llvmPointerType, arg0); + return mlir::LLVM::AtomicRMWOp::create(builder, loc, binOp, arg0, arg1, + mlir::LLVM::AtomicOrdering::seq_cst); +} + +// ATOMICADD +mlir::Value +CUDAIntrinsicLibrary::genAtomicAdd(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 2); + mlir::LLVM::AtomicBinOp binOp = + mlir::isa<mlir::IntegerType>(args[1].getType()) + ? mlir::LLVM::AtomicBinOp::add + : mlir::LLVM::AtomicBinOp::fadd; + return genAtomBinOp(builder, loc, binOp, args[0], args[1]); +} + +fir::ExtendedValue +CUDAIntrinsicLibrary::genAtomicAddR2(mlir::Type resultType, + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 2); + + mlir::Value a = fir::getBase(args[0]); + + if (mlir::isa<fir::BaseBoxType>(a.getType())) { + a = fir::BoxAddrOp::create(builder, loc, a); + } + + auto loc = builder.getUnknownLoc(); + auto f16Ty = builder.getF16Type(); + auto i32Ty = builder.getI32Type(); + auto vecF16Ty = mlir::VectorType::get({2}, f16Ty); + mlir::Type idxTy = builder.getIndexType(); + auto f16RefTy = fir::ReferenceType::get(f16Ty); + auto zero = builder.createIntegerConstant(loc, idxTy, 0); + auto one = builder.createIntegerConstant(loc, idxTy, 1); + auto v1Coord = fir::CoordinateOp::create(builder, loc, f16RefTy, + fir::getBase(args[1]), zero); + auto v2Coord = fir::CoordinateOp::create(builder, loc, f16RefTy, + fir::getBase(args[1]), one); + auto v1 = fir::LoadOp::create(builder, loc, v1Coord); + auto v2 = fir::LoadOp::create(builder, loc, v2Coord); + mlir::Value undef = mlir::LLVM::UndefOp::create(builder, loc, vecF16Ty); + mlir::Value vec1 = mlir::LLVM::InsertElementOp::create( + builder, loc, undef, v1, builder.createIntegerConstant(loc, i32Ty, 0)); + mlir::Value vec2 = mlir::LLVM::InsertElementOp::create( + builder, loc, vec1, v2, builder.createIntegerConstant(loc, i32Ty, 1)); + auto res = genAtomBinOp(builder, loc, mlir::LLVM::AtomicBinOp::fadd, a, vec2); + auto i32VecTy = mlir::VectorType::get({1}, i32Ty); + mlir::Value vecI32 = + mlir::vector::BitCastOp::create(builder, loc, i32VecTy, res); + return mlir::vector::ExtractOp::create(builder, loc, vecI32, + mlir::ArrayRef<int64_t>{0}); +} + +// ATOMICADDVECTOR +template <int extent> +fir::ExtendedValue CUDAIntrinsicLibrary::genAtomicAddVector( + mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 2); + mlir::Value res = fir::AllocaOp::create( + builder, loc, fir::SequenceType::get({extent}, resultType)); + mlir::Value a = fir::getBase(args[0]); + if (mlir::isa<fir::BaseBoxType>(a.getType())) { + a = fir::BoxAddrOp::create(builder, loc, a); + } + auto vecTy = mlir::VectorType::get({extent}, resultType); + auto refTy = fir::ReferenceType::get(resultType); + mlir::Type i32Ty = builder.getI32Type(); + mlir::Type idxTy = builder.getIndexType(); + + // Extract the values from the array. + llvm::SmallVector<mlir::Value> values; + for (unsigned i = 0; i < extent; ++i) { + mlir::Value pos = builder.createIntegerConstant(loc, idxTy, i); + mlir::Value coord = fir::CoordinateOp::create(builder, loc, refTy, + fir::getBase(args[1]), pos); + mlir::Value value = fir::LoadOp::create(builder, loc, coord); + values.push_back(value); + } + // Pack extracted values into a vector to call the atomic add. + mlir::Value undef = mlir::LLVM::UndefOp::create(builder, loc, vecTy); + for (unsigned i = 0; i < extent; ++i) { + mlir::Value insert = mlir::LLVM::InsertElementOp::create( + builder, loc, undef, values[i], + builder.createIntegerConstant(loc, i32Ty, i)); + undef = insert; + } + // Atomic operation with a vector of values. + mlir::Value add = + genAtomBinOp(builder, loc, mlir::LLVM::AtomicBinOp::fadd, a, undef); + // Store results in the result array. + for (unsigned i = 0; i < extent; ++i) { + mlir::Value r = mlir::LLVM::ExtractElementOp::create( + builder, loc, add, builder.createIntegerConstant(loc, i32Ty, i)); + mlir::Value c = fir::CoordinateOp::create( + builder, loc, refTy, res, builder.createIntegerConstant(loc, idxTy, i)); + fir::StoreOp::create(builder, loc, r, c); + } + mlir::Value ext = builder.createIntegerConstant(loc, idxTy, extent); + return fir::ArrayBoxValue(res, {ext}); +} + +// ATOMICADDVECTOR4x4 +fir::ExtendedValue CUDAIntrinsicLibrary::genAtomicAddVector4x4( + mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 2); + mlir::Value a = fir::getBase(args[0]); + if (mlir::isa<fir::BaseBoxType>(a.getType())) + a = fir::BoxAddrOp::create(builder, loc, a); + + const unsigned extent = 4; + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext()); + mlir::Value ptr = builder.createConvert(loc, llvmPtrTy, a); + mlir::Type f32Ty = builder.getF32Type(); + mlir::Type idxTy = builder.getIndexType(); + mlir::Type refTy = fir::ReferenceType::get(f32Ty); + llvm::SmallVector<mlir::Value> values; + for (unsigned i = 0; i < extent; ++i) { + mlir::Value pos = builder.createIntegerConstant(loc, idxTy, i); + mlir::Value coord = fir::CoordinateOp::create(builder, loc, refTy, + fir::getBase(args[1]), pos); + mlir::Value value = fir::LoadOp::create(builder, loc, coord); + values.push_back(value); + } + + auto inlinePtx = mlir::NVVM::InlinePtxOp::create( + builder, loc, {f32Ty, f32Ty, f32Ty, f32Ty}, + {ptr, values[0], values[1], values[2], values[3]}, {}, + "atom.add.v4.f32 {%0, %1, %2, %3}, [%4], {%5, %6, %7, %8};", {}); + + llvm::SmallVector<mlir::Value> results; + results.push_back(inlinePtx.getResult(0)); + results.push_back(inlinePtx.getResult(1)); + results.push_back(inlinePtx.getResult(2)); + results.push_back(inlinePtx.getResult(3)); + + mlir::Type vecF32Ty = mlir::VectorType::get({extent}, f32Ty); + mlir::Value undef = mlir::LLVM::UndefOp::create(builder, loc, vecF32Ty); + mlir::Type i32Ty = builder.getI32Type(); + for (unsigned i = 0; i < extent; ++i) + undef = mlir::LLVM::InsertElementOp::create( + builder, loc, undef, results[i], + builder.createIntegerConstant(loc, i32Ty, i)); + + auto i128Ty = builder.getIntegerType(128); + auto i128VecTy = mlir::VectorType::get({1}, i128Ty); + mlir::Value vec128 = + mlir::vector::BitCastOp::create(builder, loc, i128VecTy, undef); + return mlir::vector::ExtractOp::create(builder, loc, vec128, + mlir::ArrayRef<int64_t>{0}); +} + +mlir::Value +CUDAIntrinsicLibrary::genAtomicAnd(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 2); + assert(mlir::isa<mlir::IntegerType>(args[1].getType())); + + mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::_and; + return genAtomBinOp(builder, loc, binOp, args[0], args[1]); +} + +mlir::Value +CUDAIntrinsicLibrary::genAtomicOr(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 2); + assert(mlir::isa<mlir::IntegerType>(args[1].getType())); + + mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::_or; + return genAtomBinOp(builder, loc, binOp, args[0], args[1]); +} + +// ATOMICCAS +fir::ExtendedValue +CUDAIntrinsicLibrary::genAtomicCas(mlir::Type resultType, + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 3); + auto successOrdering = mlir::LLVM::AtomicOrdering::acq_rel; + auto failureOrdering = mlir::LLVM::AtomicOrdering::monotonic; + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(resultType.getContext()); + + mlir::Value arg0 = fir::getBase(args[0]); + mlir::Value arg1 = fir::getBase(args[1]); + mlir::Value arg2 = fir::getBase(args[2]); + + auto bitCastFloat = [&](mlir::Value arg) -> mlir::Value { + if (mlir::isa<mlir::Float32Type>(arg.getType())) + return mlir::LLVM::BitcastOp::create(builder, loc, builder.getI32Type(), + arg); + if (mlir::isa<mlir::Float64Type>(arg.getType())) + return mlir::LLVM::BitcastOp::create(builder, loc, builder.getI64Type(), + arg); + return arg; + }; + + arg1 = bitCastFloat(arg1); + arg2 = bitCastFloat(arg2); + + if (arg1.getType() != arg2.getType()) { + // arg1 and arg2 need to have the same type in AtomicCmpXchgOp. + arg2 = builder.createConvert(loc, arg1.getType(), arg2); + } + + auto address = + mlir::UnrealizedConversionCastOp::create(builder, loc, llvmPtrTy, arg0) + .getResult(0); + auto cmpxchg = mlir::LLVM::AtomicCmpXchgOp::create( + builder, loc, address, arg1, arg2, successOrdering, failureOrdering); + mlir::Value boolResult = + mlir::LLVM::ExtractValueOp::create(builder, loc, cmpxchg, 1); + return builder.createConvert(loc, resultType, boolResult); +} + +mlir::Value +CUDAIntrinsicLibrary::genAtomicDec(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 2); + assert(mlir::isa<mlir::IntegerType>(args[1].getType())); + + mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::udec_wrap; + return genAtomBinOp(builder, loc, binOp, args[0], args[1]); +} + +// ATOMICEXCH +fir::ExtendedValue +CUDAIntrinsicLibrary::genAtomicExch(mlir::Type resultType, + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 2); + mlir::Value arg0 = fir::getBase(args[0]); + mlir::Value arg1 = fir::getBase(args[1]); + assert(arg1.getType().isIntOrFloat()); + + mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::xchg; + return genAtomBinOp(builder, loc, binOp, arg0, arg1); +} + +mlir::Value +CUDAIntrinsicLibrary::genAtomicInc(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 2); + assert(mlir::isa<mlir::IntegerType>(args[1].getType())); + + mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::uinc_wrap; + return genAtomBinOp(builder, loc, binOp, args[0], args[1]); +} + +mlir::Value +CUDAIntrinsicLibrary::genAtomicMax(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 2); + + mlir::LLVM::AtomicBinOp binOp = + mlir::isa<mlir::IntegerType>(args[1].getType()) + ? mlir::LLVM::AtomicBinOp::max + : mlir::LLVM::AtomicBinOp::fmax; + return genAtomBinOp(builder, loc, binOp, args[0], args[1]); +} + +mlir::Value +CUDAIntrinsicLibrary::genAtomicMin(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 2); + + mlir::LLVM::AtomicBinOp binOp = + mlir::isa<mlir::IntegerType>(args[1].getType()) + ? mlir::LLVM::AtomicBinOp::min + : mlir::LLVM::AtomicBinOp::fmin; + return genAtomBinOp(builder, loc, binOp, args[0], args[1]); +} + +// ATOMICSUB +mlir::Value +CUDAIntrinsicLibrary::genAtomicSub(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 2); + mlir::LLVM::AtomicBinOp binOp = + mlir::isa<mlir::IntegerType>(args[1].getType()) + ? mlir::LLVM::AtomicBinOp::sub + : mlir::LLVM::AtomicBinOp::fsub; + return genAtomBinOp(builder, loc, binOp, args[0], args[1]); +} + +// ATOMICXOR +fir::ExtendedValue +CUDAIntrinsicLibrary::genAtomicXor(mlir::Type resultType, + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 2); + mlir::Value arg0 = fir::getBase(args[0]); + mlir::Value arg1 = fir::getBase(args[1]); + return genAtomBinOp(builder, loc, mlir::LLVM::AtomicBinOp::_xor, arg0, arg1); +} + +// BARRIER_ARRIVE +mlir::Value +CUDAIntrinsicLibrary::genBarrierArrive(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 1); + mlir::Value barrier = convertPtrToNVVMSpace( + builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared); + return mlir::NVVM::MBarrierArriveOp::create(builder, loc, resultType, barrier) + .getResult(0); +} + +// BARRIER_ARRIBVE_CNT +mlir::Value +CUDAIntrinsicLibrary::genBarrierArriveCnt(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 2); + mlir::Value barrier = convertPtrToNVVMSpace( + builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared); + return mlir::NVVM::InlinePtxOp::create(builder, loc, {resultType}, + {barrier, args[1]}, {}, + "mbarrier.arrive.expect_tx.release." + "cta.shared::cta.b64 %0, [%1], %2;", + {}) + .getResult(0); +} + +// BARRIER_INIT +void CUDAIntrinsicLibrary::genBarrierInit( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 2); + mlir::Value barrier = convertPtrToNVVMSpace( + builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared); + mlir::NVVM::MBarrierInitOp::create(builder, loc, barrier, + fir::getBase(args[1]), {}); + auto kind = mlir::NVVM::ProxyKindAttr::get( + builder.getContext(), mlir::NVVM::ProxyKind::async_shared); + auto space = mlir::NVVM::SharedSpaceAttr::get( + builder.getContext(), mlir::NVVM::SharedSpace::shared_cta); + mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space); +} + +// BARRIER_TRY_WAIT +mlir::Value +CUDAIntrinsicLibrary::genBarrierTryWait(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 2); + mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); + mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0); + fir::StoreOp::create(builder, loc, zero, res); + mlir::Value ns = + builder.createIntegerConstant(loc, builder.getI32Type(), 1000000); + mlir::Value load = fir::LoadOp::create(builder, loc, res); + auto whileOp = mlir::scf::WhileOp::create( + builder, loc, mlir::TypeRange{resultType}, mlir::ValueRange{load}); + mlir::Block *beforeBlock = builder.createBlock(&whileOp.getBefore()); + mlir::Value beforeArg = beforeBlock->addArgument(resultType, loc); + builder.setInsertionPointToStart(beforeBlock); + mlir::Value condition = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::eq, beforeArg, zero); + mlir::scf::ConditionOp::create(builder, loc, condition, beforeArg); + mlir::Block *afterBlock = builder.createBlock(&whileOp.getAfter()); + afterBlock->addArgument(resultType, loc); + builder.setInsertionPointToStart(afterBlock); + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext()); + auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]); + mlir::Value ret = mlir::NVVM::InlinePtxOp::create( + builder, loc, {resultType}, {barrier, args[1], ns}, {}, + "{\n" + " .reg .pred p;\n" + " mbarrier.try_wait.shared.b64 p, [%1], %2, %3;\n" + " selp.b32 %0, 1, 0, p;\n" + "}", + {}) + .getResult(0); + mlir::scf::YieldOp::create(builder, loc, ret); + builder.setInsertionPointAfter(whileOp); + return whileOp.getResult(0); +} + +// BARRIER_TRY_WAIT_SLEEP +mlir::Value +CUDAIntrinsicLibrary::genBarrierTryWaitSleep(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 3); + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext()); + auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]); + return mlir::NVVM::InlinePtxOp::create( + builder, loc, {resultType}, {barrier, args[1], args[2]}, {}, + "{\n" + " .reg .pred p;\n" + " mbarrier.try_wait.shared.b64 p, [%1], %2, %3;\n" + " selp.b32 %0, 1, 0, p;\n" + "}", + {}) + .getResult(0); +} + +static void insertValueAtPos(fir::FirOpBuilder &builder, mlir::Location loc, + fir::RecordType recTy, mlir::Value base, + mlir::Value dim, unsigned fieldPos) { + auto fieldName = recTy.getTypeList()[fieldPos].first; + mlir::Type fieldTy = recTy.getTypeList()[fieldPos].second; + mlir::Type fieldIndexType = fir::FieldType::get(base.getContext()); + mlir::Value fieldIndex = + fir::FieldIndexOp::create(builder, loc, fieldIndexType, fieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + mlir::Value coord = fir::CoordinateOp::create( + builder, loc, builder.getRefType(fieldTy), base, fieldIndex); + fir::StoreOp::create(builder, loc, dim, coord); +} + +// CLUSTER_BLOCK_INDEX +mlir::Value +CUDAIntrinsicLibrary::genClusterBlockIndex(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 0); + auto recTy = mlir::cast<fir::RecordType>(resultType); + assert(recTy && "RecordType expepected"); + mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); + mlir::Type i32Ty = builder.getI32Type(); + mlir::Value x = mlir::NVVM::BlockInClusterIdXOp::create(builder, loc, i32Ty); + mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1); + x = mlir::arith::AddIOp::create(builder, loc, x, one); + insertValueAtPos(builder, loc, recTy, res, x, 0); + mlir::Value y = mlir::NVVM::BlockInClusterIdYOp::create(builder, loc, i32Ty); + y = mlir::arith::AddIOp::create(builder, loc, y, one); + insertValueAtPos(builder, loc, recTy, res, y, 1); + mlir::Value z = mlir::NVVM::BlockInClusterIdZOp::create(builder, loc, i32Ty); + z = mlir::arith::AddIOp::create(builder, loc, z, one); + insertValueAtPos(builder, loc, recTy, res, z, 2); + return res; +} + +// CLUSTER_DIM_BLOCKS +mlir::Value +CUDAIntrinsicLibrary::genClusterDimBlocks(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 0); + auto recTy = mlir::cast<fir::RecordType>(resultType); + assert(recTy && "RecordType expepected"); + mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); + mlir::Type i32Ty = builder.getI32Type(); + mlir::Value x = mlir::NVVM::ClusterDimBlocksXOp::create(builder, loc, i32Ty); + insertValueAtPos(builder, loc, recTy, res, x, 0); + mlir::Value y = mlir::NVVM::ClusterDimBlocksYOp::create(builder, loc, i32Ty); + insertValueAtPos(builder, loc, recTy, res, y, 1); + mlir::Value z = mlir::NVVM::ClusterDimBlocksZOp::create(builder, loc, i32Ty); + insertValueAtPos(builder, loc, recTy, res, z, 2); + return res; +} + +// FENCE_PROXY_ASYNC +void CUDAIntrinsicLibrary::genFenceProxyAsync( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 0); + auto kind = mlir::NVVM::ProxyKindAttr::get( + builder.getContext(), mlir::NVVM::ProxyKind::async_shared); + auto space = mlir::NVVM::SharedSpaceAttr::get( + builder.getContext(), mlir::NVVM::SharedSpace::shared_cta); + mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space); +} + +// __LDCA, __LDCS, __LDLU, __LDCV +template <const char *fctName, int extent> +fir::ExtendedValue +CUDAIntrinsicLibrary::genLDXXFunc(mlir::Type resultType, + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 1); + mlir::Type resTy = fir::SequenceType::get(extent, resultType); + mlir::Value arg = fir::getBase(args[0]); + mlir::Value res = fir::AllocaOp::create(builder, loc, resTy); + if (mlir::isa<fir::BaseBoxType>(arg.getType())) + arg = fir::BoxAddrOp::create(builder, loc, arg); + mlir::Type refResTy = fir::ReferenceType::get(resTy); + mlir::FunctionType ftype = + mlir::FunctionType::get(arg.getContext(), {refResTy, refResTy}, {}); + auto funcOp = builder.createFunction(loc, fctName, ftype); + llvm::SmallVector<mlir::Value> funcArgs; + funcArgs.push_back(res); + funcArgs.push_back(arg); + fir::CallOp::create(builder, loc, funcOp, funcArgs); + mlir::Value ext = + builder.createIntegerConstant(loc, builder.getIndexType(), extent); + return fir::ArrayBoxValue(res, {ext}); +} + +// CLOCK, CLOCK64, GLOBALTIMER +template <typename OpTy> +mlir::Value +CUDAIntrinsicLibrary::genNVVMTime(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 0 && "expect no arguments"); + return OpTy::create(builder, loc, resultType).getResult(); +} + +// MATCH_ALL_SYNC +mlir::Value +CUDAIntrinsicLibrary::genMatchAllSync(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 3); + bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32(); + + mlir::Type i1Ty = builder.getI1Type(); + mlir::MLIRContext *context = builder.getContext(); + + mlir::Value arg1 = args[1]; + if (arg1.getType().isF32() || arg1.getType().isF64()) + arg1 = fir::ConvertOp::create( + builder, loc, is32 ? builder.getI32Type() : builder.getI64Type(), arg1); + + mlir::Type retTy = + mlir::LLVM::LLVMStructType::getLiteral(context, {resultType, i1Ty}); + auto match = + mlir::NVVM::MatchSyncOp::create(builder, loc, retTy, args[0], arg1, + mlir::NVVM::MatchSyncKind::all) + .getResult(); + auto value = mlir::LLVM::ExtractValueOp::create(builder, loc, match, 0); + auto pred = mlir::LLVM::ExtractValueOp::create(builder, loc, match, 1); + auto conv = mlir::LLVM::ZExtOp::create(builder, loc, resultType, pred); + fir::StoreOp::create(builder, loc, conv, args[2]); + return value; +} + +// MATCH_ANY_SYNC +mlir::Value +CUDAIntrinsicLibrary::genMatchAnySync(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 2); + bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32(); + + mlir::Value arg1 = args[1]; + if (arg1.getType().isF32() || arg1.getType().isF64()) + arg1 = fir::ConvertOp::create( + builder, loc, is32 ? builder.getI32Type() : builder.getI64Type(), arg1); + + return mlir::NVVM::MatchSyncOp::create(builder, loc, resultType, args[0], + arg1, mlir::NVVM::MatchSyncKind::any) + .getResult(); +} + +// SYNCTHREADS +void CUDAIntrinsicLibrary::genSyncThreads( + llvm::ArrayRef<fir::ExtendedValue> args) { + mlir::NVVM::Barrier0Op::create(builder, loc); +} + +// SYNCTHREADS_AND +mlir::Value +CUDAIntrinsicLibrary::genSyncThreadsAnd(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + mlir::Value arg = builder.createConvert(loc, builder.getI32Type(), args[0]); + return mlir::NVVM::BarrierOp::create( + builder, loc, resultType, {}, {}, + mlir::NVVM::BarrierReductionAttr::get( + builder.getContext(), mlir::NVVM::BarrierReduction::AND), + arg) + .getResult(0); +} + +// SYNCTHREADS_COUNT +mlir::Value +CUDAIntrinsicLibrary::genSyncThreadsCount(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + mlir::Value arg = builder.createConvert(loc, builder.getI32Type(), args[0]); + return mlir::NVVM::BarrierOp::create( + builder, loc, resultType, {}, {}, + mlir::NVVM::BarrierReductionAttr::get( + builder.getContext(), mlir::NVVM::BarrierReduction::POPC), + arg) + .getResult(0); +} + +// SYNCTHREADS_OR +mlir::Value +CUDAIntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + mlir::Value arg = builder.createConvert(loc, builder.getI32Type(), args[0]); + return mlir::NVVM::BarrierOp::create( + builder, loc, resultType, {}, {}, + mlir::NVVM::BarrierReductionAttr::get( + builder.getContext(), mlir::NVVM::BarrierReduction::OR), + arg) + .getResult(0); +} + +// SYNCWARP +void CUDAIntrinsicLibrary::genSyncWarp( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 1); + mlir::NVVM::SyncWarpOp::create(builder, loc, fir::getBase(args[0])); +} + +// THIS_CLUSTER +mlir::Value +CUDAIntrinsicLibrary::genThisCluster(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 0); + auto recTy = mlir::cast<fir::RecordType>(resultType); + assert(recTy && "RecordType expepected"); + mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); + mlir::Type i32Ty = builder.getI32Type(); + + // SIZE + mlir::Value size = mlir::NVVM::ClusterDim::create(builder, loc, i32Ty); + auto sizeFieldName = recTy.getTypeList()[1].first; + mlir::Type sizeFieldTy = recTy.getTypeList()[1].second; + mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext()); + mlir::Value sizeFieldIndex = fir::FieldIndexOp::create( + builder, loc, fieldIndexType, sizeFieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + mlir::Value sizeCoord = fir::CoordinateOp::create( + builder, loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex); + fir::StoreOp::create(builder, loc, size, sizeCoord); + + // RANK + mlir::Value rank = mlir::NVVM::ClusterId::create(builder, loc, i32Ty); + mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1); + rank = mlir::arith::AddIOp::create(builder, loc, rank, one); + auto rankFieldName = recTy.getTypeList()[2].first; + mlir::Type rankFieldTy = recTy.getTypeList()[2].second; + mlir::Value rankFieldIndex = fir::FieldIndexOp::create( + builder, loc, fieldIndexType, rankFieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + mlir::Value rankCoord = fir::CoordinateOp::create( + builder, loc, builder.getRefType(rankFieldTy), res, rankFieldIndex); + fir::StoreOp::create(builder, loc, rank, rankCoord); + + return res; +} + +// THIS_GRID +mlir::Value +CUDAIntrinsicLibrary::genThisGrid(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 0); + auto recTy = mlir::cast<fir::RecordType>(resultType); + assert(recTy && "RecordType expepected"); + mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); + mlir::Type i32Ty = builder.getI32Type(); + + mlir::Value threadIdX = mlir::NVVM::ThreadIdXOp::create(builder, loc, i32Ty); + mlir::Value threadIdY = mlir::NVVM::ThreadIdYOp::create(builder, loc, i32Ty); + mlir::Value threadIdZ = mlir::NVVM::ThreadIdZOp::create(builder, loc, i32Ty); + + mlir::Value blockIdX = mlir::NVVM::BlockIdXOp::create(builder, loc, i32Ty); + mlir::Value blockIdY = mlir::NVVM::BlockIdYOp::create(builder, loc, i32Ty); + mlir::Value blockIdZ = mlir::NVVM::BlockIdZOp::create(builder, loc, i32Ty); + + mlir::Value blockDimX = mlir::NVVM::BlockDimXOp::create(builder, loc, i32Ty); + mlir::Value blockDimY = mlir::NVVM::BlockDimYOp::create(builder, loc, i32Ty); + mlir::Value blockDimZ = mlir::NVVM::BlockDimZOp::create(builder, loc, i32Ty); + mlir::Value gridDimX = mlir::NVVM::GridDimXOp::create(builder, loc, i32Ty); + mlir::Value gridDimY = mlir::NVVM::GridDimYOp::create(builder, loc, i32Ty); + mlir::Value gridDimZ = mlir::NVVM::GridDimZOp::create(builder, loc, i32Ty); + + // this_grid.size = ((blockDim.z * gridDim.z) * (blockDim.y * gridDim.y)) * + // (blockDim.x * gridDim.x); + mlir::Value resZ = + mlir::arith::MulIOp::create(builder, loc, blockDimZ, gridDimZ); + mlir::Value resY = + mlir::arith::MulIOp::create(builder, loc, blockDimY, gridDimY); + mlir::Value resX = + mlir::arith::MulIOp::create(builder, loc, blockDimX, gridDimX); + mlir::Value resZY = mlir::arith::MulIOp::create(builder, loc, resZ, resY); + mlir::Value size = mlir::arith::MulIOp::create(builder, loc, resZY, resX); + + // tmp = ((blockIdx.z * gridDim.y * gridDim.x) + (blockIdx.y * gridDim.x)) + + // blockIdx.x; + // this_group.rank = tmp * ((blockDim.x * blockDim.y) * blockDim.z) + + // ((threadIdx.z * blockDim.y) * blockDim.x) + + // (threadIdx.y * blockDim.x) + threadIdx.x + 1; + mlir::Value r1 = + mlir::arith::MulIOp::create(builder, loc, blockIdZ, gridDimY); + mlir::Value r2 = mlir::arith::MulIOp::create(builder, loc, r1, gridDimX); + mlir::Value r3 = + mlir::arith::MulIOp::create(builder, loc, blockIdY, gridDimX); + mlir::Value r2r3 = mlir::arith::AddIOp::create(builder, loc, r2, r3); + mlir::Value tmp = mlir::arith::AddIOp::create(builder, loc, r2r3, blockIdX); + + mlir::Value bXbY = + mlir::arith::MulIOp::create(builder, loc, blockDimX, blockDimY); + mlir::Value bXbYbZ = + mlir::arith::MulIOp::create(builder, loc, bXbY, blockDimZ); + mlir::Value tZbY = + mlir::arith::MulIOp::create(builder, loc, threadIdZ, blockDimY); + mlir::Value tZbYbX = + mlir::arith::MulIOp::create(builder, loc, tZbY, blockDimX); + mlir::Value tYbX = + mlir::arith::MulIOp::create(builder, loc, threadIdY, blockDimX); + mlir::Value rank = mlir::arith::MulIOp::create(builder, loc, tmp, bXbYbZ); + rank = mlir::arith::AddIOp::create(builder, loc, rank, tZbYbX); + rank = mlir::arith::AddIOp::create(builder, loc, rank, tYbX); + rank = mlir::arith::AddIOp::create(builder, loc, rank, threadIdX); + mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1); + rank = mlir::arith::AddIOp::create(builder, loc, rank, one); + + auto sizeFieldName = recTy.getTypeList()[1].first; + mlir::Type sizeFieldTy = recTy.getTypeList()[1].second; + mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext()); + mlir::Value sizeFieldIndex = fir::FieldIndexOp::create( + builder, loc, fieldIndexType, sizeFieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + mlir::Value sizeCoord = fir::CoordinateOp::create( + builder, loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex); + fir::StoreOp::create(builder, loc, size, sizeCoord); + + auto rankFieldName = recTy.getTypeList()[2].first; + mlir::Type rankFieldTy = recTy.getTypeList()[2].second; + mlir::Value rankFieldIndex = fir::FieldIndexOp::create( + builder, loc, fieldIndexType, rankFieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + mlir::Value rankCoord = fir::CoordinateOp::create( + builder, loc, builder.getRefType(rankFieldTy), res, rankFieldIndex); + fir::StoreOp::create(builder, loc, rank, rankCoord); + return res; +} + +// THIS_THREAD_BLOCK +mlir::Value +CUDAIntrinsicLibrary::genThisThreadBlock(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 0); + auto recTy = mlir::cast<fir::RecordType>(resultType); + assert(recTy && "RecordType expepected"); + mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); + mlir::Type i32Ty = builder.getI32Type(); + + // this_thread_block%size = blockDim.z * blockDim.y * blockDim.x; + mlir::Value blockDimX = mlir::NVVM::BlockDimXOp::create(builder, loc, i32Ty); + mlir::Value blockDimY = mlir::NVVM::BlockDimYOp::create(builder, loc, i32Ty); + mlir::Value blockDimZ = mlir::NVVM::BlockDimZOp::create(builder, loc, i32Ty); + mlir::Value size = + mlir::arith::MulIOp::create(builder, loc, blockDimZ, blockDimY); + size = mlir::arith::MulIOp::create(builder, loc, size, blockDimX); + + // this_thread_block%rank = ((threadIdx.z * blockDim.y) * blockDim.x) + + // (threadIdx.y * blockDim.x) + threadIdx.x + 1; + mlir::Value threadIdX = mlir::NVVM::ThreadIdXOp::create(builder, loc, i32Ty); + mlir::Value threadIdY = mlir::NVVM::ThreadIdYOp::create(builder, loc, i32Ty); + mlir::Value threadIdZ = mlir::NVVM::ThreadIdZOp::create(builder, loc, i32Ty); + mlir::Value r1 = + mlir::arith::MulIOp::create(builder, loc, threadIdZ, blockDimY); + mlir::Value r2 = mlir::arith::MulIOp::create(builder, loc, r1, blockDimX); + mlir::Value r3 = + mlir::arith::MulIOp::create(builder, loc, threadIdY, blockDimX); + mlir::Value r2r3 = mlir::arith::AddIOp::create(builder, loc, r2, r3); + mlir::Value rank = mlir::arith::AddIOp::create(builder, loc, r2r3, threadIdX); + mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1); + rank = mlir::arith::AddIOp::create(builder, loc, rank, one); + + auto sizeFieldName = recTy.getTypeList()[1].first; + mlir::Type sizeFieldTy = recTy.getTypeList()[1].second; + mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext()); + mlir::Value sizeFieldIndex = fir::FieldIndexOp::create( + builder, loc, fieldIndexType, sizeFieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + mlir::Value sizeCoord = fir::CoordinateOp::create( + builder, loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex); + fir::StoreOp::create(builder, loc, size, sizeCoord); + + auto rankFieldName = recTy.getTypeList()[2].first; + mlir::Type rankFieldTy = recTy.getTypeList()[2].second; + mlir::Value rankFieldIndex = fir::FieldIndexOp::create( + builder, loc, fieldIndexType, rankFieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + mlir::Value rankCoord = fir::CoordinateOp::create( + builder, loc, builder.getRefType(rankFieldTy), res, rankFieldIndex); + fir::StoreOp::create(builder, loc, rank, rankCoord); + return res; +} + +// THIS_WARP +mlir::Value +CUDAIntrinsicLibrary::genThisWarp(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 0); + auto recTy = mlir::cast<fir::RecordType>(resultType); + assert(recTy && "RecordType expepected"); + mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); + mlir::Type i32Ty = builder.getI32Type(); + + // coalesced_group%size = 32 + mlir::Value size = builder.createIntegerConstant(loc, i32Ty, 32); + auto sizeFieldName = recTy.getTypeList()[1].first; + mlir::Type sizeFieldTy = recTy.getTypeList()[1].second; + mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext()); + mlir::Value sizeFieldIndex = fir::FieldIndexOp::create( + builder, loc, fieldIndexType, sizeFieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + mlir::Value sizeCoord = fir::CoordinateOp::create( + builder, loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex); + fir::StoreOp::create(builder, loc, size, sizeCoord); + + // coalesced_group%rank = threadIdx.x & 31 + 1 + mlir::Value threadIdX = mlir::NVVM::ThreadIdXOp::create(builder, loc, i32Ty); + mlir::Value mask = builder.createIntegerConstant(loc, i32Ty, 31); + mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1); + mlir::Value masked = + mlir::arith::AndIOp::create(builder, loc, threadIdX, mask); + mlir::Value rank = mlir::arith::AddIOp::create(builder, loc, masked, one); + auto rankFieldName = recTy.getTypeList()[2].first; + mlir::Type rankFieldTy = recTy.getTypeList()[2].second; + mlir::Value rankFieldIndex = fir::FieldIndexOp::create( + builder, loc, fieldIndexType, rankFieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + mlir::Value rankCoord = fir::CoordinateOp::create( + builder, loc, builder.getRefType(rankFieldTy), res, rankFieldIndex); + fir::StoreOp::create(builder, loc, rank, rankCoord); + return res; +} + +// THREADFENCE, THREADFENCE_BLOCK, THREADFENCE_SYSTEM +template <mlir::NVVM::MemScopeKind scope> +void CUDAIntrinsicLibrary::genThreadFence( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 0); + mlir::NVVM::MembarOp::create(builder, loc, scope); +} + +// TMA_BULK_COMMIT_GROUP +void CUDAIntrinsicLibrary::genTMABulkCommitGroup( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 0); + mlir::NVVM::CpAsyncBulkCommitGroupOp::create(builder, loc); +} + +// TMA_BULK_G2S +void CUDAIntrinsicLibrary::genTMABulkG2S( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 4); + mlir::Value barrier = convertPtrToNVVMSpace( + builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared); + mlir::Value dst = + convertPtrToNVVMSpace(builder, loc, fir::getBase(args[2]), + mlir::NVVM::NVVMMemorySpace::SharedCluster); + mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]), + mlir::NVVM::NVVMMemorySpace::Global); + mlir::NVVM::CpAsyncBulkGlobalToSharedClusterOp::create( + builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {}); +} + +static void setAlignment(mlir::Value ptr, unsigned alignment) { + if (auto declareOp = mlir::dyn_cast<hlfir::DeclareOp>(ptr.getDefiningOp())) + if (auto sharedOp = mlir::dyn_cast<cuf::SharedMemoryOp>( + declareOp.getMemref().getDefiningOp())) + sharedOp.setAlignment(alignment); +} + +static void genTMABulkLoad(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value barrier, mlir::Value src, + mlir::Value dst, mlir::Value nelem, + mlir::Value eleSize) { + mlir::Value size = mlir::arith::MulIOp::create(builder, loc, nelem, eleSize); + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext()); + barrier = builder.createConvert(loc, llvmPtrTy, barrier); + setAlignment(dst, kTMAAlignment); + dst = builder.createConvert(loc, llvmPtrTy, dst); + src = builder.createConvert(loc, llvmPtrTy, src); + mlir::NVVM::InlinePtxOp::create( + builder, loc, mlir::TypeRange{}, {dst, src, size, barrier}, {}, + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], " + "[%1], %2, [%3];", + {}); + mlir::NVVM::InlinePtxOp::create( + builder, loc, mlir::TypeRange{}, {barrier, size}, {}, + "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;", {}); +} + +// TMA_BULK_LOADC4 +void CUDAIntrinsicLibrary::genTMABulkLoadC4( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 4); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 8); + genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), fir::getBase(args[3]), eleSize); +} + +// TMA_BULK_LOADC8 +void CUDAIntrinsicLibrary::genTMABulkLoadC8( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 4); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 16); + genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), fir::getBase(args[3]), eleSize); +} + +// TMA_BULK_LOADI4 +void CUDAIntrinsicLibrary::genTMABulkLoadI4( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 4); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 4); + genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), fir::getBase(args[3]), eleSize); +} + +// TMA_BULK_LOADI8 +void CUDAIntrinsicLibrary::genTMABulkLoadI8( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 4); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 8); + genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), fir::getBase(args[3]), eleSize); +} + +// TMA_BULK_LOADR2 +void CUDAIntrinsicLibrary::genTMABulkLoadR2( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 4); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 2); + genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), fir::getBase(args[3]), eleSize); +} + +// TMA_BULK_LOADR4 +void CUDAIntrinsicLibrary::genTMABulkLoadR4( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 4); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 4); + genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), fir::getBase(args[3]), eleSize); +} + +// TMA_BULK_LOADR8 +void CUDAIntrinsicLibrary::genTMABulkLoadR8( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 4); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 8); + genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), fir::getBase(args[3]), eleSize); +} + +// TMA_BULK_S2G +void CUDAIntrinsicLibrary::genTMABulkS2G( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 3); + mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[0]), + mlir::NVVM::NVVMMemorySpace::Shared); + mlir::Value dst = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]), + mlir::NVVM::NVVMMemorySpace::Global); + mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create( + builder, loc, dst, src, fir::getBase(args[2]), {}, {}); + + mlir::NVVM::InlinePtxOp::create(builder, loc, mlir::TypeRange{}, {}, {}, + "cp.async.bulk.commit_group;", {}); + mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc, + builder.getI32IntegerAttr(0), {}); +} + +static void genTMABulkStore(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value src, mlir::Value dst, mlir::Value count, + mlir::Value eleSize) { + mlir::Value size = mlir::arith::MulIOp::create(builder, loc, eleSize, count); + setAlignment(src, kTMAAlignment); + src = convertPtrToNVVMSpace(builder, loc, src, + mlir::NVVM::NVVMMemorySpace::Shared); + dst = convertPtrToNVVMSpace(builder, loc, dst, + mlir::NVVM::NVVMMemorySpace::Global); + mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create(builder, loc, dst, src, + size, {}, {}); + mlir::NVVM::InlinePtxOp::create(builder, loc, mlir::TypeRange{}, {}, {}, + "cp.async.bulk.commit_group;", {}); + mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc, + builder.getI32IntegerAttr(0), {}); +} + +// TMA_BULK_STORE_C4 +void CUDAIntrinsicLibrary::genTMABulkStoreC4( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 3); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 8); + genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_C8 +void CUDAIntrinsicLibrary::genTMABulkStoreC8( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 3); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 16); + genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_I4 +void CUDAIntrinsicLibrary::genTMABulkStoreI4( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 3); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 4); + genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_I8 +void CUDAIntrinsicLibrary::genTMABulkStoreI8( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 3); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 8); + genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_R2 +void CUDAIntrinsicLibrary::genTMABulkStoreR2( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 3); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 2); + genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_R4 +void CUDAIntrinsicLibrary::genTMABulkStoreR4( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 3); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 4); + genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_R8 +void CUDAIntrinsicLibrary::genTMABulkStoreR8( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 3); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 8); + genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_WAIT_GROUP +void CUDAIntrinsicLibrary::genTMABulkWaitGroup( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 0); + auto group = builder.getIntegerAttr(builder.getI32Type(), 0); + mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc, group, {}); +} + +// ALL_SYNC, ANY_SYNC, BALLOT_SYNC +template <mlir::NVVM::VoteSyncKind kind> +mlir::Value +CUDAIntrinsicLibrary::genVoteSync(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + assert(args.size() == 2); + mlir::Value arg1 = + fir::ConvertOp::create(builder, loc, builder.getI1Type(), args[1]); + mlir::Type resTy = kind == mlir::NVVM::VoteSyncKind::ballot + ? builder.getI32Type() + : builder.getI1Type(); + auto voteRes = + mlir::NVVM::VoteSyncOp::create(builder, loc, resTy, args[0], arg1, kind) + .getResult(); + return fir::ConvertOp::create(builder, loc, resultType, voteRes); +} + +} // namespace fir diff --git a/flang/lib/Optimizer/Builder/CUFCommon.cpp b/flang/lib/Optimizer/Builder/CUFCommon.cpp index cf7588f..2266f4d 100644 --- a/flang/lib/Optimizer/Builder/CUFCommon.cpp +++ b/flang/lib/Optimizer/Builder/CUFCommon.cpp @@ -9,6 +9,7 @@ #include "flang/Optimizer/Builder/CUFCommon.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Dialect/CUF/CUFOps.h" +#include "flang/Optimizer/Dialect/Support/KindMapping.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" @@ -91,3 +92,66 @@ void cuf::genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder) { } } } + +int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type, + fir::KindMapping &kindMap, + bool emitErrorOnFailure) { + auto eleTy = fir::unwrapSequenceType(type); + if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) + return t.getWidth() / 8; + if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) + return t.getWidth() / 8; + if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) + return kindMap.getLogicalBitsize(t.getFKind()) / 8; + if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) { + int elemSize = + mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8; + return 2 * elemSize; + } + if (auto t{mlir::dyn_cast<fir::CharacterType>(eleTy)}) + return kindMap.getCharacterBitsize(t.getFKind()) / 8; + if (emitErrorOnFailure) + mlir::emitError(loc, "unsupported type"); + return 0; +} + +mlir::Value cuf::computeElementCount(mlir::PatternRewriter &rewriter, + mlir::Location loc, + mlir::Value shapeOperand, + mlir::Type seqType, + mlir::Type targetType) { + if (shapeOperand) { + // Dynamic extent - extract from shape operand + llvm::SmallVector<mlir::Value> extents; + if (auto shapeOp = + mlir::dyn_cast<fir::ShapeOp>(shapeOperand.getDefiningOp())) { + extents = shapeOp.getExtents(); + } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>( + shapeOperand.getDefiningOp())) { + for (auto i : llvm::enumerate(shapeShiftOp.getPairs())) + if (i.index() & 1) + extents.push_back(i.value()); + } + + if (extents.empty()) + return mlir::Value(); + + // Compute total element count by multiplying all dimensions + mlir::Value count = + fir::ConvertOp::create(rewriter, loc, targetType, extents[0]); + for (unsigned i = 1; i < extents.size(); ++i) { + auto operand = + fir::ConvertOp::create(rewriter, loc, targetType, extents[i]); + count = mlir::arith::MulIOp::create(rewriter, loc, count, operand); + } + return count; + } else { + // Static extent - use constant array size + if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(seqType)) { + mlir::IntegerAttr attr = + rewriter.getIntegerAttr(targetType, seqTy.getConstantArraySize()); + return mlir::arith::ConstantOp::create(rewriter, loc, targetType, attr); + } + } + return mlir::Value(); +} diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp index 5da27d1..6a9c84f 100644 --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Analysis/AliasAnalysis.h" #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/Character.h" #include "flang/Optimizer/Builder/Complex.h" @@ -427,7 +428,8 @@ mlir::Value fir::FirOpBuilder::genTempDeclareOp( builder, loc, memref.getType(), memref, shape, typeParams, /*dummy_scope=*/nullptr, /*storage=*/nullptr, - /*storage_offset=*/0, nameAttr, fortranAttrs, cuf::DataAttributeAttr{}); + /*storage_offset=*/0, nameAttr, fortranAttrs, cuf::DataAttributeAttr{}, + /*dummy_arg_no=*/mlir::IntegerAttr{}); } mlir::Value fir::FirOpBuilder::genStackSave(mlir::Location loc) { @@ -858,21 +860,32 @@ mlir::Value fir::FirOpBuilder::genIsNullAddr(mlir::Location loc, mlir::arith::CmpIPredicate::eq); } -mlir::Value fir::FirOpBuilder::genExtentFromTriplet(mlir::Location loc, - mlir::Value lb, - mlir::Value ub, - mlir::Value step, - mlir::Type type) { +template <typename OpTy, typename... Args> +static mlir::Value createAndMaybeFold(bool fold, fir::FirOpBuilder &builder, + mlir::Location loc, Args &&...args) { + if (fold) + return builder.createOrFold<OpTy>(loc, std::forward<Args>(args)...); + return OpTy::create(builder, loc, std::forward<Args>(args)...); +} + +mlir::Value +fir::FirOpBuilder::genExtentFromTriplet(mlir::Location loc, mlir::Value lb, + mlir::Value ub, mlir::Value step, + mlir::Type type, bool fold) { auto zero = createIntegerConstant(loc, type, 0); lb = createConvert(loc, type, lb); ub = createConvert(loc, type, ub); step = createConvert(loc, type, step); - auto diff = mlir::arith::SubIOp::create(*this, loc, ub, lb); - auto add = mlir::arith::AddIOp::create(*this, loc, diff, step); - auto div = mlir::arith::DivSIOp::create(*this, loc, add, step); - auto cmp = mlir::arith::CmpIOp::create( - *this, loc, mlir::arith::CmpIPredicate::sgt, div, zero); - return mlir::arith::SelectOp::create(*this, loc, cmp, div, zero); + + auto diff = createAndMaybeFold<mlir::arith::SubIOp>(fold, *this, loc, ub, lb); + auto add = + createAndMaybeFold<mlir::arith::AddIOp>(fold, *this, loc, diff, step); + auto div = + createAndMaybeFold<mlir::arith::DivSIOp>(fold, *this, loc, add, step); + auto cmp = createAndMaybeFold<mlir::arith::CmpIOp>( + fold, *this, loc, mlir::arith::CmpIPredicate::sgt, div, zero); + return createAndMaybeFold<mlir::arith::SelectOp>(fold, *this, loc, cmp, div, + zero); } mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc, @@ -1392,12 +1405,10 @@ fir::ExtendedValue fir::factory::arraySectionElementToExtendedValue( return fir::factory::componentToExtendedValue(builder, loc, element); } -void fir::factory::genScalarAssignment(fir::FirOpBuilder &builder, - mlir::Location loc, - const fir::ExtendedValue &lhs, - const fir::ExtendedValue &rhs, - bool needFinalization, - bool isTemporaryLHS) { +void fir::factory::genScalarAssignment( + fir::FirOpBuilder &builder, mlir::Location loc, + const fir::ExtendedValue &lhs, const fir::ExtendedValue &rhs, + bool needFinalization, bool isTemporaryLHS, mlir::ArrayAttr accessGroups) { assert(lhs.rank() == 0 && rhs.rank() == 0 && "must be scalars"); auto type = fir::unwrapSequenceType( fir::unwrapPassByRefType(fir::getBase(lhs).getType())); @@ -1419,7 +1430,9 @@ void fir::factory::genScalarAssignment(fir::FirOpBuilder &builder, mlir::Value lhsAddr = fir::getBase(lhs); rhsVal = builder.createConvert(loc, fir::unwrapRefType(lhsAddr.getType()), rhsVal); - fir::StoreOp::create(builder, loc, rhsVal, lhsAddr); + fir::StoreOp store = fir::StoreOp::create(builder, loc, rhsVal, lhsAddr); + if (accessGroups) + store.setAccessGroupsAttr(accessGroups); } } @@ -1554,8 +1567,15 @@ void fir::factory::genRecordAssignment(fir::FirOpBuilder &builder, mlir::isa<fir::BaseBoxType>(fir::getBase(rhs).getType()); auto recTy = mlir::dyn_cast<fir::RecordType>(baseTy); assert(recTy && "must be a record type"); + + // Use alias analysis to guard the fast path. + fir::AliasAnalysis aa; + // Aliased SEQUENCE types must take the conservative (slow) path. + bool disjoint = isTemporaryLHS || !recTy.isSequence() || + (aa.alias(fir::getBase(lhs), fir::getBase(rhs)) == + mlir::AliasResult::NoAlias); if ((needFinalization && mayHaveFinalizer(recTy, builder)) || - hasBoxOperands || !recordTypeCanBeMemCopied(recTy)) { + hasBoxOperands || !recordTypeCanBeMemCopied(recTy) || !disjoint) { auto to = fir::getBase(builder.createBox(loc, lhs)); auto from = fir::getBase(builder.createBox(loc, rhs)); // The runtime entry point may modify the LHS descriptor if it is @@ -1670,6 +1690,26 @@ mlir::Value fir::factory::createZeroValue(fir::FirOpBuilder &builder, "numeric or logical type"); } +mlir::Value fir::factory::createOneValue(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Type type) { + mlir::Type i1 = builder.getIntegerType(1); + if (mlir::isa<fir::LogicalType>(type) || type == i1) + return builder.createConvert(loc, type, builder.createBool(loc, true)); + if (fir::isa_integer(type)) + return builder.createIntegerConstant(loc, type, 1); + if (fir::isa_real(type)) + return builder.createRealOneConstant(loc, type); + if (fir::isa_complex(type)) { + fir::factory::Complex complexHelper(builder, loc); + mlir::Type partType = complexHelper.getComplexPartType(type); + mlir::Value realPart = builder.createRealOneConstant(loc, partType); + mlir::Value imagPart = builder.createRealZeroConstant(loc, partType); + return complexHelper.createComplex(type, realPart, imagPart); + } + fir::emitFatalError(loc, "internal: trying to generate one value of non " + "numeric or logical type"); +} + std::optional<std::int64_t> fir::factory::getExtentFromTriplet(mlir::Value lb, mlir::Value ub, mlir::Value stride) { diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp index 93dfc57..3355bf1 100644 --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -250,7 +250,7 @@ hlfir::genDeclare(mlir::Location loc, fir::FirOpBuilder &builder, const fir::ExtendedValue &exv, llvm::StringRef name, fir::FortranVariableFlagsAttr flags, mlir::Value dummyScope, mlir::Value storage, std::uint64_t storageOffset, - cuf::DataAttributeAttr dataAttr) { + cuf::DataAttributeAttr dataAttr, unsigned dummyArgNo) { mlir::Value base = fir::getBase(exv); assert(fir::conformsWithPassByRef(base.getType()) && @@ -281,7 +281,7 @@ hlfir::genDeclare(mlir::Location loc, fir::FirOpBuilder &builder, [](const auto &) {}); auto declareOp = hlfir::DeclareOp::create( builder, loc, base, name, shapeOrShift, lenParams, dummyScope, storage, - storageOffset, flags, dataAttr); + storageOffset, flags, dataAttr, dummyArgNo); return mlir::cast<fir::FortranVariableOpInterface>(declareOp.getOperation()); } @@ -402,9 +402,9 @@ hlfir::Entity hlfir::genVariableBox(mlir::Location loc, fir::BoxType::get(var.getElementOrSequenceType(), isVolatile); if (forceBoxType) { boxType = forceBoxType; - mlir::Type baseType = - fir::ReferenceType::get(fir::unwrapRefType(forceBoxType.getEleTy())); - addr = builder.createConvert(loc, baseType, addr); + mlir::Type baseType = fir::ReferenceType::get( + fir::unwrapRefType(forceBoxType.getEleTy()), forceBoxType.isVolatile()); + addr = builder.createConvertWithVolatileCast(loc, baseType, addr); } auto embox = fir::EmboxOp::create(builder, loc, boxType, addr, shape, /*slice=*/mlir::Value{}, typeParams); @@ -1392,6 +1392,79 @@ bool hlfir::elementalOpMustProduceTemp(hlfir::ElementalOp elemental) { return false; } +static void combineAndStoreElement( + mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity lhs, + hlfir::Entity rhs, bool temporaryLHS, + std::function<hlfir::Entity(mlir::Location, fir::FirOpBuilder &, + hlfir::Entity, hlfir::Entity)> *combiner, + mlir::ArrayAttr accessGroups) { + hlfir::Entity valueToAssign = hlfir::loadTrivialScalar(loc, builder, rhs); + if (accessGroups) + if (auto load = valueToAssign.getDefiningOp<fir::LoadOp>()) + load.setAccessGroupsAttr(accessGroups); + if (combiner) { + hlfir::Entity lhsValue = hlfir::loadTrivialScalar(loc, builder, lhs); + if (accessGroups) + if (auto load = lhsValue.getDefiningOp<fir::LoadOp>()) + load.setAccessGroupsAttr(accessGroups); + valueToAssign = (*combiner)(loc, builder, lhsValue, valueToAssign); + } + auto assign = hlfir::AssignOp::create(builder, loc, valueToAssign, lhs, + /*realloc=*/false, + /*keep_lhs_length_if_realloc=*/false, + /*temporary_lhs=*/temporaryLHS); + if (accessGroups) + assign->setAttr(fir::getAccessGroupsAttrName(), accessGroups); +} + +void hlfir::genNoAliasArrayAssignment( + mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity rhs, + hlfir::Entity lhs, bool emitWorkshareLoop, bool temporaryLHS, + std::function<hlfir::Entity(mlir::Location, fir::FirOpBuilder &, + hlfir::Entity, hlfir::Entity)> *combiner, + mlir::ArrayAttr accessGroups) { + mlir::OpBuilder::InsertionGuard guard(builder); + rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs); + lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs); + mlir::Value lhsShape = hlfir::genShape(loc, builder, lhs); + llvm::SmallVector<mlir::Value> extents = + hlfir::getIndexExtents(loc, builder, lhsShape); + if (rhs.isArray()) { + mlir::Value rhsShape = hlfir::genShape(loc, builder, rhs); + llvm::SmallVector<mlir::Value> rhsExtents = + hlfir::getIndexExtents(loc, builder, rhsShape); + extents = fir::factory::deduceOptimalExtents(extents, rhsExtents); + } + hlfir::LoopNest loopNest = + hlfir::genLoopNest(loc, builder, extents, + /*isUnordered=*/true, emitWorkshareLoop); + builder.setInsertionPointToStart(loopNest.body); + auto rhsArrayElement = + hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices); + rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement); + auto lhsArrayElement = + hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices); + combineAndStoreElement(loc, builder, lhsArrayElement, rhsArrayElement, + temporaryLHS, combiner, accessGroups); +} + +void hlfir::genNoAliasAssignment( + mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity rhs, + hlfir::Entity lhs, bool emitWorkshareLoop, bool temporaryLHS, + std::function<hlfir::Entity(mlir::Location, fir::FirOpBuilder &, + hlfir::Entity, hlfir::Entity)> *combiner, + mlir::ArrayAttr accessGroups) { + if (lhs.isArray()) { + genNoAliasArrayAssignment(loc, builder, rhs, lhs, emitWorkshareLoop, + temporaryLHS, combiner, accessGroups); + return; + } + rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs); + lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs); + combineAndStoreElement(loc, builder, lhs, rhs, temporaryLHS, combiner, + accessGroups); +} + std::pair<hlfir::Entity, bool> hlfir::createTempFromMold(mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity mold) { @@ -1624,25 +1697,38 @@ hlfir::genExtentsVector(mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity hlfir::gen1DSection(mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity array, int64_t dim, - mlir::ArrayRef<mlir::Value> lbounds, mlir::ArrayRef<mlir::Value> extents, mlir::ValueRange oneBasedIndices, mlir::ArrayRef<mlir::Value> typeParams) { assert(array.isVariable() && "array must be a variable"); assert(dim > 0 && dim <= array.getRank() && "invalid dim number"); + llvm::SmallVector<mlir::Value> lbounds = + getNonDefaultLowerBounds(loc, builder, array); mlir::Value one = builder.createIntegerConstant(loc, builder.getIndexType(), 1); hlfir::DesignateOp::Subscripts subscripts; unsigned indexId = 0; for (int i = 0; i < array.getRank(); ++i) { if (i == dim - 1) { - mlir::Value ubound = genUBound(loc, builder, lbounds[i], extents[i], one); - subscripts.emplace_back( - hlfir::DesignateOp::Triplet{lbounds[i], ubound, one}); + // (...,:, ..) + if (lbounds.empty()) { + subscripts.emplace_back( + hlfir::DesignateOp::Triplet{one, extents[i], one}); + } else { + mlir::Value ubound = + genUBound(loc, builder, lbounds[i], extents[i], one); + subscripts.emplace_back( + hlfir::DesignateOp::Triplet{lbounds[i], ubound, one}); + } } else { - mlir::Value index = - genUBound(loc, builder, lbounds[i], oneBasedIndices[indexId++], one); - subscripts.emplace_back(index); + // (...,lb + one_based_index - 1, ..) + if (lbounds.empty()) { + subscripts.emplace_back(oneBasedIndices[indexId++]); + } else { + mlir::Value index = genUBound(loc, builder, lbounds[i], + oneBasedIndices[indexId++], one); + subscripts.emplace_back(index); + } } } mlir::Value sectionShape = @@ -1710,9 +1796,10 @@ bool hlfir::isSimplyContiguous(mlir::Value base, bool checkWhole) { return false; return mlir::TypeSwitch<mlir::Operation *, bool>(def) - .Case<fir::EmboxOp>( - [&](auto op) { return fir::isContiguousEmbox(op, checkWhole); }) - .Case<fir::ReboxOp>([&](auto op) { + .Case([&](fir::EmboxOp op) { + return fir::isContiguousEmbox(op, checkWhole); + }) + .Case([&](fir::ReboxOp op) { hlfir::Entity box{op.getBox()}; return fir::reboxPreservesContinuity( op, box.mayHaveNonDefaultLowerBounds(), checkWhole) && @@ -1721,7 +1808,7 @@ bool hlfir::isSimplyContiguous(mlir::Value base, bool checkWhole) { .Case<fir::DeclareOp, hlfir::DeclareOp>([&](auto op) { return isSimplyContiguous(op.getMemref(), checkWhole); }) - .Case<fir::ConvertOp>( - [&](auto op) { return isSimplyContiguous(op.getValue()); }) + .Case( + [&](fir::ConvertOp op) { return isSimplyContiguous(op.getValue()); }) .Default([](auto &&) { return false; }); } diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index ec0c802..d3c6739 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -15,7 +15,9 @@ #include "flang/Optimizer/Builder/IntrinsicCall.h" #include "flang/Common/static-multimap-view.h" +#include "flang/Lower/AbstractConverter.h" #include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Optimizer/Builder/CUDAIntrinsicCall.h" #include "flang/Optimizer/Builder/CUFCommon.h" #include "flang/Optimizer/Builder/Character.h" #include "flang/Optimizer/Builder/Complex.h" @@ -90,6 +92,11 @@ static bool isStaticallyAbsent(llvm::ArrayRef<mlir::Value> args, size_t argIndex) { return args.size() <= argIndex || !args[argIndex]; } +static bool isOptional(mlir::Value value) { + auto varIface = mlir::dyn_cast_or_null<fir::FortranVariableOpInterface>( + value.getDefiningOp()); + return varIface && varIface.isOptional(); +} /// Test if an ExtendedValue is present. This is used to test if an intrinsic /// argument is present at compile time. This does not imply that the related @@ -107,34 +114,6 @@ using I = IntrinsicLibrary; /// argument is an optional variable in the current scope). static constexpr bool handleDynamicOptional = true; -/// TODO: Move all CUDA Fortran intrinsic handlers into its own file similar to -/// PPC. -static const char __ldca_i4x4[] = "__ldca_i4x4_"; -static const char __ldca_i8x2[] = "__ldca_i8x2_"; -static const char __ldca_r2x2[] = "__ldca_r2x2_"; -static const char __ldca_r4x4[] = "__ldca_r4x4_"; -static const char __ldca_r8x2[] = "__ldca_r8x2_"; -static const char __ldcg_i4x4[] = "__ldcg_i4x4_"; -static const char __ldcg_i8x2[] = "__ldcg_i8x2_"; -static const char __ldcg_r2x2[] = "__ldcg_r2x2_"; -static const char __ldcg_r4x4[] = "__ldcg_r4x4_"; -static const char __ldcg_r8x2[] = "__ldcg_r8x2_"; -static const char __ldcs_i4x4[] = "__ldcs_i4x4_"; -static const char __ldcs_i8x2[] = "__ldcs_i8x2_"; -static const char __ldcs_r2x2[] = "__ldcs_r2x2_"; -static const char __ldcs_r4x4[] = "__ldcs_r4x4_"; -static const char __ldcs_r8x2[] = "__ldcs_r8x2_"; -static const char __ldcv_i4x4[] = "__ldcv_i4x4_"; -static const char __ldcv_i8x2[] = "__ldcv_i8x2_"; -static const char __ldcv_r2x2[] = "__ldcv_r2x2_"; -static const char __ldcv_r4x4[] = "__ldcv_r4x4_"; -static const char __ldcv_r8x2[] = "__ldcv_r8x2_"; -static const char __ldlu_i4x4[] = "__ldlu_i4x4_"; -static const char __ldlu_i8x2[] = "__ldlu_i8x2_"; -static const char __ldlu_r2x2[] = "__ldlu_r2x2_"; -static const char __ldlu_r4x4[] = "__ldlu_r4x4_"; -static const char __ldlu_r8x2[] = "__ldlu_r8x2_"; - /// Table that drives the fir generation depending on the intrinsic or intrinsic /// module procedure one to one mapping with Fortran arguments. If no mapping is /// defined here for a generic intrinsic, genRuntimeCall will be called @@ -143,106 +122,6 @@ static const char __ldlu_r8x2[] = "__ldlu_r8x2_"; /// argument must not be lowered by value. In which case, the lowering rules /// should be provided for all the intrinsic arguments for completeness. static constexpr IntrinsicHandler handlers[]{ - {"__ldca_i4x4", - &I::genCUDALDXXFunc<__ldca_i4x4, 4>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldca_i8x2", - &I::genCUDALDXXFunc<__ldca_i8x2, 2>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldca_r2x2", - &I::genCUDALDXXFunc<__ldca_r2x2, 2>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldca_r4x4", - &I::genCUDALDXXFunc<__ldca_r4x4, 4>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldca_r8x2", - &I::genCUDALDXXFunc<__ldca_r8x2, 2>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldcg_i4x4", - &I::genCUDALDXXFunc<__ldcg_i4x4, 4>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldcg_i8x2", - &I::genCUDALDXXFunc<__ldcg_i8x2, 2>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldcg_r2x2", - &I::genCUDALDXXFunc<__ldcg_r2x2, 2>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldcg_r4x4", - &I::genCUDALDXXFunc<__ldcg_r4x4, 4>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldcg_r8x2", - &I::genCUDALDXXFunc<__ldcg_r8x2, 2>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldcs_i4x4", - &I::genCUDALDXXFunc<__ldcs_i4x4, 4>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldcs_i8x2", - &I::genCUDALDXXFunc<__ldcs_i8x2, 2>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldcs_r2x2", - &I::genCUDALDXXFunc<__ldcs_r2x2, 2>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldcs_r4x4", - &I::genCUDALDXXFunc<__ldcs_r4x4, 4>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldcs_r8x2", - &I::genCUDALDXXFunc<__ldcs_r8x2, 2>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldcv_i4x4", - &I::genCUDALDXXFunc<__ldcv_i4x4, 4>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldcv_i8x2", - &I::genCUDALDXXFunc<__ldcv_i8x2, 2>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldcv_r2x2", - &I::genCUDALDXXFunc<__ldcv_r2x2, 2>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldcv_r4x4", - &I::genCUDALDXXFunc<__ldcv_r4x4, 4>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldcv_r8x2", - &I::genCUDALDXXFunc<__ldcv_r8x2, 2>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldlu_i4x4", - &I::genCUDALDXXFunc<__ldlu_i4x4, 4>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldlu_i8x2", - &I::genCUDALDXXFunc<__ldlu_i8x2, 2>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldlu_r2x2", - &I::genCUDALDXXFunc<__ldlu_r2x2, 2>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldlu_r4x4", - &I::genCUDALDXXFunc<__ldlu_r4x4, 4>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, - {"__ldlu_r8x2", - &I::genCUDALDXXFunc<__ldlu_r8x2, 2>, - {{{"a", asAddr}}}, - /*isElemental=*/false}, {"abort", &I::genAbort}, {"abs", &I::genAbs}, {"achar", &I::genChar}, @@ -262,10 +141,6 @@ static constexpr IntrinsicHandler handlers[]{ &I::genAll, {{{"mask", asAddr}, {"dim", asValue}}}, /*isElemental=*/false}, - {"all_sync", - &I::genVoteSync<mlir::NVVM::VoteSyncKind::all>, - {{{"mask", asValue}, {"pred", asValue}}}, - /*isElemental=*/false}, {"allocated", &I::genAllocated, {{{"array", asInquired}, {"scalar", asInquired}}}, @@ -275,10 +150,6 @@ static constexpr IntrinsicHandler handlers[]{ &I::genAny, {{{"mask", asAddr}, {"dim", asValue}}}, /*isElemental=*/false}, - {"any_sync", - &I::genVoteSync<mlir::NVVM::VoteSyncKind::any>, - {{{"mask", asValue}, {"pred", asValue}}}, - /*isElemental=*/false}, {"asind", &I::genAsind}, {"asinpi", &I::genAsinpi}, {"associated", @@ -289,75 +160,6 @@ static constexpr IntrinsicHandler handlers[]{ {"atan2pi", &I::genAtanpi}, {"atand", &I::genAtand}, {"atanpi", &I::genAtanpi}, - {"atomicaddd", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicaddf", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicaddi", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicaddl", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicandi", &I::genAtomicAnd, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomiccasd", - &I::genAtomicCas, - {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}}, - false}, - {"atomiccasf", - &I::genAtomicCas, - {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}}, - false}, - {"atomiccasi", - &I::genAtomicCas, - {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}}, - false}, - {"atomiccasul", - &I::genAtomicCas, - {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}}, - false}, - {"atomicdeci", &I::genAtomicDec, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicexchd", - &I::genAtomicExch, - {{{"a", asAddr}, {"v", asValue}}}, - false}, - {"atomicexchf", - &I::genAtomicExch, - {{{"a", asAddr}, {"v", asValue}}}, - false}, - {"atomicexchi", - &I::genAtomicExch, - {{{"a", asAddr}, {"v", asValue}}}, - false}, - {"atomicexchul", - &I::genAtomicExch, - {{{"a", asAddr}, {"v", asValue}}}, - false}, - {"atomicinci", &I::genAtomicInc, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicmaxd", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicmaxf", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicmaxi", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicmaxl", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicmind", &I::genAtomicMin, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicminf", &I::genAtomicMin, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicmini", &I::genAtomicMin, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicminl", &I::genAtomicMin, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicori", &I::genAtomicOr, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicsubd", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicsubf", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicsubi", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicsubl", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"atomicxori", &I::genAtomicXor, {{{"a", asAddr}, {"v", asValue}}}, false}, - {"ballot_sync", - &I::genVoteSync<mlir::NVVM::VoteSyncKind::ballot>, - {{{"mask", asValue}, {"pred", asValue}}}, - /*isElemental=*/false}, - {"barrier_arrive", - &I::genBarrierArrive, - {{{"barrier", asAddr}}}, - /*isElemental=*/false}, - {"barrier_arrive_cnt", - &I::genBarrierArriveCnt, - {{{"barrier", asAddr}, {"count", asValue}}}, - /*isElemental=*/false}, - {"barrier_init", - &I::genBarrierInit, - {{{"barrier", asAddr}, {"count", asValue}}}, - /*isElemental=*/false}, {"bessel_jn", &I::genBesselJn, {{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}}, @@ -391,6 +193,12 @@ static constexpr IntrinsicHandler handlers[]{ &I::genCFProcPointer, {{{"cptr", asValue}, {"fptr", asInquired}}}, /*isElemental=*/false}, + {"c_f_strpointer", + &I::genCFStrPointer, + {{{"cstrptr_or_cstrarray", asValue}, + {"fstrptr", asInquired}, + {"nchars", asValue, handleDynamicOptional}}}, + /*isElemental=*/false}, {"c_funloc", &I::genCFunLoc, {{{"x", asBox}}}, /*isElemental=*/false}, {"c_loc", &I::genCLoc, {{{"x", asBox}}}, /*isElemental=*/false}, {"c_ptr_eq", &I::genCPtrCompare<mlir::arith::CmpIPredicate::eq>}, @@ -401,11 +209,6 @@ static constexpr IntrinsicHandler handlers[]{ &I::genChdir, {{{"name", asAddr}, {"status", asAddr, handleDynamicOptional}}}, /*isElemental=*/false}, - {"clock", &I::genNVVMTime<mlir::NVVM::ClockOp>, {}, /*isElemental=*/false}, - {"clock64", - &I::genNVVMTime<mlir::NVVM::Clock64Op>, - {}, - /*isElemental=*/false}, {"cmplx", &I::genCmplx, {{{"x", asValue}, {"y", asValue, handleDynamicOptional}}}}, @@ -502,9 +305,9 @@ static constexpr IntrinsicHandler handlers[]{ &I::genExtendsTypeOf, {{{"a", asBox}, {"mold", asBox}}}, /*isElemental=*/false}, - {"fence_proxy_async", - &I::genFenceProxyAsync, - {}, + {"f_c_string", + &I::genFCString, + {{{"string", asAddr}, {"asis", asValue, handleDynamicOptional}}}, /*isElemental=*/false}, {"findloc", &I::genFindloc, @@ -516,6 +319,10 @@ static constexpr IntrinsicHandler handlers[]{ {"back", asValue, handleDynamicOptional}}}, /*isElemental=*/false}, {"floor", &I::genFloor}, + {"flush", + &I::genFlush, + {{{"unit", asAddr}}}, + /*isElemental=*/false}, {"fraction", &I::genFraction}, {"free", &I::genFree}, {"fseek", @@ -553,6 +360,10 @@ static constexpr IntrinsicHandler handlers[]{ {"trim_name", asAddr, handleDynamicOptional}, {"errmsg", asBox, handleDynamicOptional}}}, /*isElemental=*/false}, + {"get_team", + &I::genGetTeam, + {{{"level", asValue, handleDynamicOptional}}}, + /*isElemental=*/false}, {"getcwd", &I::genGetCwd, {{{"c", asBox}, {"status", asAddr, handleDynamicOptional}}}, @@ -560,10 +371,6 @@ static constexpr IntrinsicHandler handlers[]{ {"getgid", &I::genGetGID}, {"getpid", &I::genGetPID}, {"getuid", &I::genGetUID}, - {"globaltimer", - &I::genNVVMTime<mlir::NVVM::GlobalTimerOp>, - {}, - /*isElemental=*/false}, {"hostnm", &I::genHostnm, {{{"c", asBox}, {"status", asAddr, handleDynamicOptional}}}, @@ -703,6 +510,10 @@ static constexpr IntrinsicHandler handlers[]{ {"dim", asValue}, {"mask", asBox, handleDynamicOptional}}}, /*isElemental=*/false}, + {"irand", + &I::genIrand, + {{{"i", asAddr, handleDynamicOptional}}}, + /*isElemental=*/false}, {"is_contiguous", &I::genIsContiguous, {{{"array", asBox}}}, @@ -731,38 +542,6 @@ static constexpr IntrinsicHandler handlers[]{ {"malloc", &I::genMalloc}, {"maskl", &I::genMask<mlir::arith::ShLIOp>}, {"maskr", &I::genMask<mlir::arith::ShRUIOp>}, - {"match_all_syncjd", - &I::genMatchAllSync, - {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}}, - /*isElemental=*/false}, - {"match_all_syncjf", - &I::genMatchAllSync, - {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}}, - /*isElemental=*/false}, - {"match_all_syncjj", - &I::genMatchAllSync, - {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}}, - /*isElemental=*/false}, - {"match_all_syncjx", - &I::genMatchAllSync, - {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}}, - /*isElemental=*/false}, - {"match_any_syncjd", - &I::genMatchAnySync, - {{{"mask", asValue}, {"value", asValue}}}, - /*isElemental=*/false}, - {"match_any_syncjf", - &I::genMatchAnySync, - {{{"mask", asValue}, {"value", asValue}}}, - /*isElemental=*/false}, - {"match_any_syncjj", - &I::genMatchAnySync, - {{{"mask", asValue}, {"value", asValue}}}, - /*isElemental=*/false}, - {"match_any_syncjx", - &I::genMatchAnySync, - {{{"mask", asValue}, {"value", asValue}}}, - /*isElemental=*/false}, {"matmul", &I::genMatmul, {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}}, @@ -861,6 +640,10 @@ static constexpr IntrinsicHandler handlers[]{ &I::genPutenv, {{{"str", asAddr}, {"status", asAddr, handleDynamicOptional}}}, /*isElemental=*/false}, + {"rand", + &I::genRand, + {{{"i", asAddr, handleDynamicOptional}}}, + /*isElemental=*/false}, {"random_init", &I::genRandomInit, {{{"repeatable", asValue}, {"image_distinct", asValue}}}, @@ -955,6 +738,10 @@ static constexpr IntrinsicHandler handlers[]{ {"shifta", &I::genShiftA}, {"shiftl", &I::genShift<mlir::arith::ShLIOp>}, {"shiftr", &I::genShift<mlir::arith::ShRUIOp>}, + {"show_descriptor", + &I::genShowDescriptor, + {{{"d", asInquired}}}, + /*isElemental=*/false}, {"sign", &I::genSign}, {"signal", &I::genSignalSubroutine, @@ -988,11 +775,6 @@ static constexpr IntrinsicHandler handlers[]{ {"dim", asValue}, {"mask", asBox, handleDynamicOptional}}}, /*isElemental=*/false}, - {"syncthreads", &I::genSyncThreads, {}, /*isElemental=*/false}, - {"syncthreads_and", &I::genSyncThreadsAnd, {}, /*isElemental=*/false}, - {"syncthreads_count", &I::genSyncThreadsCount, {}, /*isElemental=*/false}, - {"syncthreads_or", &I::genSyncThreadsOr, {}, /*isElemental=*/false}, - {"syncwarp", &I::genSyncWarp, {}, /*isElemental=*/false}, {"system", &I::genSystem, {{{"command", asBox}, {"exitstat", asBox, handleDynamicOptional}}}, @@ -1003,38 +785,17 @@ static constexpr IntrinsicHandler handlers[]{ /*isElemental=*/false}, {"tand", &I::genTand}, {"tanpi", &I::genTanpi}, - {"this_grid", &I::genThisGrid, {}, /*isElemental=*/false}, + {"team_number", + &I::genTeamNumber, + {{{"team", asBox, handleDynamicOptional}}}, + /*isElemental=*/false}, {"this_image", &I::genThisImage, {{{"coarray", asBox}, {"dim", asAddr}, {"team", asBox, handleDynamicOptional}}}, /*isElemental=*/false}, - {"this_thread_block", &I::genThisThreadBlock, {}, /*isElemental=*/false}, - {"this_warp", &I::genThisWarp, {}, /*isElemental=*/false}, - {"threadfence", &I::genThreadFence, {}, /*isElemental=*/false}, - {"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false}, - {"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false}, {"time", &I::genTime, {}, /*isElemental=*/false}, - {"tma_bulk_commit_group", - &I::genTMABulkCommitGroup, - {{}}, - /*isElemental=*/false}, - {"tma_bulk_g2s", - &I::genTMABulkG2S, - {{{"barrier", asAddr}, - {"src", asAddr}, - {"dst", asAddr}, - {"nbytes", asValue}}}, - /*isElemental=*/false}, - {"tma_bulk_s2g", - &I::genTMABulkS2G, - {{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}}, - /*isElemental=*/false}, - {"tma_bulk_wait_group", - &I::genTMABulkWaitGroup, - {{}}, - /*isElemental=*/false}, {"trailz", &I::genTrailz}, {"transfer", &I::genTransfer, @@ -1758,8 +1519,10 @@ static constexpr MathOperation mathOperations[] = { genComplexMathOp<mlir::complex::SinOp>}, {"sin", RTNAME_STRING(CSinF128), FuncTypeComplex16Complex16, genLibF128Call}, - {"sinh", "sinhf", genFuncType<Ty::Real<4>, Ty::Real<4>>, genLibCall}, - {"sinh", "sinh", genFuncType<Ty::Real<8>, Ty::Real<8>>, genLibCall}, + {"sinh", "sinhf", genFuncType<Ty::Real<4>, Ty::Real<4>>, + genMathOp<mlir::math::SinhOp>}, + {"sinh", "sinh", genFuncType<Ty::Real<8>, Ty::Real<8>>, + genMathOp<mlir::math::SinhOp>}, {"sinh", RTNAME_STRING(SinhF128), FuncTypeReal16Real16, genLibF128Call}, {"sinh", "csinhf", genFuncType<Ty::Complex<4>, Ty::Complex<4>>, genLibCall}, {"sinh", "csinh", genFuncType<Ty::Complex<8>, Ty::Complex<8>>, genLibCall}, @@ -2124,6 +1887,9 @@ lookupIntrinsicHandler(fir::FirOpBuilder &builder, if (isPPCTarget) if (const IntrinsicHandler *ppcHandler = findPPCIntrinsicHandler(name)) return std::make_optional<IntrinsicHandlerEntry>(ppcHandler); + // TODO: Look for CUDA intrinsic handlers only if CUDA is enabled. + if (const IntrinsicHandler *cudaHandler = findCUDAIntrinsicHandler(name)) + return std::make_optional<IntrinsicHandlerEntry>(cudaHandler); // Subroutines should have a handler. if (!resultType) return std::nullopt; @@ -3010,157 +2776,6 @@ mlir::Value IntrinsicLibrary::genAtanpi(mlir::Type resultType, return mlir::arith::MulFOp::create(builder, loc, atan, factor); } -static mlir::Value genAtomBinOp(fir::FirOpBuilder &builder, mlir::Location &loc, - mlir::LLVM::AtomicBinOp binOp, mlir::Value arg0, - mlir::Value arg1) { - auto llvmPointerType = mlir::LLVM::LLVMPointerType::get(builder.getContext()); - arg0 = builder.createConvert(loc, llvmPointerType, arg0); - return mlir::LLVM::AtomicRMWOp::create(builder, loc, binOp, arg0, arg1, - mlir::LLVM::AtomicOrdering::seq_cst); -} - -mlir::Value IntrinsicLibrary::genAtomicAdd(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 2); - - mlir::LLVM::AtomicBinOp binOp = - mlir::isa<mlir::IntegerType>(args[1].getType()) - ? mlir::LLVM::AtomicBinOp::add - : mlir::LLVM::AtomicBinOp::fadd; - return genAtomBinOp(builder, loc, binOp, args[0], args[1]); -} - -mlir::Value IntrinsicLibrary::genAtomicSub(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 2); - - mlir::LLVM::AtomicBinOp binOp = - mlir::isa<mlir::IntegerType>(args[1].getType()) - ? mlir::LLVM::AtomicBinOp::sub - : mlir::LLVM::AtomicBinOp::fsub; - return genAtomBinOp(builder, loc, binOp, args[0], args[1]); -} - -mlir::Value IntrinsicLibrary::genAtomicAnd(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 2); - assert(mlir::isa<mlir::IntegerType>(args[1].getType())); - - mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::_and; - return genAtomBinOp(builder, loc, binOp, args[0], args[1]); -} - -mlir::Value IntrinsicLibrary::genAtomicOr(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 2); - assert(mlir::isa<mlir::IntegerType>(args[1].getType())); - - mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::_or; - return genAtomBinOp(builder, loc, binOp, args[0], args[1]); -} - -// ATOMICCAS -fir::ExtendedValue -IntrinsicLibrary::genAtomicCas(mlir::Type resultType, - llvm::ArrayRef<fir::ExtendedValue> args) { - assert(args.size() == 3); - auto successOrdering = mlir::LLVM::AtomicOrdering::acq_rel; - auto failureOrdering = mlir::LLVM::AtomicOrdering::monotonic; - auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(resultType.getContext()); - - mlir::Value arg0 = fir::getBase(args[0]); - mlir::Value arg1 = fir::getBase(args[1]); - mlir::Value arg2 = fir::getBase(args[2]); - - auto bitCastFloat = [&](mlir::Value arg) -> mlir::Value { - if (mlir::isa<mlir::Float32Type>(arg.getType())) - return mlir::LLVM::BitcastOp::create(builder, loc, builder.getI32Type(), - arg); - if (mlir::isa<mlir::Float64Type>(arg.getType())) - return mlir::LLVM::BitcastOp::create(builder, loc, builder.getI64Type(), - arg); - return arg; - }; - - arg1 = bitCastFloat(arg1); - arg2 = bitCastFloat(arg2); - - if (arg1.getType() != arg2.getType()) { - // arg1 and arg2 need to have the same type in AtomicCmpXchgOp. - arg2 = builder.createConvert(loc, arg1.getType(), arg2); - } - - auto address = - mlir::UnrealizedConversionCastOp::create(builder, loc, llvmPtrTy, arg0) - .getResult(0); - auto cmpxchg = mlir::LLVM::AtomicCmpXchgOp::create( - builder, loc, address, arg1, arg2, successOrdering, failureOrdering); - return mlir::LLVM::ExtractValueOp::create(builder, loc, cmpxchg, 1); -} - -mlir::Value IntrinsicLibrary::genAtomicDec(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 2); - assert(mlir::isa<mlir::IntegerType>(args[1].getType())); - - mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::udec_wrap; - return genAtomBinOp(builder, loc, binOp, args[0], args[1]); -} - -// ATOMICEXCH -fir::ExtendedValue -IntrinsicLibrary::genAtomicExch(mlir::Type resultType, - llvm::ArrayRef<fir::ExtendedValue> args) { - assert(args.size() == 2); - mlir::Value arg0 = fir::getBase(args[0]); - mlir::Value arg1 = fir::getBase(args[1]); - assert(arg1.getType().isIntOrFloat()); - - mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::xchg; - return genAtomBinOp(builder, loc, binOp, arg0, arg1); -} - -mlir::Value IntrinsicLibrary::genAtomicInc(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 2); - assert(mlir::isa<mlir::IntegerType>(args[1].getType())); - - mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::uinc_wrap; - return genAtomBinOp(builder, loc, binOp, args[0], args[1]); -} - -mlir::Value IntrinsicLibrary::genAtomicMax(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 2); - - mlir::LLVM::AtomicBinOp binOp = - mlir::isa<mlir::IntegerType>(args[1].getType()) - ? mlir::LLVM::AtomicBinOp::max - : mlir::LLVM::AtomicBinOp::fmax; - return genAtomBinOp(builder, loc, binOp, args[0], args[1]); -} - -mlir::Value IntrinsicLibrary::genAtomicMin(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 2); - - mlir::LLVM::AtomicBinOp binOp = - mlir::isa<mlir::IntegerType>(args[1].getType()) - ? mlir::LLVM::AtomicBinOp::min - : mlir::LLVM::AtomicBinOp::fmin; - return genAtomBinOp(builder, loc, binOp, args[0], args[1]); -} - -// ATOMICXOR -fir::ExtendedValue -IntrinsicLibrary::genAtomicXor(mlir::Type resultType, - llvm::ArrayRef<fir::ExtendedValue> args) { - assert(args.size() == 2); - mlir::Value arg0 = fir::getBase(args[0]); - mlir::Value arg1 = fir::getBase(args[1]); - return genAtomBinOp(builder, loc, mlir::LLVM::AtomicBinOp::_xor, arg0, arg1); -} - // ASSOCIATED fir::ExtendedValue IntrinsicLibrary::genAssociated(mlir::Type resultType, @@ -3212,63 +2827,6 @@ IntrinsicLibrary::genAssociated(mlir::Type resultType, return fir::runtime::genAssociated(builder, loc, pointerBox, targetBox); } -static mlir::Value convertPtrToNVVMSpace(fir::FirOpBuilder &builder, - mlir::Location loc, - mlir::Value barrier, - mlir::NVVM::NVVMMemorySpace space) { - mlir::Value llvmPtr = fir::ConvertOp::create( - builder, loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()), - barrier); - mlir::Value addrCast = mlir::LLVM::AddrSpaceCastOp::create( - builder, loc, - mlir::LLVM::LLVMPointerType::get(builder.getContext(), - static_cast<unsigned>(space)), - llvmPtr); - return addrCast; -} - -// BARRIER_ARRIVE (CUDA) -mlir::Value -IntrinsicLibrary::genBarrierArrive(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 1); - mlir::Value barrier = convertPtrToNVVMSpace( - builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared); - return mlir::NVVM::MBarrierArriveSharedOp::create(builder, loc, resultType, - barrier) - .getResult(); -} - -// BARRIER_ARRIBVE_CNT (CUDA) -mlir::Value -IntrinsicLibrary::genBarrierArriveCnt(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 2); - mlir::Value barrier = convertPtrToNVVMSpace( - builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared); - mlir::Value token = fir::AllocaOp::create(builder, loc, resultType); - // TODO: the MBarrierArriveExpectTxOp is not taking the state argument and - // currently just the sink symbol `_`. - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive - mlir::NVVM::MBarrierArriveExpectTxOp::create(builder, loc, barrier, args[1], - {}); - return fir::LoadOp::create(builder, loc, token); -} - -// BARRIER_INIT (CUDA) -void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) { - assert(args.size() == 2); - mlir::Value barrier = convertPtrToNVVMSpace( - builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared); - mlir::NVVM::MBarrierInitSharedOp::create(builder, loc, barrier, - fir::getBase(args[1]), {}); - auto kind = mlir::NVVM::ProxyKindAttr::get( - builder.getContext(), mlir::NVVM::ProxyKind::async_shared); - auto space = mlir::NVVM::SharedSpaceAttr::get( - builder.getContext(), mlir::NVVM::SharedSpace::shared_cta); - mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space); -} - // BESSEL_JN fir::ExtendedValue IntrinsicLibrary::genBesselJn(mlir::Type resultType, @@ -3516,11 +3074,23 @@ static mlir::Value getAddrFromBox(fir::FirOpBuilder &builder, return addr; } +static void clocDeviceArgRewrite(fir::ExtendedValue arg) { + // Special case for device address in c_loc. + if (auto emboxOp = mlir::dyn_cast_or_null<fir::EmboxOp>( + fir::getBase(arg).getDefiningOp())) + if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>( + emboxOp.getMemref().getDefiningOp())) + if (declareOp.getDataAttr() && + declareOp.getDataAttr() == cuf::DataAttribute::Device) + emboxOp.getMemrefMutable().assign(declareOp.getMemref()); +} + static fir::ExtendedValue genCLocOrCFunLoc(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args, bool isFunc = false, bool isDevLoc = false) { assert(args.size() == 1); + clocDeviceArgRewrite(args[0]); mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); mlir::Value resAddr; if (isDevLoc) @@ -3686,6 +3256,99 @@ void IntrinsicLibrary::genCFProcPointer( fir::StoreOp::create(builder, loc, cptrBox, fptr); } +// C_F_STRPOINTER +void IntrinsicLibrary::genCFStrPointer( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 3); + + mlir::Value cStrAddr; + mlir::Value strLen; + + const mlir::Value firstArg = fir::getBase(args[0]); + const mlir::Type firstArgType = fir::unwrapRefType(firstArg.getType()); + const bool isCstrptr = mlir::isa<fir::RecordType>(firstArgType); + + if (isCstrptr) { + // CSTRPTR form: Extract address from C_PTR + cStrAddr = fir::factory::genCPtrOrCFunptrValue(builder, loc, firstArg); + + assert(isStaticallyPresent(args[2])); + mlir::Value nchars = fir::getBase(args[2]); + if (fir::isa_ref_type(nchars.getType())) { + strLen = fir::LoadOp::create(builder, loc, nchars); + } else { + strLen = nchars; + } + } else { + // CSTRARRAY form: Get address from CHARACTER array + if (const auto boxCharTy = + mlir::dyn_cast<fir::BoxCharType>(firstArg.getType())) { + const auto charTy = mlir::cast<fir::CharacterType>(boxCharTy.getEleTy()); + const auto addrTy = builder.getRefType(charTy); + auto unboxed = fir::UnboxCharOp::create( + builder, loc, mlir::TypeRange{addrTy, builder.getIndexType()}, + firstArg); + cStrAddr = unboxed.getResult(0); + } else if (mlir::isa<fir::BoxType>(firstArg.getType())) { + cStrAddr = fir::BoxAddrOp::create(builder, loc, firstArg); + } else { + cStrAddr = firstArg; + } + + // Handle optional NCHARS argument + if (isStaticallyPresent(args[2])) { + mlir::Value nchars = fir::getBase(args[2]); + if (fir::isa_ref_type(nchars.getType())) { + strLen = fir::LoadOp::create(builder, loc, nchars); + } else { + strLen = nchars; + } + } else { + const mlir::Type i8PtrTy = builder.getRefType(builder.getIntegerType(8)); + const mlir::Value strPtr = builder.createConvert(loc, i8PtrTy, cStrAddr); + + const mlir::Type i64Ty = builder.getIntegerType(64); + const mlir::FunctionType strlenType = + mlir::FunctionType::get(builder.getContext(), {i8PtrTy}, {i64Ty}); + + mlir::func::FuncOp strlenFunc = builder.getNamedFunction("strlen"); + if (!strlenFunc) { + strlenFunc = builder.createFunction(loc, "strlen", strlenType); + strlenFunc->setAttr( + fir::getSymbolAttrName(), + mlir::StringAttr::get(builder.getContext(), "strlen")); + } + auto call = fir::CallOp::create(builder, loc, strlenFunc, {strPtr}); + strLen = call.getResult(0); + } + } + + // Handle FSTRPTR (second argument) + const auto *fStrPtr = args[1].getBoxOf<fir::MutableBoxValue>(); + assert(fStrPtr && "FSTRPTR must be a pointer"); + + const mlir::Value lenIdx = + builder.createConvert(loc, builder.getIndexType(), strLen); + + const mlir::Type charPtrType = fir::PointerType::get(fir::CharacterType::get( + builder.getContext(), 1, fir::CharacterType::unknownLen())); + const mlir::Value charPtr = builder.createConvert(loc, charPtrType, cStrAddr); + + const fir::CharBoxValue charBox{charPtr, lenIdx}; + fir::factory::associateMutableBox(builder, loc, *fStrPtr, charBox, + /*lbounds=*/mlir::ValueRange{}); + + // CUDA synchronization if needed + if (auto declare = mlir::dyn_cast_or_null<hlfir::DeclareOp>( + fStrPtr->getAddr().getDefiningOp())) + if (declare.getMemref().getDefiningOp() && + mlir::isa<fir::AddrOfOp>(declare.getMemref().getDefiningOp())) + if (cuf::isRegisteredDeviceAttr(declare.getDataAttr()) && + !cuf::isCUDADeviceContext(builder.getRegion())) + fir::runtime::cuda::genSyncGlobalDescriptor(builder, loc, + declare.getMemref()); +} + // C_FUNLOC fir::ExtendedValue IntrinsicLibrary::genCFunLoc(mlir::Type resultType, @@ -3990,30 +3653,6 @@ IntrinsicLibrary::genCshift(mlir::Type resultType, return readAndAddCleanUp(resultMutableBox, resultType, "CSHIFT"); } -// __LDCA, __LDCS, __LDLU, __LDCV -template <const char *fctName, int extent> -fir::ExtendedValue -IntrinsicLibrary::genCUDALDXXFunc(mlir::Type resultType, - llvm::ArrayRef<fir::ExtendedValue> args) { - assert(args.size() == 1); - mlir::Type resTy = fir::SequenceType::get(extent, resultType); - mlir::Value arg = fir::getBase(args[0]); - mlir::Value res = fir::AllocaOp::create(builder, loc, resTy); - if (mlir::isa<fir::BaseBoxType>(arg.getType())) - arg = fir::BoxAddrOp::create(builder, loc, arg); - mlir::Type refResTy = fir::ReferenceType::get(resTy); - mlir::FunctionType ftype = - mlir::FunctionType::get(arg.getContext(), {refResTy, refResTy}, {}); - auto funcOp = builder.createFunction(loc, fctName, ftype); - llvm::SmallVector<mlir::Value> funcArgs; - funcArgs.push_back(res); - funcArgs.push_back(arg); - fir::CallOp::create(builder, loc, funcOp, funcArgs); - mlir::Value ext = - builder.createIntegerConstant(loc, builder.getIndexType(), extent); - return fir::ArrayBoxValue(res, {ext}); -} - // DATE_AND_TIME void IntrinsicLibrary::genDateAndTime(llvm::ArrayRef<fir::ExtendedValue> args) { assert(args.size() == 4 && "date_and_time has 4 args"); @@ -4317,9 +3956,6 @@ void IntrinsicLibrary::genExit(llvm::ArrayRef<fir::ExtendedValue> args) { EXIT_SUCCESS) : fir::getBase(args[0]); - assert(status.getType() == builder.getDefaultIntegerType() && - "STATUS parameter must be an INTEGER of default kind"); - fir::runtime::genExit(builder, loc, status); } @@ -4346,15 +3982,30 @@ IntrinsicLibrary::genExtendsTypeOf(mlir::Type resultType, fir::getBase(args[1]))); } -// FENCE_PROXY_ASYNC (CUDA) -void IntrinsicLibrary::genFenceProxyAsync( - llvm::ArrayRef<fir::ExtendedValue> args) { - assert(args.size() == 0); - auto kind = mlir::NVVM::ProxyKindAttr::get( - builder.getContext(), mlir::NVVM::ProxyKind::async_shared); - auto space = mlir::NVVM::SharedSpaceAttr::get( - builder.getContext(), mlir::NVVM::SharedSpace::shared_cta); - mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space); +// F_C_STRING +fir::ExtendedValue +IntrinsicLibrary::genFCString(mlir::Type resultType, + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() >= 1 && args.size() <= 2); + + mlir::Value string = builder.createBox(loc, args[0]); + + // Handle optional ASIS argument + mlir::Value asis = isStaticallyAbsent(args, 1) + ? builder.createBool(loc, false) + : fir::getBase(args[1]); + + // Create mutable fir.box to be passed to the runtime for the result. + fir::MutableBoxValue resultMutableBox = + fir::factory::createTempMutableBox(builder, loc, resultType); + mlir::Value resultIrBox = + fir::factory::getMutableIRBox(builder, loc, resultMutableBox); + + fir::runtime::genFCString(builder, loc, resultIrBox, string, asis); + + // Read result from mutable fir.box and add it to the list of temps to be + // finalized by the StatementContext. + return readAndAddCleanUp(resultMutableBox, resultType, "F_C_STRING"); } // FINDLOC @@ -4439,6 +4090,40 @@ mlir::Value IntrinsicLibrary::genFloor(mlir::Type resultType, return builder.createConvert(loc, resultType, floor); } +// FLUSH +void IntrinsicLibrary::genFlush(llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 1); + + mlir::Value unit; + if (isStaticallyAbsent(args[0])) + // Give a sentinal value of `-1` on the `()` case. + unit = builder.createIntegerConstant(loc, builder.getI32Type(), -1); + else { + unit = fir::getBase(args[0]); + if (isOptional(unit)) { + mlir::Value isPresent = + fir::IsPresentOp::create(builder, loc, builder.getI1Type(), unit); + unit = builder + .genIfOp(loc, builder.getI32Type(), isPresent, + /*withElseRegion=*/true) + .genThen([&]() { + mlir::Value loaded = fir::LoadOp::create(builder, loc, unit); + fir::ResultOp::create(builder, loc, loaded); + }) + .genElse([&]() { + mlir::Value negOne = builder.createIntegerConstant( + loc, builder.getI32Type(), -1); + fir::ResultOp::create(builder, loc, negOne); + }) + .getResults()[0]; + } else { + unit = fir::LoadOp::create(builder, loc, unit); + } + } + + fir::runtime::genFlush(builder, loc, unit); +} + // FRACTION mlir::Value IntrinsicLibrary::genFraction(mlir::Type resultType, llvm::ArrayRef<mlir::Value> args) { @@ -4518,6 +4203,15 @@ IntrinsicLibrary::genFtell(std::optional<mlir::Type> resultType, } } +// GET_TEAM +mlir::Value IntrinsicLibrary::genGetTeam(mlir::Type resultType, + llvm::ArrayRef<mlir::Value> args) { + converter->checkCoarrayEnabled(); + assert(args.size() == 1); + return mif::GetTeamOp::create(builder, loc, fir::BoxType::get(resultType), + /*level*/ args[0]); +} + // GETCWD fir::ExtendedValue IntrinsicLibrary::genGetCwd(std::optional<mlir::Type> resultType, @@ -6603,6 +6297,20 @@ IntrinsicLibrary::genIparity(mlir::Type resultType, "IPARITY", resultType, args); } +// IRAND +fir::ExtendedValue +IntrinsicLibrary::genIrand(mlir::Type resultType, + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 1); + mlir::Value i = + isStaticallyPresent(args[0]) + ? fir::getBase(args[0]) + : fir::AbsentOp::create(builder, loc, + builder.getRefType(builder.getI32Type())) + .getResult(); + return fir::runtime::genIrand(builder, loc, i); +} + // IS_CONTIGUOUS fir::ExtendedValue IntrinsicLibrary::genIsContiguous(mlir::Type resultType, @@ -6786,12 +6494,6 @@ IntrinsicLibrary::genCharacterCompare(mlir::Type resultType, fir::getBase(args[1]), fir::getLen(args[1])); } -static bool isOptional(mlir::Value value) { - auto varIface = mlir::dyn_cast_or_null<fir::FortranVariableOpInterface>( - value.getDefiningOp()); - return varIface && varIface.isOptional(); -} - // LOC fir::ExtendedValue IntrinsicLibrary::genLoc(mlir::Type resultType, @@ -6867,67 +6569,6 @@ mlir::Value IntrinsicLibrary::genMask(mlir::Type resultType, return result; } -// MATCH_ALL_SYNC -mlir::Value -IntrinsicLibrary::genMatchAllSync(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 3); - bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32(); - - mlir::Type i1Ty = builder.getI1Type(); - mlir::MLIRContext *context = builder.getContext(); - - mlir::Value arg1 = args[1]; - if (arg1.getType().isF32() || arg1.getType().isF64()) - arg1 = fir::ConvertOp::create( - builder, loc, is32 ? builder.getI32Type() : builder.getI64Type(), arg1); - - mlir::Type retTy = - mlir::LLVM::LLVMStructType::getLiteral(context, {resultType, i1Ty}); - auto match = - mlir::NVVM::MatchSyncOp::create(builder, loc, retTy, args[0], arg1, - mlir::NVVM::MatchSyncKind::all) - .getResult(); - auto value = mlir::LLVM::ExtractValueOp::create(builder, loc, match, 0); - auto pred = mlir::LLVM::ExtractValueOp::create(builder, loc, match, 1); - auto conv = mlir::LLVM::ZExtOp::create(builder, loc, resultType, pred); - fir::StoreOp::create(builder, loc, conv, args[2]); - return value; -} - -// ALL_SYNC, ANY_SYNC, BALLOT_SYNC -template <mlir::NVVM::VoteSyncKind kind> -mlir::Value IntrinsicLibrary::genVoteSync(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 2); - mlir::Value arg1 = - fir::ConvertOp::create(builder, loc, builder.getI1Type(), args[1]); - mlir::Type resTy = kind == mlir::NVVM::VoteSyncKind::ballot - ? builder.getI32Type() - : builder.getI1Type(); - auto voteRes = - mlir::NVVM::VoteSyncOp::create(builder, loc, resTy, args[0], arg1, kind) - .getResult(); - return fir::ConvertOp::create(builder, loc, resultType, voteRes); -} - -// MATCH_ANY_SYNC -mlir::Value -IntrinsicLibrary::genMatchAnySync(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 2); - bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32(); - - mlir::Value arg1 = args[1]; - if (arg1.getType().isF32() || arg1.getType().isF64()) - arg1 = fir::ConvertOp::create( - builder, loc, is32 ? builder.getI32Type() : builder.getI64Type(), arg1); - - return mlir::NVVM::MatchSyncOp::create(builder, loc, resultType, args[0], - arg1, mlir::NVVM::MatchSyncKind::any) - .getResult(); -} - // MATMUL fir::ExtendedValue IntrinsicLibrary::genMatmul(mlir::Type resultType, @@ -7075,11 +6716,9 @@ static mlir::Value genFastMod(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value IntrinsicLibrary::genMod(mlir::Type resultType, llvm::ArrayRef<mlir::Value> args) { auto mod = builder.getModule(); - bool dontUseFastRealMod = false; - bool canUseApprox = mlir::arith::bitEnumContainsAny( - builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn); - if (auto attr = mod->getAttrOfType<mlir::BoolAttr>("fir.no_fast_real_mod")) - dontUseFastRealMod = attr.getValue(); + bool useFastRealMod = false; + if (auto attr = mod->getAttrOfType<mlir::BoolAttr>("fir.fast_real_mod")) + useFastRealMod = attr.getValue(); assert(args.size() == 2); if (resultType.isUnsignedInteger()) { @@ -7092,7 +6731,7 @@ mlir::Value IntrinsicLibrary::genMod(mlir::Type resultType, if (mlir::isa<mlir::IntegerType>(resultType)) return mlir::arith::RemSIOp::create(builder, loc, args[0], args[1]); - if (resultType.isFloat() && canUseApprox && !dontUseFastRealMod) { + if (resultType.isFloat() && useFastRealMod) { // Treat MOD as an approximate function and code-gen inline code // instead of calling into the Fortran runtime library. return builder.createConvert(loc, resultType, @@ -7545,14 +7184,6 @@ IntrinsicLibrary::genNumImages(mlir::Type resultType, return mif::NumImagesOp::create(builder, loc).getResult(); } -// CLOCK, CLOCK64, GLOBALTIMER -template <typename OpTy> -mlir::Value IntrinsicLibrary::genNVVMTime(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 0 && "expect no arguments"); - return OpTy::create(builder, loc, resultType).getResult(); -} - // PACK fir::ExtendedValue IntrinsicLibrary::genPack(mlir::Type resultType, @@ -7706,6 +7337,19 @@ IntrinsicLibrary::genPutenv(std::optional<mlir::Type> resultType, return {}; } +// RAND +fir::ExtendedValue +IntrinsicLibrary::genRand(mlir::Type, llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 1); + mlir::Value i = + isStaticallyPresent(args[0]) + ? fir::getBase(args[0]) + : fir::AbsentOp::create(builder, loc, + builder.getRefType(builder.getI32Type())) + .getResult(); + return fir::runtime::genRand(builder, loc, i); +} + // RANDOM_INIT void IntrinsicLibrary::genRandomInit(llvm::ArrayRef<fir::ExtendedValue> args) { assert(args.size() == 2); @@ -8371,6 +8015,47 @@ mlir::Value IntrinsicLibrary::genShiftA(mlir::Type resultType, return result; } +void IntrinsicLibrary::genShowDescriptor( + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 1 && "expected single argument for show_descriptor"); + const mlir::Value arg = fir::getBase(args[0]); + + // Use consistent !fir.ref<!fir.box<none>> argument type + auto targetType = fir::BoxType::get(builder.getNoneType()); + auto targetRefType = fir::ReferenceType::get(targetType); + + mlir::Value descrAddr = nullptr; + if (fir::isBoxAddress(arg.getType())) { + // If it's already a reference to a box, convert it to correct type and + // pass it directly + descrAddr = builder.createConvert(loc, targetRefType, arg); + } else { + // At this point, arg is either SSA descriptor or a non-descriptor entity. + // If necessary, wrap non-descriptor entity in a descriptor. + mlir::Value descriptor = nullptr; + if (fir::isa_box_type(arg.getType())) { + descriptor = arg; + } else if (fir::isa_ref_type(arg.getType())) { + // Note: here use full extended value args[0] + descriptor = builder.createBox(loc, args[0]); + } else { + // arg is a value (e.g. constant), spill it to a temporary + // because createBox expects a memory reference. + mlir::Value temp = builder.createTemporary(loc, arg.getType()); + builder.createStoreWithConvert(loc, arg, temp); + + // Note: here use full extended value args[0] + descriptor = builder.createBox(loc, fir::substBase(args[0], temp)); + } + + // Spill it to the stack + descrAddr = builder.createTemporary(loc, targetType); + builder.createStoreWithConvert(loc, descriptor, descrAddr); + } + + fir::runtime::genShowDescriptor(builder, loc, descrAddr); +} + // SIGNAL void IntrinsicLibrary::genSignalSubroutine( llvm::ArrayRef<fir::ExtendedValue> args) { @@ -8527,90 +8212,16 @@ mlir::Value IntrinsicLibrary::genTanpi(mlir::Type resultType, return getRuntimeCallGenerator("tan", ftype)(builder, loc, {arg}); } -// THIS_GRID -mlir::Value IntrinsicLibrary::genThisGrid(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 0); - auto recTy = mlir::cast<fir::RecordType>(resultType); - assert(recTy && "RecordType expepected"); - mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); - mlir::Type i32Ty = builder.getI32Type(); +// TEAM_NUMBER +fir::ExtendedValue +IntrinsicLibrary::genTeamNumber(mlir::Type resultType, + llvm::ArrayRef<fir::ExtendedValue> args) { + converter->checkCoarrayEnabled(); + assert(args.size() == 1); - mlir::Value threadIdX = mlir::NVVM::ThreadIdXOp::create(builder, loc, i32Ty); - mlir::Value threadIdY = mlir::NVVM::ThreadIdYOp::create(builder, loc, i32Ty); - mlir::Value threadIdZ = mlir::NVVM::ThreadIdZOp::create(builder, loc, i32Ty); - - mlir::Value blockIdX = mlir::NVVM::BlockIdXOp::create(builder, loc, i32Ty); - mlir::Value blockIdY = mlir::NVVM::BlockIdYOp::create(builder, loc, i32Ty); - mlir::Value blockIdZ = mlir::NVVM::BlockIdZOp::create(builder, loc, i32Ty); - - mlir::Value blockDimX = mlir::NVVM::BlockDimXOp::create(builder, loc, i32Ty); - mlir::Value blockDimY = mlir::NVVM::BlockDimYOp::create(builder, loc, i32Ty); - mlir::Value blockDimZ = mlir::NVVM::BlockDimZOp::create(builder, loc, i32Ty); - mlir::Value gridDimX = mlir::NVVM::GridDimXOp::create(builder, loc, i32Ty); - mlir::Value gridDimY = mlir::NVVM::GridDimYOp::create(builder, loc, i32Ty); - mlir::Value gridDimZ = mlir::NVVM::GridDimZOp::create(builder, loc, i32Ty); - - // this_grid.size = ((blockDim.z * gridDim.z) * (blockDim.y * gridDim.y)) * - // (blockDim.x * gridDim.x); - mlir::Value resZ = - mlir::arith::MulIOp::create(builder, loc, blockDimZ, gridDimZ); - mlir::Value resY = - mlir::arith::MulIOp::create(builder, loc, blockDimY, gridDimY); - mlir::Value resX = - mlir::arith::MulIOp::create(builder, loc, blockDimX, gridDimX); - mlir::Value resZY = mlir::arith::MulIOp::create(builder, loc, resZ, resY); - mlir::Value size = mlir::arith::MulIOp::create(builder, loc, resZY, resX); - - // tmp = ((blockIdx.z * gridDim.y * gridDim.x) + (blockIdx.y * gridDim.x)) + - // blockIdx.x; - // this_group.rank = tmp * ((blockDim.x * blockDim.y) * blockDim.z) + - // ((threadIdx.z * blockDim.y) * blockDim.x) + - // (threadIdx.y * blockDim.x) + threadIdx.x + 1; - mlir::Value r1 = - mlir::arith::MulIOp::create(builder, loc, blockIdZ, gridDimY); - mlir::Value r2 = mlir::arith::MulIOp::create(builder, loc, r1, gridDimX); - mlir::Value r3 = - mlir::arith::MulIOp::create(builder, loc, blockIdY, gridDimX); - mlir::Value r2r3 = mlir::arith::AddIOp::create(builder, loc, r2, r3); - mlir::Value tmp = mlir::arith::AddIOp::create(builder, loc, r2r3, blockIdX); - - mlir::Value bXbY = - mlir::arith::MulIOp::create(builder, loc, blockDimX, blockDimY); - mlir::Value bXbYbZ = - mlir::arith::MulIOp::create(builder, loc, bXbY, blockDimZ); - mlir::Value tZbY = - mlir::arith::MulIOp::create(builder, loc, threadIdZ, blockDimY); - mlir::Value tZbYbX = - mlir::arith::MulIOp::create(builder, loc, tZbY, blockDimX); - mlir::Value tYbX = - mlir::arith::MulIOp::create(builder, loc, threadIdY, blockDimX); - mlir::Value rank = mlir::arith::MulIOp::create(builder, loc, tmp, bXbYbZ); - rank = mlir::arith::AddIOp::create(builder, loc, rank, tZbYbX); - rank = mlir::arith::AddIOp::create(builder, loc, rank, tYbX); - rank = mlir::arith::AddIOp::create(builder, loc, rank, threadIdX); - mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1); - rank = mlir::arith::AddIOp::create(builder, loc, rank, one); - - auto sizeFieldName = recTy.getTypeList()[1].first; - mlir::Type sizeFieldTy = recTy.getTypeList()[1].second; - mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext()); - mlir::Value sizeFieldIndex = fir::FieldIndexOp::create( - builder, loc, fieldIndexType, sizeFieldName, recTy, - /*typeParams=*/mlir::ValueRange{}); - mlir::Value sizeCoord = fir::CoordinateOp::create( - builder, loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex); - fir::StoreOp::create(builder, loc, size, sizeCoord); - - auto rankFieldName = recTy.getTypeList()[2].first; - mlir::Type rankFieldTy = recTy.getTypeList()[2].second; - mlir::Value rankFieldIndex = fir::FieldIndexOp::create( - builder, loc, fieldIndexType, rankFieldName, recTy, - /*typeParams=*/mlir::ValueRange{}); - mlir::Value rankCoord = fir::CoordinateOp::create( - builder, loc, builder.getRefType(rankFieldTy), res, rankFieldIndex); - fir::StoreOp::create(builder, loc, rank, rankCoord); - return res; + mlir::Value res = mif::TeamNumberOp::create(builder, loc, + /*team*/ fir::getBase(args[0])); + return builder.createConvert(loc, resultType, res); } // THIS_IMAGE @@ -8628,99 +8239,6 @@ IntrinsicLibrary::genThisImage(mlir::Type resultType, return builder.createConvert(loc, resultType, res); } -// THIS_THREAD_BLOCK -mlir::Value -IntrinsicLibrary::genThisThreadBlock(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 0); - auto recTy = mlir::cast<fir::RecordType>(resultType); - assert(recTy && "RecordType expepected"); - mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); - mlir::Type i32Ty = builder.getI32Type(); - - // this_thread_block%size = blockDim.z * blockDim.y * blockDim.x; - mlir::Value blockDimX = mlir::NVVM::BlockDimXOp::create(builder, loc, i32Ty); - mlir::Value blockDimY = mlir::NVVM::BlockDimYOp::create(builder, loc, i32Ty); - mlir::Value blockDimZ = mlir::NVVM::BlockDimZOp::create(builder, loc, i32Ty); - mlir::Value size = - mlir::arith::MulIOp::create(builder, loc, blockDimZ, blockDimY); - size = mlir::arith::MulIOp::create(builder, loc, size, blockDimX); - - // this_thread_block%rank = ((threadIdx.z * blockDim.y) * blockDim.x) + - // (threadIdx.y * blockDim.x) + threadIdx.x + 1; - mlir::Value threadIdX = mlir::NVVM::ThreadIdXOp::create(builder, loc, i32Ty); - mlir::Value threadIdY = mlir::NVVM::ThreadIdYOp::create(builder, loc, i32Ty); - mlir::Value threadIdZ = mlir::NVVM::ThreadIdZOp::create(builder, loc, i32Ty); - mlir::Value r1 = - mlir::arith::MulIOp::create(builder, loc, threadIdZ, blockDimY); - mlir::Value r2 = mlir::arith::MulIOp::create(builder, loc, r1, blockDimX); - mlir::Value r3 = - mlir::arith::MulIOp::create(builder, loc, threadIdY, blockDimX); - mlir::Value r2r3 = mlir::arith::AddIOp::create(builder, loc, r2, r3); - mlir::Value rank = mlir::arith::AddIOp::create(builder, loc, r2r3, threadIdX); - mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1); - rank = mlir::arith::AddIOp::create(builder, loc, rank, one); - - auto sizeFieldName = recTy.getTypeList()[1].first; - mlir::Type sizeFieldTy = recTy.getTypeList()[1].second; - mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext()); - mlir::Value sizeFieldIndex = fir::FieldIndexOp::create( - builder, loc, fieldIndexType, sizeFieldName, recTy, - /*typeParams=*/mlir::ValueRange{}); - mlir::Value sizeCoord = fir::CoordinateOp::create( - builder, loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex); - fir::StoreOp::create(builder, loc, size, sizeCoord); - - auto rankFieldName = recTy.getTypeList()[2].first; - mlir::Type rankFieldTy = recTy.getTypeList()[2].second; - mlir::Value rankFieldIndex = fir::FieldIndexOp::create( - builder, loc, fieldIndexType, rankFieldName, recTy, - /*typeParams=*/mlir::ValueRange{}); - mlir::Value rankCoord = fir::CoordinateOp::create( - builder, loc, builder.getRefType(rankFieldTy), res, rankFieldIndex); - fir::StoreOp::create(builder, loc, rank, rankCoord); - return res; -} - -// THIS_WARP -mlir::Value IntrinsicLibrary::genThisWarp(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - assert(args.size() == 0); - auto recTy = mlir::cast<fir::RecordType>(resultType); - assert(recTy && "RecordType expepected"); - mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); - mlir::Type i32Ty = builder.getI32Type(); - - // coalesced_group%size = 32 - mlir::Value size = builder.createIntegerConstant(loc, i32Ty, 32); - auto sizeFieldName = recTy.getTypeList()[1].first; - mlir::Type sizeFieldTy = recTy.getTypeList()[1].second; - mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext()); - mlir::Value sizeFieldIndex = fir::FieldIndexOp::create( - builder, loc, fieldIndexType, sizeFieldName, recTy, - /*typeParams=*/mlir::ValueRange{}); - mlir::Value sizeCoord = fir::CoordinateOp::create( - builder, loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex); - fir::StoreOp::create(builder, loc, size, sizeCoord); - - // coalesced_group%rank = threadIdx.x & 31 + 1 - mlir::Value threadIdX = mlir::NVVM::ThreadIdXOp::create(builder, loc, i32Ty); - mlir::Value mask = builder.createIntegerConstant(loc, i32Ty, 31); - mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1); - mlir::Value masked = - mlir::arith::AndIOp::create(builder, loc, threadIdX, mask); - mlir::Value rank = mlir::arith::AddIOp::create(builder, loc, masked, one); - auto rankFieldName = recTy.getTypeList()[2].first; - mlir::Type rankFieldTy = recTy.getTypeList()[2].second; - mlir::Value rankFieldIndex = fir::FieldIndexOp::create( - builder, loc, fieldIndexType, rankFieldName, recTy, - /*typeParams=*/mlir::ValueRange{}); - mlir::Value rankCoord = fir::CoordinateOp::create( - builder, loc, builder.getRefType(rankFieldTy), res, rankFieldIndex); - fir::StoreOp::create(builder, loc, rank, rankCoord); - return res; -} - // TRAILZ mlir::Value IntrinsicLibrary::genTrailz(mlir::Type resultType, llvm::ArrayRef<mlir::Value> args) { @@ -8942,59 +8460,6 @@ IntrinsicLibrary::genSum(mlir::Type resultType, resultType, args); } -// SYNCTHREADS -void IntrinsicLibrary::genSyncThreads(llvm::ArrayRef<fir::ExtendedValue> args) { - mlir::NVVM::Barrier0Op::create(builder, loc); -} - -// SYNCTHREADS_AND -mlir::Value -IntrinsicLibrary::genSyncThreadsAnd(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.and"; - mlir::MLIRContext *context = builder.getContext(); - mlir::FunctionType ftype = - mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); - auto funcOp = builder.createFunction(loc, funcName, ftype); - return fir::CallOp::create(builder, loc, funcOp, args).getResult(0); -} - -// SYNCTHREADS_COUNT -mlir::Value -IntrinsicLibrary::genSyncThreadsCount(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.popc"; - mlir::MLIRContext *context = builder.getContext(); - mlir::FunctionType ftype = - mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); - auto funcOp = builder.createFunction(loc, funcName, ftype); - return fir::CallOp::create(builder, loc, funcOp, args).getResult(0); -} - -// SYNCTHREADS_OR -mlir::Value -IntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType, - llvm::ArrayRef<mlir::Value> args) { - constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.or"; - mlir::MLIRContext *context = builder.getContext(); - mlir::FunctionType ftype = - mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); - auto funcOp = builder.createFunction(loc, funcName, ftype); - return fir::CallOp::create(builder, loc, funcOp, args).getResult(0); -} - -// SYNCWARP -void IntrinsicLibrary::genSyncWarp(llvm::ArrayRef<fir::ExtendedValue> args) { - assert(args.size() == 1); - constexpr llvm::StringLiteral funcName = "llvm.nvvm.bar.warp.sync"; - mlir::Value mask = fir::getBase(args[0]); - mlir::FunctionType funcType = - mlir::FunctionType::get(builder.getContext(), {mask.getType()}, {}); - auto funcOp = builder.createFunction(loc, funcName, funcType); - llvm::SmallVector<mlir::Value> argsList{mask}; - fir::CallOp::create(builder, loc, funcOp, argsList); -} - // SYSTEM fir::ExtendedValue IntrinsicLibrary::genSystem(std::optional<mlir::Type> resultType, @@ -9126,38 +8591,6 @@ IntrinsicLibrary::genTranspose(mlir::Type resultType, return readAndAddCleanUp(resultMutableBox, resultType, "TRANSPOSE"); } -// THREADFENCE -void IntrinsicLibrary::genThreadFence(llvm::ArrayRef<fir::ExtendedValue> args) { - constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.gl"; - mlir::FunctionType funcType = - mlir::FunctionType::get(builder.getContext(), {}, {}); - auto funcOp = builder.createFunction(loc, funcName, funcType); - llvm::SmallVector<mlir::Value> noArgs; - fir::CallOp::create(builder, loc, funcOp, noArgs); -} - -// THREADFENCE_BLOCK -void IntrinsicLibrary::genThreadFenceBlock( - llvm::ArrayRef<fir::ExtendedValue> args) { - constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.cta"; - mlir::FunctionType funcType = - mlir::FunctionType::get(builder.getContext(), {}, {}); - auto funcOp = builder.createFunction(loc, funcName, funcType); - llvm::SmallVector<mlir::Value> noArgs; - fir::CallOp::create(builder, loc, funcOp, noArgs); -} - -// THREADFENCE_SYSTEM -void IntrinsicLibrary::genThreadFenceSystem( - llvm::ArrayRef<fir::ExtendedValue> args) { - constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.sys"; - mlir::FunctionType funcType = - mlir::FunctionType::get(builder.getContext(), {}, {}); - auto funcOp = builder.createFunction(loc, funcName, funcType); - llvm::SmallVector<mlir::Value> noArgs; - fir::CallOp::create(builder, loc, funcOp, noArgs); -} - // TIME mlir::Value IntrinsicLibrary::genTime(mlir::Type resultType, llvm::ArrayRef<mlir::Value> args) { @@ -9166,46 +8599,6 @@ mlir::Value IntrinsicLibrary::genTime(mlir::Type resultType, fir::runtime::genTime(builder, loc)); } -// TMA_BULK_COMMIT_GROUP (CUDA) -void IntrinsicLibrary::genTMABulkCommitGroup( - llvm::ArrayRef<fir::ExtendedValue> args) { - assert(args.size() == 0); - mlir::NVVM::CpAsyncBulkCommitGroupOp::create(builder, loc); -} - -// TMA_BULK_G2S (CUDA) -void IntrinsicLibrary::genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue> args) { - assert(args.size() == 4); - mlir::Value barrier = convertPtrToNVVMSpace( - builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared); - mlir::Value dst = - convertPtrToNVVMSpace(builder, loc, fir::getBase(args[2]), - mlir::NVVM::NVVMMemorySpace::SharedCluster); - mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]), - mlir::NVVM::NVVMMemorySpace::Global); - mlir::NVVM::CpAsyncBulkGlobalToSharedClusterOp::create( - builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {}); -} - -// TMA_BULK_S2G (CUDA) -void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) { - assert(args.size() == 3); - mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[0]), - mlir::NVVM::NVVMMemorySpace::Shared); - mlir::Value dst = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]), - mlir::NVVM::NVVMMemorySpace::Global); - mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create( - builder, loc, dst, src, fir::getBase(args[2]), {}, {}); -} - -// TMA_BULK_WAIT_GROUP (CUDA) -void IntrinsicLibrary::genTMABulkWaitGroup( - llvm::ArrayRef<fir::ExtendedValue> args) { - assert(args.size() == 0); - auto group = builder.getIntegerAttr(builder.getI32Type(), 0); - mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc, group, {}); -} - // TRIM fir::ExtendedValue IntrinsicLibrary::genTrim(mlir::Type resultType, @@ -9620,6 +9013,9 @@ getIntrinsicArgumentLowering(llvm::StringRef specificName) { if (const IntrinsicHandler *ppcHandler = findPPCIntrinsicHandler(name)) if (!ppcHandler->argLoweringRules.hasDefaultRules()) return &ppcHandler->argLoweringRules; + if (const IntrinsicHandler *cudaHandler = findCUDAIntrinsicHandler(name)) + if (!cudaHandler->argLoweringRules.hasDefaultRules()) + return &cudaHandler->argLoweringRules; return nullptr; } diff --git a/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp index 265e268..5a4e517 100644 --- a/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp @@ -15,6 +15,7 @@ #include "flang/Optimizer/Builder/PPCIntrinsicCall.h" #include "flang/Evaluate/common.h" +#include "flang/Lower/AbstractConverter.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/MutableBox.h" #include "mlir/Dialect/Index/IR/IndexOps.h" diff --git a/flang/lib/Optimizer/Builder/Runtime/Allocatable.cpp b/flang/lib/Optimizer/Builder/Runtime/Allocatable.cpp index cc9f828..89f5f45 100644 --- a/flang/lib/Optimizer/Builder/Runtime/Allocatable.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/Allocatable.cpp @@ -86,8 +86,9 @@ void fir::runtime::genAllocatableAllocate(fir::FirOpBuilder &builder, mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType()); errMsg = fir::AbsentOp::create(builder, loc, boxNoneTy).getResult(); } - llvm::SmallVector<mlir::Value> args{ - fir::runtime::createArguments(builder, loc, fTy, desc, asyncObject, - hasStat, errMsg, sourceFile, sourceLine)}; + mlir::Value deviceInit = builder.createBool(loc, false); + llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( + builder, loc, fTy, desc, asyncObject, hasStat, errMsg, sourceFile, + sourceLine, deviceInit)}; fir::CallOp::create(builder, loc, func, args); } diff --git a/flang/lib/Optimizer/Builder/Runtime/Character.cpp b/flang/lib/Optimizer/Builder/Runtime/Character.cpp index 540ecba..e297125 100644 --- a/flang/lib/Optimizer/Builder/Runtime/Character.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/Character.cpp @@ -94,27 +94,34 @@ fir::runtime::genCharCompare(fir::FirOpBuilder &builder, mlir::Location loc, mlir::arith::CmpIPredicate cmp, mlir::Value lhsBuff, mlir::Value lhsLen, mlir::Value rhsBuff, mlir::Value rhsLen) { - mlir::func::FuncOp beginFunc; - switch (discoverKind(lhsBuff.getType())) { + int lhsKind = discoverKind(lhsBuff.getType()); + int rhsKind = discoverKind(rhsBuff.getType()); + if (lhsKind != rhsKind) { + fir::emitFatalError(loc, "runtime does not support comparison of different " + "CHARACTER kind values"); + } + mlir::func::FuncOp func; + switch (lhsKind) { case 1: - beginFunc = fir::runtime::getRuntimeFunc<mkRTKey(CharacterCompareScalar1)>( + func = fir::runtime::getRuntimeFunc<mkRTKey(CharacterCompareScalar1)>( loc, builder); break; case 2: - beginFunc = fir::runtime::getRuntimeFunc<mkRTKey(CharacterCompareScalar2)>( + func = fir::runtime::getRuntimeFunc<mkRTKey(CharacterCompareScalar2)>( loc, builder); break; case 4: - beginFunc = fir::runtime::getRuntimeFunc<mkRTKey(CharacterCompareScalar4)>( + func = fir::runtime::getRuntimeFunc<mkRTKey(CharacterCompareScalar4)>( loc, builder); break; default: - llvm_unreachable("runtime does not support CHARACTER KIND"); + fir::emitFatalError( + loc, "unsupported CHARACTER kind value. Runtime expects 1, 2, or 4."); } - auto fTy = beginFunc.getFunctionType(); + auto fTy = func.getFunctionType(); auto args = fir::runtime::createArguments(builder, loc, fTy, lhsBuff, rhsBuff, lhsLen, rhsLen); - auto tri = fir::CallOp::create(builder, loc, beginFunc, args).getResult(0); + auto tri = fir::CallOp::create(builder, loc, func, args).getResult(0); auto zero = builder.createIntegerConstant(loc, tri.getType(), 0); return mlir::arith::CmpIOp::create(builder, loc, cmp, tri, zero); } @@ -140,6 +147,19 @@ mlir::Value fir::runtime::genCharCompare(fir::FirOpBuilder &builder, rhsBuffer, fir::getLen(rhs)); } +void fir::runtime::genFCString(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value resultBox, mlir::Value stringBox, + mlir::Value asis) { + auto func = fir::runtime::getRuntimeFunc<mkRTKey(FCString)>(loc, builder); + auto fTy = func.getFunctionType(); + auto sourceFile = fir::factory::locationToFilename(builder, loc); + auto sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); + auto args = fir::runtime::createArguments( + builder, loc, fTy, resultBox, stringBox, asis, sourceFile, sourceLine); + fir::CallOp::create(builder, loc, func, args); +} + mlir::Value fir::runtime::genIndex(fir::FirOpBuilder &builder, mlir::Location loc, int kind, mlir::Value stringBase, diff --git a/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp b/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp index 110b1b2..a5f16f8 100644 --- a/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp @@ -137,6 +137,15 @@ void fir::runtime::genEtime(fir::FirOpBuilder &builder, mlir::Location loc, fir::CallOp::create(builder, loc, runtimeFunc, args); } +void fir::runtime::genFlush(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value unit) { + auto runtimeFunc = fir::runtime::getRuntimeFunc<mkRTKey(Flush)>(loc, builder); + llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( + builder, loc, runtimeFunc.getFunctionType(), unit); + + fir::CallOp::create(builder, loc, runtimeFunc, args); +} + void fir::runtime::genFree(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value ptr) { auto runtimeFunc = fir::runtime::getRuntimeFunc<mkRTKey(Free)>(loc, builder); @@ -461,3 +470,34 @@ mlir::Value fir::runtime::genChdir(fir::FirOpBuilder &builder, fir::runtime::createArguments(builder, loc, func.getFunctionType(), name); return fir::CallOp::create(builder, loc, func, args).getResult(0); } + +mlir::Value fir::runtime::genIrand(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value i) { + auto runtimeFunc = fir::runtime::getRuntimeFunc<mkRTKey(Irand)>(loc, builder); + mlir::FunctionType runtimeFuncTy = runtimeFunc.getFunctionType(); + + llvm::SmallVector<mlir::Value> args = + fir::runtime::createArguments(builder, loc, runtimeFuncTy, i); + return fir::CallOp::create(builder, loc, runtimeFunc, args).getResult(0); +} + +mlir::Value fir::runtime::genRand(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value i) { + auto runtimeFunc = fir::runtime::getRuntimeFunc<mkRTKey(Rand)>(loc, builder); + mlir::FunctionType runtimeFuncTy = runtimeFunc.getFunctionType(); + + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, runtimeFuncTy.getInput(2)); + + llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( + builder, loc, runtimeFuncTy, i, sourceFile, sourceLine); + return fir::CallOp::create(builder, loc, runtimeFunc, args).getResult(0); +} + +void fir::runtime::genShowDescriptor(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value descAddr) { + mlir::func::FuncOp func{ + fir::runtime::getRuntimeFunc<mkRTKey(ShowDescriptor)>(loc, builder)}; + fir::CallOp::create(builder, loc, func, descAddr); +} diff --git a/flang/lib/Optimizer/Builder/Runtime/Main.cpp b/flang/lib/Optimizer/Builder/Runtime/Main.cpp index 9ce5e17..2b748de 100644 --- a/flang/lib/Optimizer/Builder/Runtime/Main.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/Main.cpp @@ -74,8 +74,8 @@ void fir::runtime::genMain( mif::InitOp::create(builder, loc); fir::CallOp::create(builder, loc, qqMainFn); - fir::CallOp::create(builder, loc, stopFn); mlir::Value ret = builder.createIntegerConstant(loc, argcTy, 0); + fir::CallOp::create(builder, loc, stopFn); mlir::func::ReturnOp::create(builder, loc, ret); } diff --git a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp index 157d435..343d848 100644 --- a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp @@ -1841,7 +1841,7 @@ mlir::Value fir::runtime::genReduce(fir::FirOpBuilder &builder, assert((fir::isa_real(eleTy) || fir::isa_integer(eleTy) || mlir::isa<fir::LogicalType>(eleTy)) && - "expect real, interger or logical"); + "expect real, integer or logical"); auto [cat, kind] = fir::mlirTypeToCategoryKind(loc, eleTy); mlir::func::FuncOp func; diff --git a/flang/lib/Optimizer/Builder/TemporaryStorage.cpp b/flang/lib/Optimizer/Builder/TemporaryStorage.cpp index 7e329e3..5db40af 100644 --- a/flang/lib/Optimizer/Builder/TemporaryStorage.cpp +++ b/flang/lib/Optimizer/Builder/TemporaryStorage.cpp @@ -258,13 +258,9 @@ void fir::factory::AnyVariableStack::pushValue(mlir::Location loc, fir::FirOpBuilder &builder, mlir::Value variable) { hlfir::Entity entity{variable}; - mlir::Type storageElementType = - hlfir::getFortranElementType(retValueBox.getType()); - auto [box, maybeCleanUp] = - hlfir::convertToBox(loc, builder, entity, storageElementType); + mlir::Value box = + hlfir::genVariableBox(loc, builder, entity, entity.getBoxType()); fir::runtime::genPushDescriptor(loc, builder, opaquePtr, fir::getBase(box)); - if (maybeCleanUp) - (*maybeCleanUp)(); } void fir::factory::AnyVariableStack::resetFetchPosition( diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 70bb43a2..6257017 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -39,6 +39,7 @@ #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/MathToFuncs/MathToFuncs.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" @@ -680,6 +681,22 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> { if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr()) llvmCall.setResAttrsAttr(resAttrs); + if (auto inlineAttr = call.getInlineAttrAttr()) { + llvmCall->removeAttr("inline_attr"); + if (inlineAttr.getValue() == fir::FortranInlineEnum::no_inline) { + llvmCall.setNoInlineAttr(rewriter.getUnitAttr()); + } else if (inlineAttr.getValue() == fir::FortranInlineEnum::inline_hint) { + llvmCall.setInlineHintAttr(rewriter.getUnitAttr()); + } else if (inlineAttr.getValue() == + fir::FortranInlineEnum::always_inline) { + llvmCall.setAlwaysInlineAttr(rewriter.getUnitAttr()); + } + } + + if (std::optional<mlir::ArrayAttr> optionalAccessGroups = + call.getAccessGroups()) + llvmCall.setAccessGroups(*optionalAccessGroups); + if (memAttr) llvmCall.setMemoryEffectsAttr( mlir::cast<mlir::LLVM::MemoryEffectsAttr>(memAttr)); @@ -749,6 +766,44 @@ struct VolatileCastOpConversion } }; +/// Lower `fir.assumed_size_extent` to constant -1 of index type. +struct AssumedSizeExtentOpConversion + : public fir::FIROpConversion<fir::AssumedSizeExtentOp> { + using FIROpConversion::FIROpConversion; + + llvm::LogicalResult + matchAndRewrite(fir::AssumedSizeExtentOp op, OpAdaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); + mlir::Type ity = lowerTy().indexType(); + auto cst = fir::genConstantIndex(loc, ity, rewriter, -1); + rewriter.replaceOp(op, cst.getResult()); + return mlir::success(); + } +}; + +/// Lower `fir.is_assumed_size_extent` to integer equality with -1. +struct IsAssumedSizeExtentOpConversion + : public fir::FIROpConversion<fir::IsAssumedSizeExtentOp> { + using FIROpConversion::FIROpConversion; + + llvm::LogicalResult + matchAndRewrite(fir::IsAssumedSizeExtentOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); + mlir::Value val = adaptor.getVal(); + mlir::Type valTy = val.getType(); + // Create constant -1 of the operand type. + auto negOneAttr = rewriter.getIntegerAttr(valTy, -1); + auto negOne = + mlir::LLVM::ConstantOp::create(rewriter, loc, valTy, negOneAttr); + auto cmp = mlir::LLVM::ICmpOp::create( + rewriter, loc, mlir::LLVM::ICmpPredicate::eq, val, negOne); + rewriter.replaceOp(op, cmp.getResult()); + return mlir::success(); + } +}; + /// convert value of from-type to value of to-type struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> { using FIROpConversion::FIROpConversion; @@ -762,6 +817,60 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> { mlir::ConversionPatternRewriter &rewriter) const override { auto fromFirTy = convert.getValue().getType(); auto toFirTy = convert.getRes().getType(); + + // Handle conversions between pointer-like values and memref descriptors. + // These are produced by FIR-to-MemRef lowering and represent descriptor + // conversion rather than pure value conversions. + if (auto memRefTy = mlir::dyn_cast<mlir::MemRefType>(toFirTy)) { + mlir::Location loc = convert.getLoc(); + mlir::Value basePtr = adaptor.getValue(); + assert(basePtr && "null base pointer"); + + auto [strides, offset] = memRefTy.getStridesAndOffset(); + bool hasStaticLayout = + mlir::ShapedType::isStatic(offset) && + llvm::none_of(strides, mlir::ShapedType::isDynamic); + + auto *firConv = + static_cast<const fir::LLVMTypeConverter *>(this->getTypeConverter()); + assert(firConv && "expected non-null LLVMTypeConverter"); + + if (memRefTy.hasStaticShape() && hasStaticLayout) { + // Static shape and layout: build a fully-populated descriptor. + mlir::Value memrefDesc = mlir::MemRefDescriptor::fromStaticShape( + rewriter, loc, *firConv, memRefTy, basePtr); + rewriter.replaceOp(convert, memrefDesc); + return mlir::success(); + } + + // Dynamic shape or layout: create an LLVM memref descriptor and insert + // the base pointer field, letting the rest of the fields be populated + // by subsequent lowering. + mlir::Type llvmMemRefTy = firConv->convertType(memRefTy); + auto undef = mlir::LLVM::UndefOp::create(rewriter, loc, llvmMemRefTy); + auto insert = + mlir::LLVM::InsertValueOp::create(rewriter, loc, undef, basePtr, 1); + rewriter.replaceOp(convert, insert); + return mlir::success(); + } + + if (auto memRefTy = mlir::dyn_cast<mlir::MemRefType>(fromFirTy)) { + // Legalize conversions *from* memref descriptors to pointer-like values + // by extracting the underlying buffer pointer from the descriptor. + mlir::Location loc = convert.getLoc(); + mlir::Value base = adaptor.getValue(); + auto alignedPtr = + mlir::LLVM::ExtractValueOp::create(rewriter, loc, base, 1); + auto offset = mlir::LLVM::ExtractValueOp::create(rewriter, loc, base, 2); + mlir::Type elementType = + this->getTypeConverter()->convertType(memRefTy.getElementType()); + auto gepOp = mlir::LLVM::GEPOp::create(rewriter, loc, + alignedPtr.getType(), elementType, + alignedPtr, offset.getResult()); + rewriter.replaceOp(convert, gepOp); + return mlir::success(); + } + auto fromTy = convertType(fromFirTy); auto toTy = convertType(toFirTy); mlir::Value op0 = adaptor.getOperands()[0]; @@ -1113,7 +1222,7 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> { mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy); if (auto scaleSize = fir::genAllocationScaleSize(loc, heap.getInType(), ity, rewriter)) - size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize); + size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize); for (mlir::Value opnd : adaptor.getOperands()) size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, integerCast(loc, rewriter, ity, opnd)); @@ -3296,6 +3405,26 @@ private: } }; +/// `fir.prefetch` --> `llvm.prefetch` +struct PrefetchOpConversion : public fir::FIROpConversion<fir::PrefetchOp> { + using FIROpConversion::FIROpConversion; + + llvm::LogicalResult + matchAndRewrite(fir::PrefetchOp prefetch, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::IntegerAttr rw = mlir::IntegerAttr::get(rewriter.getI32Type(), + prefetch.getRwAttr() ? 1 : 0); + mlir::IntegerAttr localityHint = prefetch.getLocalityHintAttr(); + mlir::IntegerAttr cacheType = mlir::IntegerAttr::get( + rewriter.getI32Type(), prefetch.getCacheTypeAttr() ? 1 : 0); + mlir::LLVM::Prefetch::create(rewriter, prefetch.getLoc(), + adaptor.getOperands().front(), rw, + localityHint, cacheType); + rewriter.eraseOp(prefetch); + return mlir::success(); + } +}; + /// `fir.load` --> `llvm.load` struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> { using FIROpConversion::FIROpConversion; @@ -3352,6 +3481,9 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> { loadOp.setTBAATags(*optionalTag); else attachTBAATag(loadOp, load.getType(), load.getType(), nullptr); + if (std::optional<mlir::ArrayAttr> optionalAccessGroups = + load.getAccessGroups()) + loadOp.setAccessGroups(*optionalAccessGroups); rewriter.replaceOp(load, loadOp.getResult()); } return mlir::success(); @@ -3396,6 +3528,20 @@ struct NoReassocOpConversion : public fir::FIROpConversion<fir::NoReassocOp> { } }; +/// Erase `fir.use_stmt` operations during LLVM lowering. +/// These operations are only used for debug info generation by the +/// AddDebugInfo pass and have no runtime representation. +struct UseStmtOpConversion : public fir::FIROpConversion<fir::UseStmtOp> { + using FIROpConversion::FIROpConversion; + + llvm::LogicalResult + matchAndRewrite(fir::UseStmtOp useStmt, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(useStmt); + return mlir::success(); + } +}; + static void genCondBrOp(mlir::Location loc, mlir::Value cmp, mlir::Block *dest, std::optional<mlir::ValueRange> destOps, mlir::ConversionPatternRewriter &rewriter, @@ -3466,6 +3612,11 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> { mlir::Block *dest = caseOp.getSuccessor(t); std::optional<mlir::ValueRange> destOps = caseOp.getSuccessorOperands(adaptor.getOperands(), t); + // Convert block signature if needed + if (destOps && !destOps->empty()) + if (auto conversion = getTypeConverter()->convertBlockSignature(dest)) + dest = rewriter.applySignatureConversion(dest, *conversion, + getTypeConverter()); std::optional<mlir::ValueRange> cmpOps = *caseOp.getCompareOperands(adaptor.getOperands(), t); mlir::Attribute attr = cases[t]; @@ -3683,6 +3834,10 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> { if (store.getNontemporal()) storeOp.setNontemporal(true); + if (std::optional<mlir::ArrayAttr> optionalAccessGroups = + store.getAccessGroups()) + storeOp.setAccessGroups(*optionalAccessGroups); + newOp = storeOp; } if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa()) @@ -4360,6 +4515,7 @@ void fir::populateFIRToLLVMConversionPatterns( AllocaOpConversion, AllocMemOpConversion, BoxAddrOpConversion, BoxCharLenOpConversion, BoxDimsOpConversion, BoxEleSizeOpConversion, BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion, + AssumedSizeExtentOpConversion, IsAssumedSizeExtentOpConversion, BoxOffsetOpConversion, BoxProcHostOpConversion, BoxRankOpConversion, BoxTypeCodeOpConversion, BoxTypeDescOpConversion, CallOpConversion, CmpcOpConversion, VolatileCastOpConversion, ConvertOpConversion, @@ -4372,14 +4528,15 @@ void fir::populateFIRToLLVMConversionPatterns( FirEndOpConversion, FreeMemOpConversion, GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion, IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion, MulcOpConversion, - NegcOpConversion, NoReassocOpConversion, SelectCaseOpConversion, - SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion, - ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion, - SliceOpConversion, StoreOpConversion, StringLitOpConversion, - SubcOpConversion, TypeDescOpConversion, TypeInfoOpConversion, - UnboxCharOpConversion, UnboxProcOpConversion, UndefOpConversion, - UnreachableOpConversion, XArrayCoorOpConversion, XEmboxOpConversion, - XReboxOpConversion, ZeroOpConversion>(converter, options); + NegcOpConversion, NoReassocOpConversion, PrefetchOpConversion, + SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion, + SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion, + ShiftOpConversion, SliceOpConversion, StoreOpConversion, + StringLitOpConversion, SubcOpConversion, TypeDescOpConversion, + TypeInfoOpConversion, UnboxCharOpConversion, UnboxProcOpConversion, + UndefOpConversion, UnreachableOpConversion, UseStmtOpConversion, + XArrayCoorOpConversion, XEmboxOpConversion, XReboxOpConversion, + ZeroOpConversion>(converter, options); // Patterns that are populated without a type converter do not trigger // target materializations for the operands of the root op. diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp index 381b2a2..3e1fe1d 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp @@ -242,10 +242,11 @@ struct TargetAllocMemOpConversion loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout()); if (auto scaleSize = fir::genAllocationScaleSize( loc, allocmemOp.getInType(), ity, rewriter)) - size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize); + size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize); for (mlir::Value opnd : adaptor.getOperands().drop_front()) - size = rewriter.create<mlir::LLVM::MulOp>( - loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd)); + size = mlir::LLVM::MulOp::create( + rewriter, loc, ity, size, + integerCast(lowerTy(), loc, rewriter, ity, opnd)); auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); auto mallocTy = mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); @@ -259,6 +260,21 @@ struct TargetAllocMemOpConversion return mlir::success(); } }; + +struct DeclareMapperOpConversion + : public OpenMPFIROpConversion<mlir::omp::DeclareMapperOp> { + using OpenMPFIROpConversion::OpenMPFIROpConversion; + + llvm::LogicalResult + matchAndRewrite(mlir::omp::DeclareMapperOp curOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.startOpModification(curOp); + curOp.setType(convertObjectType(lowerTy(), curOp.getType())); + rewriter.finalizeOpModification(curOp); + return mlir::success(); + } +}; + } // namespace void fir::populateOpenMPFIRToLLVMConversionPatterns( @@ -266,4 +282,5 @@ void fir::populateOpenMPFIRToLLVMConversionPatterns( patterns.add<MapInfoOpConversion>(converter); patterns.add<PrivateClauseOpConversion>(converter); patterns.add<TargetAllocMemOpConversion>(converter); + patterns.add<DeclareMapperOpConversion>(converter); } diff --git a/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp b/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp index ac432c7..81488d7 100644 --- a/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp +++ b/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp @@ -289,7 +289,6 @@ PackArrayConversion::genRepackedBox(fir::FirOpBuilder &builder, fir::factory::genDimInfoFromBox(builder, loc, box, &lbounds, &extents, /*strides=*/nullptr); // Get the type parameters from the box, if needed. - llvm::SmallVector<mlir::Value> assumedTypeParams; if (numTypeParams != 0) { if (auto charType = mlir::dyn_cast<fir::CharacterType>(boxType.unwrapInnerType())) diff --git a/flang/lib/Optimizer/CodeGen/PassDetail.h b/flang/lib/Optimizer/CodeGen/PassDetail.h index f703013..252da02 100644 --- a/flang/lib/Optimizer/CodeGen/PassDetail.h +++ b/flang/lib/Optimizer/CodeGen/PassDetail.h @@ -18,7 +18,7 @@ namespace fir { -#define GEN_PASS_CLASSES +#define GEN_PASS_DECL #include "flang/Optimizer/CodeGen/CGPasses.h.inc" } // namespace fir diff --git a/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp b/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp index 1b1d43c..3b137d1 100644 --- a/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp @@ -302,11 +302,16 @@ public: else return mlir::failure(); } + // Extract dummy_arg_no attribute if present + mlir::IntegerAttr dummyArgNoAttr; + if (auto attr = declareOp->getAttrOfType<mlir::IntegerAttr>("dummy_arg_no")) + dummyArgNoAttr = attr; // FIXME: Add FortranAttrs and CudaAttrs auto xDeclOp = fir::cg::XDeclareOp::create( rewriter, loc, declareOp.getType(), declareOp.getMemref(), shapeOpers, shiftOpers, declareOp.getTypeparams(), declareOp.getDummyScope(), - declareOp.getUniqName()); + declareOp.getStorage(), declareOp.getStorageOffset(), + declareOp.getUniqName(), dummyArgNoAttr); LLVM_DEBUG(llvm::dbgs() << "rewriting " << declareOp << " to " << xDeclOp << '\n'); rewriter.replaceOp(declareOp, xDeclOp.getOperation()->getResults()); diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp index b60a72e..9b6c9be 100644 --- a/flang/lib/Optimizer/CodeGen/Target.cpp +++ b/flang/lib/Optimizer/CodeGen/Target.cpp @@ -353,7 +353,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> { ArgClass ¤t = byteOffset < 8 ? Lo : Hi; // System V AMD64 ABI 3.2.3. version 1.0 llvm::TypeSwitch<mlir::Type>(type) - .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) { + .Case([&](mlir::IntegerType intTy) { if (intTy.getWidth() == 128) Hi = Lo = ArgClass::Integer; else @@ -371,7 +371,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> { current = ArgClass::SSE; } }) - .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { + .Case([&](mlir::ComplexType cmplx) { const auto *sem = &floatToSemantics(kindMap, cmplx.getElementType()); if (sem == &llvm::APFloat::x87DoubleExtended()) { current = ArgClass::ComplexX87; @@ -382,23 +382,23 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> { byteOffset, Lo, Hi); } }) - .template Case<fir::LogicalType>([&](fir::LogicalType logical) { + .Case([&](fir::LogicalType logical) { if (kindMap.getLogicalBitsize(logical.getFKind()) == 128) Hi = Lo = ArgClass::Integer; else current = ArgClass::Integer; }) - .template Case<fir::CharacterType>( + .Case( [&](fir::CharacterType character) { current = ArgClass::Integer; }) - .template Case<fir::SequenceType>([&](fir::SequenceType seqTy) { + .Case([&](fir::SequenceType seqTy) { // Array component. classifyArray(loc, seqTy, byteOffset, Lo, Hi); }) - .template Case<fir::RecordType>([&](fir::RecordType recTy) { + .Case([&](fir::RecordType recTy) { // Component that is a derived type. classifyStruct(loc, recTy, byteOffset, Lo, Hi); }) - .template Case<fir::VectorType>([&](fir::VectorType vecTy) { + .Case([&](fir::VectorType vecTy) { // Previously marshalled SSE eight byte for a previous struct // argument. auto *sem = fir::isa_real(vecTy.getEleTy()) @@ -939,23 +939,23 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> { NRegs usedRegsForType(mlir::Location loc, mlir::Type type) const { return llvm::TypeSwitch<mlir::Type, NRegs>(type) - .Case<mlir::IntegerType>([&](auto intTy) { + .Case([&](mlir::IntegerType intTy) { return intTy.getWidth() == 128 ? NRegs{2, false} : NRegs{1, false}; }) - .Case<mlir::FloatType>([&](auto) { return NRegs{1, true}; }) - .Case<mlir::ComplexType>([&](auto) { return NRegs{2, true}; }) - .Case<fir::LogicalType>([&](auto) { return NRegs{1, false}; }) - .Case<fir::CharacterType>([&](auto) { return NRegs{1, false}; }) - .Case<fir::SequenceType>([&](auto ty) { + .Case([&](mlir::FloatType) { return NRegs{1, true}; }) + .Case([&](mlir::ComplexType) { return NRegs{2, true}; }) + .Case([&](fir::LogicalType) { return NRegs{1, false}; }) + .Case([&](fir::CharacterType) { return NRegs{1, false}; }) + .Case([&](fir::SequenceType ty) { assert(ty.getShape().size() == 1 && "invalid array dimensions in BIND(C)"); NRegs nregs = usedRegsForType(loc, ty.getEleTy()); nregs.n *= ty.getShape()[0]; return nregs; }) - .Case<fir::RecordType>( - [&](auto ty) { return usedRegsForRecordType(loc, ty); }) - .Case<fir::VectorType>([&](auto) { + .Case( + [&](fir::RecordType ty) { return usedRegsForRecordType(loc, ty); }) + .Case([&](fir::VectorType) { TODO(loc, "passing vector argument to C by value is not supported"); return NRegs{}; }) @@ -1167,13 +1167,12 @@ struct TargetPPC64le : public GenericTarget<TargetPPC64le> { unsigned getElemWidth(mlir::Type ty) const { unsigned width{}; llvm::TypeSwitch<mlir::Type>(ty) - .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { + .Case([&](mlir::ComplexType cmplx) { auto elemType{ mlir::dyn_cast<mlir::FloatType>(cmplx.getElementType())}; width = elemType.getWidth(); }) - .template Case<mlir::FloatType>( - [&](mlir::FloatType real) { width = real.getWidth(); }); + .Case([&](mlir::FloatType real) { width = real.getWidth(); }); return width; } @@ -1594,15 +1593,15 @@ struct TargetLoongArch64 : public GenericTarget<TargetLoongArch64> { llvm::SmallVector<mlir::Type> flatTypes; llvm::TypeSwitch<mlir::Type>(type) - .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) { + .Case([&](mlir::IntegerType intTy) { if (intTy.getWidth() != 0) flatTypes.push_back(intTy); }) - .template Case<mlir::FloatType>([&](mlir::FloatType floatTy) { + .Case([&](mlir::FloatType floatTy) { if (floatTy.getWidth() != 0) flatTypes.push_back(floatTy); }) - .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { + .Case([&](mlir::ComplexType cmplx) { const auto *sem = &floatToSemantics(kindMap, cmplx.getElementType()); if (sem == &llvm::APFloat::IEEEsingle() || sem == &llvm::APFloat::IEEEdouble() || @@ -1614,21 +1613,21 @@ struct TargetLoongArch64 : public GenericTarget<TargetLoongArch64> { "IEEEquad) as a structure component for BIND(C), " "VALUE derived type argument and type return"); }) - .template Case<fir::LogicalType>([&](fir::LogicalType logicalTy) { + .Case([&](fir::LogicalType logicalTy) { const unsigned width = kindMap.getLogicalBitsize(logicalTy.getFKind()); if (width != 0) flatTypes.push_back( mlir::IntegerType::get(type.getContext(), width)); }) - .template Case<fir::CharacterType>([&](fir::CharacterType charTy) { + .Case([&](fir::CharacterType charTy) { assert(kindMap.getCharacterBitsize(charTy.getFKind()) <= 8 && "the bit size of characterType as an interoperable type must " "not exceed 8"); for (unsigned i = 0; i < charTy.getLen(); ++i) flatTypes.push_back(mlir::IntegerType::get(type.getContext(), 8)); }) - .template Case<fir::SequenceType>([&](fir::SequenceType seqTy) { + .Case([&](fir::SequenceType seqTy) { if (!seqTy.hasDynamicExtents()) { const std::uint64_t numOfEle = seqTy.getConstantArraySize(); mlir::Type eleTy = seqTy.getEleTy(); @@ -1646,7 +1645,7 @@ struct TargetLoongArch64 : public GenericTarget<TargetLoongArch64> { "component for BIND(C), " "VALUE derived type argument and type return"); }) - .template Case<fir::RecordType>([&](fir::RecordType recTy) { + .Case([&](fir::RecordType recTy) { for (auto &component : recTy.getTypeList()) { mlir::Type eleTy = component.second; llvm::SmallVector<mlir::Type> subTypeList = @@ -1655,7 +1654,7 @@ struct TargetLoongArch64 : public GenericTarget<TargetLoongArch64> { llvm::copy(subTypeList, std::back_inserter(flatTypes)); } }) - .template Case<fir::VectorType>([&](fir::VectorType vecTy) { + .Case([&](fir::VectorType vecTy) { auto sizeAndAlign = fir::getTypeSizeAndAlignmentOrCrash( loc, vecTy, getDataLayout(), kindMap); if (sizeAndAlign.first == 2 * GRLenInChar) @@ -1742,7 +1741,7 @@ struct TargetLoongArch64 : public GenericTarget<TargetLoongArch64> { return true; llvm::TypeSwitch<mlir::Type>(type) - .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) { + .Case([&](mlir::IntegerType intTy) { const unsigned width = intTy.getWidth(); if (width > 128) TODO(loc, @@ -1754,7 +1753,7 @@ struct TargetLoongArch64 : public GenericTarget<TargetLoongArch64> { else if (width <= 2 * GRLen) GARsLeft = GARsLeft - 2; }) - .template Case<mlir::FloatType>([&](mlir::FloatType floatTy) { + .Case([&](mlir::FloatType floatTy) { const unsigned width = floatTy.getWidth(); if (width > 128) TODO(loc, "floatType with width exceeding 128 bits is unsupported"); diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index ac285b5..3ef4703 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -143,7 +143,8 @@ public: llvm::SmallVector<mlir::Type> operandsTypes; for (auto arg : gpuLaunchFunc.getKernelOperands()) operandsTypes.push_back(arg.getType()); - auto fctTy = mlir::FunctionType::get(&context, operandsTypes, {}); + auto fctTy = mlir::FunctionType::get(&context, operandsTypes, + gpuLaunchFunc.getResultTypes()); if (!hasPortableSignature(fctTy, op)) convertCallOp(gpuLaunchFunc, fctTy); } else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) { @@ -392,12 +393,12 @@ public: if (fnTy.getResults().size() == 1) { mlir::Type ty = fnTy.getResult(0); llvm::TypeSwitch<mlir::Type>(ty) - .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { + .Case([&](mlir::ComplexType cmplx) { wrap = rewriteCallComplexResultType(loc, cmplx, newResTys, newInTyAndAttrs, newOpers, savedStackPtr); }) - .template Case<fir::RecordType>([&](fir::RecordType recTy) { + .Case([&](fir::RecordType recTy) { wrap = rewriteCallStructResultType(loc, recTy, newResTys, newInTyAndAttrs, newOpers, savedStackPtr); @@ -421,7 +422,7 @@ public: mlir::Value oper = std::get<1>(e.value()); unsigned index = e.index(); llvm::TypeSwitch<mlir::Type>(ty) - .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) { + .Case([&](fir::BoxCharType boxTy) { if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { if (noCharacterConversion) { newInTyAndAttrs.push_back( @@ -455,15 +456,15 @@ public: } } }) - .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { + .Case([&](mlir::ComplexType cmplx) { rewriteCallComplexInputType(loc, cmplx, oper, newInTyAndAttrs, newOpers, savedStackPtr); }) - .template Case<fir::RecordType>([&](fir::RecordType recTy) { + .Case([&](fir::RecordType recTy) { rewriteCallStructInputType(loc, recTy, oper, newInTyAndAttrs, newOpers, savedStackPtr); }) - .template Case<mlir::TupleType>([&](mlir::TupleType tuple) { + .Case([&](mlir::TupleType tuple) { if (fir::isCharacterProcedureTuple(tuple)) { mlir::ModuleOp module = getModule(); if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { @@ -520,10 +521,14 @@ public: llvm::SmallVector<mlir::Value, 1> newCallResults; // TODO propagate/update call argument and result attributes. if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) { + mlir::Value asyncToken = callOp.getAsyncToken(); auto newCall = A::create(*rewriter, loc, callOp.getKernel(), callOp.getGridSizeOperandValues(), callOp.getBlockSizeOperandValues(), - callOp.getDynamicSharedMemorySize(), newOpers); + callOp.getDynamicSharedMemorySize(), newOpers, + asyncToken ? asyncToken.getType() : nullptr, + callOp.getAsyncDependencies(), + /*clusterSize=*/std::nullopt); if (callOp.getClusterSizeX()) newCall.getClusterSizeXMutable().assign(callOp.getClusterSizeX()); if (callOp.getClusterSizeY()) @@ -702,10 +707,10 @@ public: auto loc = addrOp.getLoc(); for (mlir::Type ty : addrTy.getResults()) { llvm::TypeSwitch<mlir::Type>(ty) - .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { + .Case([&](mlir::ComplexType ty) { lowerComplexSignatureRes(loc, ty, newResTys, newInTyAndAttrs); }) - .Case<fir::RecordType>([&](fir::RecordType ty) { + .Case([&](fir::RecordType ty) { lowerStructSignatureRes(loc, ty, newResTys, newInTyAndAttrs); }) .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); @@ -713,7 +718,7 @@ public: llvm::SmallVector<mlir::Type> trailingInTys; for (mlir::Type ty : addrTy.getInputs()) { llvm::TypeSwitch<mlir::Type>(ty) - .Case<fir::BoxCharType>([&](auto box) { + .Case([&](fir::BoxCharType box) { if (noCharacterConversion) { newInTyAndAttrs.push_back( fir::CodeGenSpecifics::getTypeAndAttr(box)); @@ -728,10 +733,10 @@ public: } } }) - .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { + .Case([&](mlir::ComplexType ty) { lowerComplexSignatureArg(loc, ty, newInTyAndAttrs); }) - .Case<mlir::TupleType>([&](mlir::TupleType tuple) { + .Case([&](mlir::TupleType tuple) { if (fir::isCharacterProcedureTuple(tuple)) { newInTyAndAttrs.push_back( fir::CodeGenSpecifics::getTypeAndAttr(tuple.getType(0))); @@ -741,7 +746,7 @@ public: fir::CodeGenSpecifics::getTypeAndAttr(ty)); } }) - .template Case<fir::RecordType>([&](fir::RecordType recTy) { + .Case([&](fir::RecordType recTy) { lowerStructSignatureArg(loc, recTy, newInTyAndAttrs); }) .Default([&](mlir::Type ty) { @@ -872,16 +877,24 @@ public: } } + // Count the number of arguments that have to stay in place at the end of + // the argument list. + unsigned trailingArgs = 0; + if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>) { + trailingArgs = + func.getNumWorkgroupAttributions() + func.getNumPrivateAttributions(); + } + // Convert return value(s) for (auto ty : funcTy.getResults()) llvm::TypeSwitch<mlir::Type>(ty) - .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { + .Case([&](mlir::ComplexType cmplx) { if (noComplexConversion) newResTys.push_back(cmplx); else doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups); }) - .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) { + .Case([&](mlir::IntegerType intTy) { auto m = specifics->integerArgumentType(func.getLoc(), intTy); assert(m.size() == 1); auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]); @@ -895,7 +908,7 @@ public: rewriter->getUnitAttr())); newResTys.push_back(retTy); }) - .template Case<fir::RecordType>([&](fir::RecordType recTy) { + .Case([&](fir::RecordType recTy) { doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups); }) .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); @@ -910,7 +923,7 @@ public: auto ty = e.value(); unsigned index = e.index(); llvm::TypeSwitch<mlir::Type>(ty) - .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) { + .Case([&](fir::BoxCharType boxTy) { if (noCharacterConversion) { newInTyAndAttrs.push_back( fir::CodeGenSpecifics::getTypeAndAttr(boxTy)); @@ -933,10 +946,10 @@ public: } } }) - .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { + .Case([&](mlir::ComplexType cmplx) { doComplexArg(func, cmplx, newInTyAndAttrs, fixups); }) - .template Case<mlir::TupleType>([&](mlir::TupleType tuple) { + .Case([&](mlir::TupleType tuple) { if (fir::isCharacterProcedureTuple(tuple)) { fixups.emplace_back(FixupTy::Codes::TrailingCharProc, newInTyAndAttrs.size(), trailingTys.size()); @@ -948,7 +961,7 @@ public: fir::CodeGenSpecifics::getTypeAndAttr(ty)); } }) - .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) { + .Case([&](mlir::IntegerType intTy) { auto m = specifics->integerArgumentType(func.getLoc(), intTy); assert(m.size() == 1); auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]); @@ -965,7 +978,7 @@ public: newInTyAndAttrs.push_back(m[0]); }) - .template Case<fir::RecordType>([&](fir::RecordType recTy) { + .Case([&](fir::RecordType recTy) { doStructArg(func, recTy, newInTyAndAttrs, fixups); }) .Default([&](mlir::Type ty) { @@ -981,6 +994,16 @@ public: } } + // Add the argument at the end if the number of trailing arguments is 0, + // otherwise insert the argument at the appropriate index. + auto addOrInsertArgument = [&](mlir::Type ty, mlir::Location loc) { + unsigned inputIndex = func.front().getArguments().size() - trailingArgs; + auto newArg = trailingArgs == 0 + ? func.front().addArgument(ty, loc) + : func.front().insertArgument(inputIndex, ty, loc); + return newArg; + }; + if (!func.empty()) { // If the function has a body, then apply the fixups to the arguments and // return ops as required. These fixups are done in place. @@ -1117,8 +1140,7 @@ public: // original arguments. (Boxchar arguments.) auto newBufArg = func.front().insertArgument(fixup.index, fixupType, loc); - auto newLenArg = - func.front().addArgument(trailingTys[fixup.second], loc); + auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc); auto boxTy = oldArgTys[fixup.index - offset]; rewriter->setInsertionPointToStart(&func.front()); auto box = fir::EmboxCharOp::create(*rewriter, loc, boxTy, newBufArg, @@ -1133,8 +1155,7 @@ public: // appended after all the original arguments. auto newProcPointerArg = func.front().insertArgument(fixup.index, fixupType, loc); - auto newLenArg = - func.front().addArgument(trailingTys[fixup.second], loc); + auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc); auto tupleType = oldArgTys[fixup.index - offset]; rewriter->setInsertionPointToStart(&func.front()); fir::FirOpBuilder builder(*rewriter, getModule()); diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp index 2283560..3c4162c 100644 --- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp +++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp @@ -163,8 +163,8 @@ LLVMTypeConverter::convertRecordType(fir::RecordType derived, return mlir::success(); } callStack.push_back(derived); - auto popConversionCallStack = - llvm::make_scope_exit([&callStack]() { callStack.pop_back(); }); + llvm::scope_exit popConversionCallStack( + [&callStack]() { callStack.pop_back(); }); llvm::SmallVector<mlir::Type> members; for (auto mem : derived.getTypeList()) { diff --git a/flang/lib/Optimizer/Dialect/CMakeLists.txt b/flang/lib/Optimizer/Dialect/CMakeLists.txt index 65d1f2c..f81989a 100644 --- a/flang/lib/Optimizer/Dialect/CMakeLists.txt +++ b/flang/lib/Optimizer/Dialect/CMakeLists.txt @@ -6,6 +6,7 @@ add_subdirectory(MIF) add_flang_library(FIRDialect FIRAttr.cpp FIRDialect.cpp + FIROperationMoveOpInterface.cpp FIROps.cpp FIRType.cpp FirAliasTagOpInterface.cpp @@ -15,6 +16,7 @@ add_flang_library(FIRDialect DEPENDS CanonicalizationPatternsIncGen + FIROperationMoveOpInterfaceIncGen FIROpsIncGen FIRSafeTempArrayCopyAttrInterfaceIncGen CUFAttrsIncGen diff --git a/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp b/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp index bd0499f..3f58065 100644 --- a/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp +++ b/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp @@ -52,4 +52,18 @@ bool hasDataAttr(mlir::Operation *op, cuf::DataAttribute value) { return false; } +bool isDeviceDataAttribute(cuf::DataAttribute attr) { + return attr == cuf::DataAttribute::Device || + attr == cuf::DataAttribute::Managed || + attr == cuf::DataAttribute::Constant || + attr == cuf::DataAttribute::Shared || + attr == cuf::DataAttribute::Unified; +} + +bool hasDeviceDataAttr(mlir::Operation *op) { + if (auto dataAttr = getDataAttr(op)) + return isDeviceDataAttribute(dataAttr.getValue()); + return false; +} + } // namespace cuf diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp index 687007d..a157c47 100644 --- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp +++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp @@ -274,6 +274,26 @@ llvm::LogicalResult cuf::KernelOp::verify() { return checkStreamType(*this); } +bool cuf::KernelOp::canMoveFromDescendant(mlir::Operation *descendant, + mlir::Operation *candidate) { + // Moving operations out of loops inside cuf.kernel is always legal. + return true; +} + +bool cuf::KernelOp::canMoveOutOf(mlir::Operation *candidate) { + // In general, some movement of operations out of cuf.kernel is allowed. + if (!candidate) + return true; + + // Operations that have !fir.ref operands cannot be moved + // out of cuf.kernel, because this may break implicit data mapping + // passes that may run after LICM. + return !llvm::any_of(candidate->getOperands(), + [&](mlir::Value candidateOperand) { + return fir::isa_ref_type(candidateOperand.getType()); + }); +} + //===----------------------------------------------------------------------===// // RegisterKernelOp //===----------------------------------------------------------------------===// @@ -333,7 +353,8 @@ void cuf::SharedMemoryOp::build( bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); build(builder, result, wrapAllocaResultType(inType), mlir::TypeAttr::get(inType), nameAttr, bindcAttr, typeparams, shape, - /*offset=*/mlir::Value{}); + /*offset=*/mlir::Value{}, /*alignment=*/mlir::IntegerAttr{}, + /*isStatic=*/nullptr); result.addAttributes(attributes); } diff --git a/flang/lib/Optimizer/Dialect/FIROperationMoveOpInterface.cpp b/flang/lib/Optimizer/Dialect/FIROperationMoveOpInterface.cpp new file mode 100644 index 0000000..dcf5323 --- /dev/null +++ b/flang/lib/Optimizer/Dialect/FIROperationMoveOpInterface.cpp @@ -0,0 +1,49 @@ +//===-- FIROperationMoveOpInterface.cpp -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/FIROperationMoveOpInterface.h" + +#include "flang/Optimizer/Dialect/FIROperationMoveOpInterface.cpp.inc" + +llvm::LogicalResult +fir::detail::verifyOperationMoveOpInterface(mlir::Operation *op) { + // It does not make sense to use this interface for operations + // without any regions. + if (op->getNumRegions() == 0) + return op->emitOpError("must contain at least one region"); + return llvm::success(); +} + +bool fir::canMoveFromDescendant(mlir::Operation *op, + mlir::Operation *descendant, + mlir::Operation *candidate) { + // Perform some sanity checks. + assert(op->isProperAncestor(descendant) && + "op must be an ancestor of descendant"); + if (candidate) + assert(descendant->isProperAncestor(candidate) && + "descendant must be an ancestor of candidate"); + if (auto iface = mlir::dyn_cast<OperationMoveOpInterface>(op)) + return iface.canMoveFromDescendant(descendant, candidate); + + return true; +} + +bool fir::canMoveOutOf(mlir::Operation *op, mlir::Operation *candidate) { + if (candidate) + assert(op->isProperAncestor(candidate) && + "op must be an ancestor of candidate"); + if (auto iface = mlir::dyn_cast<OperationMoveOpInterface>(op)) + return iface.canMoveOutOf(candidate); + + return true; +} diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 1712af1..9c22b61 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -174,6 +174,32 @@ static void printAllocatableOp(mlir::OpAsmPrinter &p, OP &op) { p.printOptionalAttrDict(op->getAttrs(), {"in_type", "operandSegmentSizes"}); } +bool fir::mayBeAbsentBox(mlir::Value val) { + assert(mlir::isa<fir::BaseBoxType>(val.getType()) && "expected box argument"); + while (val) { + mlir::Operation *defOp = val.getDefiningOp(); + if (!defOp) + return true; + + if (auto varIface = mlir::dyn_cast<fir::FortranVariableOpInterface>(defOp)) + return varIface.isOptional(); + + // Check for fir.embox and fir.rebox before checking for + // FortranObjectViewOpInterface, which they support. + // A box created by fir.embox/rebox cannot be absent. + if (mlir::isa<fir::ReboxOp, fir::EmboxOp, fir::LoadOp>(defOp)) + return false; + + if (auto viewIface = + mlir::dyn_cast<fir::FortranObjectViewOpInterface>(defOp)) { + val = viewIface.getViewSource(mlir::cast<mlir::OpResult>(val)); + continue; + } + break; + } + return true; +} + //===----------------------------------------------------------------------===// // AllocaOp //===----------------------------------------------------------------------===// @@ -186,6 +212,36 @@ static mlir::Type wrapAllocaResultType(mlir::Type intype) { return fir::ReferenceType::get(intype); } +llvm::SmallVector<mlir::MemorySlot> fir::AllocaOp::getPromotableSlots() { + // TODO: support promotion of dynamic allocas + if (isDynamic()) + return {}; + + return {mlir::MemorySlot{getResult(), getAllocatedType()}}; +} + +mlir::Value fir::AllocaOp::getDefaultValue(const mlir::MemorySlot &slot, + mlir::OpBuilder &builder) { + return fir::UndefOp::create(builder, getLoc(), slot.elemType); +} + +void fir::AllocaOp::handleBlockArgument(const mlir::MemorySlot &slot, + mlir::BlockArgument argument, + mlir::OpBuilder &builder) {} + +std::optional<mlir::PromotableAllocationOpInterface> +fir::AllocaOp::handlePromotionComplete(const mlir::MemorySlot &slot, + mlir::Value defaultValue, + mlir::OpBuilder &builder) { + if (defaultValue && defaultValue.use_empty()) { + assert(mlir::isa<fir::UndefOp>(defaultValue.getDefiningOp()) && + "Expected undef op to be the default value"); + defaultValue.getDefiningOp()->erase(); + } + this->erase(); + return std::nullopt; +} + mlir::Type fir::AllocaOp::getAllocatedType() { return mlir::cast<fir::ReferenceType>(getType()).getEleTy(); } @@ -834,6 +890,11 @@ void fir::ArrayCoorOp::getCanonicalizationPatterns( patterns.add<SimplifyArrayCoorOp>(context); } +std::optional<std::int64_t> fir::ArrayCoorOp::getViewOffset(mlir::OpResult) { + // TODO: we can try to compute the constant offset. + return std::nullopt; +} + //===----------------------------------------------------------------------===// // ArrayLoadOp //===----------------------------------------------------------------------===// @@ -1054,17 +1115,16 @@ void fir::BoxAddrOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value val) { mlir::Type type = llvm::TypeSwitch<mlir::Type, mlir::Type>(val.getType()) - .Case<fir::BaseBoxType>([&](fir::BaseBoxType ty) -> mlir::Type { + .Case([&](fir::BaseBoxType ty) -> mlir::Type { mlir::Type eleTy = ty.getEleTy(); if (fir::isa_ref_type(eleTy)) return eleTy; return fir::ReferenceType::get(eleTy); }) - .Case<fir::BoxCharType>([&](fir::BoxCharType ty) -> mlir::Type { + .Case([&](fir::BoxCharType ty) -> mlir::Type { return fir::ReferenceType::get(ty.getEleTy()); }) - .Case<fir::BoxProcType>( - [&](fir::BoxProcType ty) { return ty.getEleTy(); }) + .Case([&](fir::BoxProcType ty) { return ty.getEleTy(); }) .Default([&](const auto &) { return mlir::Type{}; }); assert(type && "bad val type"); build(builder, result, type, val); @@ -1086,6 +1146,22 @@ mlir::OpFoldResult fir::BoxAddrOp::fold(FoldAdaptor adaptor) { return {}; } +std::optional<std::int64_t> fir::BoxAddrOp::getViewOffset(mlir::OpResult) { + // fir.box_addr just returns the base address stored inside a box, + // so the direct accesses through the base address and through the box + // are not offsetted. + return 0; +} + +mlir::Speculation::Speculatability fir::BoxAddrOp::getSpeculatability() { + // Do not speculate fir.box_addr with BoxProcType and BoxCharType + // inputs. + if (!mlir::isa<fir::BaseBoxType>(getVal().getType())) + return mlir::Speculation::NotSpeculatable; + return mayBeAbsentBox(getVal()) ? mlir::Speculation::NotSpeculatable + : mlir::Speculation::Speculatable; +} + //===----------------------------------------------------------------------===// // BoxCharLenOp //===----------------------------------------------------------------------===// @@ -1110,6 +1186,11 @@ mlir::Type fir::BoxDimsOp::getTupleType() { return mlir::TupleType::get(getContext(), triple); } +mlir::Speculation::Speculatability fir::BoxDimsOp::getSpeculatability() { + return mayBeAbsentBox(getVal()) ? mlir::Speculation::NotSpeculatable + : mlir::Speculation::Speculatable; +} + //===----------------------------------------------------------------------===// // BoxRankOp //===----------------------------------------------------------------------===// @@ -1588,6 +1669,22 @@ llvm::LogicalResult fir::ConvertOp::verify() { << getValue().getType() << " / " << getType(); } +mlir::Speculation::Speculatability fir::ConvertOp::getSpeculatability() { + // fir.convert is speculatable, in general. The only concern may be + // converting from or/and to floating point types, which may trigger + // some FP exceptions. Disallow speculating such converts for the time being. + // Also disallow speculation for converts to/from non-FIR types, except + // for some builtin types. + auto canSpeculateType = [](mlir::Type ty) { + if (fir::isa_fir_type(ty) || fir::isa_integer(ty)) + return true; + return false; + }; + return (canSpeculateType(getValue().getType()) && canSpeculateType(getType())) + ? mlir::Speculation::Speculatable + : mlir::Speculation::NotSpeculatable; +} + //===----------------------------------------------------------------------===// // CoordinateOp //===----------------------------------------------------------------------===// @@ -1627,11 +1724,11 @@ void fir::CoordinateOp::build(mlir::OpBuilder &builder, bool anyField = false; for (fir::IntOrValue index : coor) { llvm::TypeSwitch<fir::IntOrValue>(index) - .Case<mlir::IntegerAttr>([&](mlir::IntegerAttr intAttr) { + .Case([&](mlir::IntegerAttr intAttr) { fieldIndices.push_back(intAttr.getInt()); anyField = true; }) - .Case<mlir::Value>([&](mlir::Value value) { + .Case([&](mlir::Value value) { dynamicIndices.push_back(value); fieldIndices.push_back(fir::CoordinateOp::kDynamicIndex); }); @@ -1654,7 +1751,7 @@ void fir::CoordinateOp::print(mlir::OpAsmPrinter &p) { for (auto index : getIndices()) { p << ", "; llvm::TypeSwitch<fir::IntOrValue>(index) - .Case<mlir::IntegerAttr>([&](mlir::IntegerAttr intAttr) { + .Case([&](mlir::IntegerAttr intAttr) { if (auto recordType = llvm::dyn_cast<fir::RecordType>(eleTy)) { int fieldId = intAttr.getInt(); if (fieldId < static_cast<int>(recordType.getNumFields())) { @@ -1669,7 +1766,7 @@ void fir::CoordinateOp::print(mlir::OpAsmPrinter &p) { // investigated. p << intAttr; }) - .Case<mlir::Value>([&](mlir::Value value) { p << value; }); + .Case([&](mlir::Value value) { p << value; }); } } p.printOptionalAttrDict( @@ -1820,6 +1917,20 @@ fir::CoordinateIndicesAdaptor fir::CoordinateOp::getIndices() { return CoordinateIndicesAdaptor(getFieldIndicesAttr(), getCoor()); } +std::optional<std::int64_t> fir::CoordinateOp::getViewOffset(mlir::OpResult) { + // TODO: we can try to compute the constant offset. + return std::nullopt; +} + +mlir::Speculation::Speculatability fir::CoordinateOp::getSpeculatability() { + const mlir::Type refTy = getRef().getType(); + if (fir::isa_ref_type(refTy)) + return mlir::Speculation::Speculatable; + + return mayBeAbsentBox(getRef()) ? mlir::Speculation::NotSpeculatable + : mlir::Speculation::Speculatable; +} + //===----------------------------------------------------------------------===// // DispatchOp //===----------------------------------------------------------------------===// @@ -2066,6 +2177,20 @@ bool fir::isContiguousEmbox(fir::EmboxOp embox, bool checkWhole) { return false; } +std::optional<std::int64_t> fir::EmboxOp::getViewOffset(mlir::OpResult) { + // The address offset is zero, unless there is a slice. + // TODO: we can handle slices that leave the base address untouched. + if (!getSlice()) + return 0; + return std::nullopt; +} + +mlir::Speculation::Speculatability fir::EmboxOp::getSpeculatability() { + return (getSourceBox() && mayBeAbsentBox(getSourceBox())) + ? mlir::Speculation::NotSpeculatable + : mlir::Speculation::Speculatable; +} + //===----------------------------------------------------------------------===// // EmboxCharOp //===----------------------------------------------------------------------===// @@ -2836,6 +2961,39 @@ llvm::SmallVector<mlir::Attribute> fir::LenParamIndexOp::getAttributes() { // LoadOp //===----------------------------------------------------------------------===// +bool fir::LoadOp::loadsFrom(const mlir::MemorySlot &slot) { + return getMemref() == slot.ptr; +} + +bool fir::LoadOp::storesTo(const mlir::MemorySlot &slot) { return false; } + +mlir::Value fir::LoadOp::getStored(const mlir::MemorySlot &slot, + mlir::OpBuilder &builder, + mlir::Value reachingDef, + const mlir::DataLayout &dataLayout) { + return mlir::Value(); +} + +bool fir::LoadOp::canUsesBeRemoved( + const mlir::MemorySlot &slot, + const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses, + mlir::SmallVectorImpl<mlir::OpOperand *> &newBlockingUses, + const mlir::DataLayout &dataLayout) { + if (blockingUses.size() != 1) + return false; + mlir::Value blockingUse = (*blockingUses.begin())->get(); + return blockingUse == slot.ptr && getMemref() == slot.ptr; +} + +mlir::DeletionKind fir::LoadOp::removeBlockingUses( + const mlir::MemorySlot &slot, + const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses, + mlir::OpBuilder &builder, mlir::Value reachingDefinition, + const mlir::DataLayout &dataLayout) { + getResult().replaceAllUsesWith(reachingDefinition); + return mlir::DeletionKind::Delete; +} + void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value refVal) { if (!refVal) { @@ -3205,11 +3363,19 @@ mlir::ParseResult fir::DTEntryOp::parse(mlir::OpAsmParser &parser, parser.parseAttribute(calleeAttr, fir::DTEntryOp::getProcAttrNameStr(), result.attributes)) return mlir::failure(); + + // Optional "deferred" keyword. + if (succeeded(parser.parseOptionalKeyword("deferred"))) { + result.addAttribute(fir::DTEntryOp::getDeferredAttrNameStr(), + parser.getBuilder().getUnitAttr()); + } return mlir::success(); } void fir::DTEntryOp::print(mlir::OpAsmPrinter &p) { p << ' ' << getMethodAttr() << ", " << getProcAttr(); + if ((*this)->getAttr(fir::DTEntryOp::getDeferredAttrNameStr())) + p << " deferred"; } //===----------------------------------------------------------------------===// @@ -3313,6 +3479,19 @@ llvm::LogicalResult fir::ReboxOp::verify() { return mlir::success(); } +std::optional<std::int64_t> fir::ReboxOp::getViewOffset(mlir::OpResult) { + // The address offset is zero, unless there is a slice. + // TODO: we can handle slices that leave the base address untouched. + if (!getSlice()) + return 0; + return std::nullopt; +} + +mlir::Speculation::Speculatability fir::ReboxOp::getSpeculatability() { + return mayBeAbsentBox(getBox()) ? mlir::Speculation::NotSpeculatable + : mlir::Speculation::Speculatable; +} + //===----------------------------------------------------------------------===// // ReboxAssumedRankOp //===----------------------------------------------------------------------===// @@ -4215,6 +4394,39 @@ llvm::LogicalResult fir::SliceOp::verify() { // StoreOp //===----------------------------------------------------------------------===// +bool fir::StoreOp::loadsFrom(const mlir::MemorySlot &slot) { return false; } + +bool fir::StoreOp::storesTo(const mlir::MemorySlot &slot) { + return getMemref() == slot.ptr; +} + +mlir::Value fir::StoreOp::getStored(const mlir::MemorySlot &slot, + mlir::OpBuilder &builder, + mlir::Value reachingDef, + const mlir::DataLayout &dataLayout) { + return getValue(); +} + +bool fir::StoreOp::canUsesBeRemoved( + const mlir::MemorySlot &slot, + const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses, + mlir::SmallVectorImpl<mlir::OpOperand *> &newBlockingUses, + const mlir::DataLayout &dataLayout) { + if (blockingUses.size() != 1) + return false; + mlir::Value blockingUse = (*blockingUses.begin())->get(); + return blockingUse == slot.ptr && getMemref() == slot.ptr && + getValue() != slot.ptr; +} + +mlir::DeletionKind fir::StoreOp::removeBlockingUses( + const mlir::MemorySlot &slot, + const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses, + mlir::OpBuilder &builder, mlir::Value reachingDefinition, + const mlir::DataLayout &dataLayout) { + return mlir::DeletionKind::Delete; +} + mlir::Type fir::StoreOp::elementType(mlir::Type refType) { return fir::dyn_cast_ptrEleTy(refType); } @@ -4252,7 +4464,7 @@ llvm::LogicalResult fir::StoreOp::verify() { void fir::StoreOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value value, mlir::Value memref) { - build(builder, result, value, memref, {}); + build(builder, result, value, memref, {}, {}, {}); } void fir::StoreOp::getEffects( @@ -4265,6 +4477,84 @@ void fir::StoreOp::getEffects( } //===----------------------------------------------------------------------===// +// PrefetchOp +//===----------------------------------------------------------------------===// + +mlir::ParseResult fir::PrefetchOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::OpAsmParser::UnresolvedOperand memref; + if (parser.parseOperand(memref)) + return mlir::failure(); + + if (mlir::succeeded(parser.parseLBrace())) { + llvm::StringRef kw; + if (parser.parseKeyword(&kw)) + return mlir::failure(); + + if (kw == "read") + result.addAttribute("rw", parser.getBuilder().getBoolAttr(false)); + else if (kw == "write") + result.addAttribute("rw", parser.getBuilder().getUnitAttr()); + else + return parser.emitError(parser.getCurrentLocation(), + "Expected either read or write keyword"); + + if (parser.parseComma()) + return mlir::failure(); + + if (parser.parseKeyword(&kw)) + return mlir::failure(); + if (kw == "instruction") { + result.addAttribute("cacheType", parser.getBuilder().getBoolAttr(false)); + } else if (kw == "data") { + result.addAttribute("cacheType", parser.getBuilder().getUnitAttr()); + } else + return parser.emitError(parser.getCurrentLocation(), + "Expected either intruction or data keyword"); + + if (parser.parseComma()) + return mlir::failure(); + + if (mlir::succeeded(parser.parseKeyword("localityHint"))) { + if (parser.parseEqual()) + return mlir::failure(); + mlir::Attribute intAttr; + if (parser.parseAttribute(intAttr)) + return mlir::failure(); + result.addAttribute("localityHint", intAttr); + } + if (parser.parseRBrace()) + return mlir::failure(); + } + mlir::Type type; + if (parser.parseColonType(type)) + return mlir::failure(); + + if (parser.resolveOperand(memref, type, result.operands)) + return mlir::failure(); + return mlir::success(); +} + +void fir::PrefetchOp::print(mlir::OpAsmPrinter &p) { + p << " "; + p.printOperand(getMemref()); + p << " {"; + if (getRw()) + p << "write"; + else + p << "read"; + p << ", "; + if (getCacheType()) + p << "data"; + else + p << "instruction"; + p << ", localityHint = "; + p << getLocalityHint(); + p << " : " << getLocalityHintAttr().getType(); + p << "} : " << getMemref().getType(); +} + +//===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// @@ -4484,7 +4774,7 @@ void fir::IfOp::getSuccessorRegions( llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) { // The `then` and the `else` region branch back to the parent operation. if (!point.isParent()) { - regions.push_back(mlir::RegionSuccessor(getResults())); + regions.push_back(mlir::RegionSuccessor::parent()); return; } @@ -4494,11 +4784,18 @@ void fir::IfOp::getSuccessorRegions( // Don't consider the else region if it is empty. mlir::Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back(mlir::RegionSuccessor()); + regions.push_back(mlir::RegionSuccessor::parent()); else regions.push_back(mlir::RegionSuccessor(elseRegion)); } +mlir::ValueRange +fir::IfOp::getSuccessorInputs(mlir::RegionSuccessor successor) { + if (successor.isParent()) + return getOperation()->getResults(); + return mlir::ValueRange(); +} + void fir::IfOp::getEntrySuccessorRegions( llvm::ArrayRef<mlir::Attribute> operands, llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) { @@ -4513,7 +4810,7 @@ void fir::IfOp::getEntrySuccessorRegions( if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back(getResults()); + regions.push_back(mlir::RegionSuccessor::parent()); } } @@ -4887,7 +5184,7 @@ bool fir::isDummyArgument(mlir::Value v) { mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) { for (auto i = path.begin(), end = path.end(); eleTy && i < end;) { eleTy = llvm::TypeSwitch<mlir::Type, mlir::Type>(eleTy) - .Case<fir::RecordType>([&](fir::RecordType ty) { + .Case([&](fir::RecordType ty) { if (auto *op = (*i++).getDefiningOp()) { if (auto off = mlir::dyn_cast<fir::FieldIndexOp>(op)) return ty.getType(off.getFieldName()); @@ -4896,7 +5193,7 @@ mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) { } return mlir::Type{}; }) - .Case<fir::SequenceType>([&](fir::SequenceType ty) { + .Case([&](fir::SequenceType ty) { bool valid = true; const auto rank = ty.getDimension(); for (std::remove_const_t<decltype(rank)> ii = 0; @@ -4904,13 +5201,13 @@ mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) { valid = i < end && fir::isa_integer((*i++).getType()); return valid ? ty.getEleTy() : mlir::Type{}; }) - .Case<mlir::TupleType>([&](mlir::TupleType ty) { + .Case([&](mlir::TupleType ty) { if (auto *op = (*i++).getDefiningOp()) if (auto off = mlir::dyn_cast<mlir::arith::ConstantOp>(op)) return ty.getType(fir::toInt(off)); return mlir::Type{}; }) - .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { + .Case([&](mlir::ComplexType ty) { if (fir::isa_integer((*i++).getType())) return ty.getElementType(); return mlir::Type{}; @@ -5143,6 +5440,34 @@ void fir::BoxTotalElementsOp::getCanonicalizationPatterns( } //===----------------------------------------------------------------------===// +// IsAssumedSizeExtentOp and AssumedSizeExtentOp +//===----------------------------------------------------------------------===// + +namespace { +struct FoldIsAssumedSizeExtentOnCtor + : public mlir::OpRewritePattern<fir::IsAssumedSizeExtentOp> { + using mlir::OpRewritePattern<fir::IsAssumedSizeExtentOp>::OpRewritePattern; + mlir::LogicalResult + matchAndRewrite(fir::IsAssumedSizeExtentOp op, + mlir::PatternRewriter &rewriter) const override { + if (llvm::isa_and_nonnull<fir::AssumedSizeExtentOp>( + op.getVal().getDefiningOp())) { + mlir::Type i1 = rewriter.getI1Type(); + rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>( + op, i1, rewriter.getIntegerAttr(i1, 1)); + return mlir::success(); + } + return mlir::failure(); + } +}; +} // namespace + +void fir::IsAssumedSizeExtentOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + patterns.add<FoldIsAssumedSizeExtentOnCtor>(context); +} + +//===----------------------------------------------------------------------===// // LocalitySpecifierOp //===----------------------------------------------------------------------===// diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp index fe35b08..ccdc8e4 100644 --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -183,18 +183,21 @@ struct RecordTypeStorage : public mlir::TypeStorage { bool isPacked() const { return packed; } void pack(bool p) { packed = p; } + bool isSequence() const { return sequence; } + void setSequence(bool s) { sequence = s; } protected: std::string name; bool finalized; bool packed; + bool sequence; std::vector<RecordType::TypePair> lens; std::vector<RecordType::TypePair> types; private: RecordTypeStorage() = delete; explicit RecordTypeStorage(llvm::StringRef name) - : name{name}, finalized{false}, packed{false} {} + : name{name}, finalized{false}, packed{false}, sequence{false} {} }; } // namespace detail @@ -226,8 +229,7 @@ mlir::Type getDerivedType(mlir::Type ty) { return seq.getEleTy(); return p.getEleTy(); }) - .Case<fir::BaseBoxType>( - [](auto p) { return getDerivedType(p.getEleTy()); }) + .Case([](fir::BaseBoxType p) { return getDerivedType(p.getEleTy()); }) .Default([](mlir::Type t) { return t; }); } @@ -423,7 +425,7 @@ mlir::Type unwrapInnerType(mlir::Type ty) { return seqTy.getEleTy(); return eleTy; }) - .Case<fir::RecordType>([](auto t) { return t; }) + .Case([](fir::RecordType t) { return t; }) .Default([](mlir::Type) { return mlir::Type{}; }); } @@ -685,7 +687,7 @@ std::string getTypeAsString(mlir::Type ty, const fir::KindMapping &kindMap, mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType, bool turnBoxIntoClass) { return llvm::TypeSwitch<mlir::Type, mlir::Type>(type) - .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type { + .Case([&](fir::SequenceType seqTy) -> mlir::Type { return fir::SequenceType::get(seqTy.getShape(), newElementType); }) .Case<fir::ReferenceType, fir::ClassType>([&](auto t) -> mlir::Type { @@ -699,7 +701,7 @@ mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType, return FIRT::get( changeElementType(t.getEleTy(), newElementType, turnBoxIntoClass)); }) - .Case<fir::BoxType>([&](fir::BoxType t) -> mlir::Type { + .Case([&](fir::BoxType t) -> mlir::Type { mlir::Type newInnerType = changeElementType(t.getEleTy(), newElementType, false); if (turnBoxIntoClass) @@ -1014,6 +1016,14 @@ mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) { if (parser.parseLess() || parser.parseKeyword(&name)) return {}; RecordType result = RecordType::get(parser.getContext(), name); + // Optional SEQUENCE attribute: ", sequence" + if (!parser.parseOptionalComma()) { + if (parser.parseKeyword("sequence")) { + parser.emitError(parser.getNameLoc(), "expected 'sequence' keyword"); + return {}; + } + result.setSequence(true); + } RecordType::TypeVector lenParamList; if (!parser.parseOptionalLParen()) { @@ -1069,6 +1079,8 @@ mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) { void fir::RecordType::print(mlir::AsmPrinter &printer) const { printer << "<" << getName(); + if (isSequence()) + printer << ",sequence"; if (!recordTypeVisited.count(uniqueKey())) { recordTypeVisited.insert(uniqueKey()); if (getLenParamList().size()) { @@ -1123,6 +1135,10 @@ void fir::RecordType::pack(bool p) { getImpl()->pack(p); } bool fir::RecordType::isPacked() const { return getImpl()->isPacked(); } +bool fir::RecordType::isSequence() const { return getImpl()->isSequence(); } + +void fir::RecordType::setSequence(bool s) { getImpl()->setSequence(s); } + detail::RecordTypeStorage const *fir::RecordType::uniqueKey() const { return getImpl(); } @@ -1438,7 +1454,7 @@ static mlir::Type changeTypeShape(mlir::Type type, std::optional<fir::SequenceType::ShapeRef> newShape) { return llvm::TypeSwitch<mlir::Type, mlir::Type>(type) - .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type { + .Case([&](fir::SequenceType seqTy) -> mlir::Type { if (newShape) return fir::SequenceType::get(*newShape, seqTy.getEleTy()); return seqTy.getEleTy(); @@ -1498,10 +1514,10 @@ fir::BaseBoxType fir::BaseBoxType::getBoxTypeWithNewAttr( break; } return llvm::TypeSwitch<fir::BaseBoxType, fir::BaseBoxType>(*this) - .Case<fir::BoxType>([baseType](auto b) { + .Case([baseType](fir::BoxType b) { return fir::BoxType::get(baseType, b.isVolatile()); }) - .Case<fir::ClassType>([baseType](auto b) { + .Case([baseType](fir::ClassType b) { return fir::ClassType::get(baseType, b.isVolatile()); }); } diff --git a/flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt b/flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt index d52ab09..d53937eb 100644 --- a/flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt +++ b/flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt @@ -3,18 +3,21 @@ add_flang_library(MIFDialect MIFOps.cpp DEPENDS - MLIRIR MIFOpsIncGen LINK_LIBS FIRDialect FIRDialectSupport - FIRSupport - MLIRIR - MLIRTargetLLVMIRExport LINK_COMPONENTS AsmParser AsmPrinter Remarks + + MLIR_DEPS + MLIRIR + + MLIR_LIBS + MLIRIR + MLIRTargetLLVMIRExport ) diff --git a/flang/lib/Optimizer/Dialect/MIF/MIFOps.cpp b/flang/lib/Optimizer/Dialect/MIF/MIFOps.cpp index c6cc2e8..8b04226 100644 --- a/flang/lib/Optimizer/Dialect/MIF/MIFOps.cpp +++ b/flang/lib/Optimizer/Dialect/MIF/MIFOps.cpp @@ -15,9 +15,6 @@ #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/SmallVector.h" -#define GET_OP_CLASSES -#include "flang/Optimizer/Dialect/MIF/MIFOps.cpp.inc" - //===----------------------------------------------------------------------===// // NumImagesOp //===----------------------------------------------------------------------===// @@ -151,3 +148,59 @@ llvm::LogicalResult mif::CoSumOp::verify() { return emitOpError("`A` shall be of numeric type."); return mlir::success(); } + +//===----------------------------------------------------------------------===// +// ChangeTeamOp +//===----------------------------------------------------------------------===// + +void mif::ChangeTeamOp::build(mlir::OpBuilder &builder, + mlir::OperationState &result, mlir::Value team, + llvm::ArrayRef<mlir::NamedAttribute> attributes) { + build(builder, result, team, /*stat*/ mlir::Value{}, /*errmsg*/ mlir::Value{}, + attributes); +} + +void mif::ChangeTeamOp::build(mlir::OpBuilder &builder, + mlir::OperationState &result, mlir::Value team, + mlir::Value stat, mlir::Value errmsg, + llvm::ArrayRef<mlir::NamedAttribute> attributes) { + std::int32_t argStat = 0, argErrmsg = 0; + result.addOperands(team); + if (stat) { + result.addOperands(stat); + argStat++; + } + if (errmsg) { + result.addOperands(errmsg); + argErrmsg++; + } + + mlir::Region *bodyRegion = result.addRegion(); + bodyRegion->push_back(new mlir::Block{}); + + result.addAttribute(getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({1, argStat, argErrmsg})); + result.addAttributes(attributes); +} + +static mlir::ParseResult parseChangeTeamOpBody(mlir::OpAsmParser &parser, + mlir::Region &body) { + if (parser.parseRegion(body)) + return mlir::failure(); + + mlir::Operation *terminator = body.back().getTerminator(); + if (!terminator || !mlir::isa<mif::EndTeamOp>(terminator)) + return parser.emitError(parser.getNameLoc(), + "missing mif.end_team terminator"); + + return mlir::success(); +} + +static void printChangeTeamOpBody(mlir::OpAsmPrinter &p, mif::ChangeTeamOp op, + mlir::Region &body) { + p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/true, + /*printBlockTerminators=*/true); +} + +#define GET_OP_CLASSES +#include "flang/Optimizer/Dialect/MIF/MIFOps.cpp.inc" diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp index 1b1abef..e0fee2f 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp @@ -87,8 +87,8 @@ bool hlfir::isFortranVariableType(mlir::Type type) { return mlir::isa<fir::BaseBoxType>(eleType) || !fir::hasDynamicSize(eleType); }) - .Case<fir::BaseBoxType, fir::BoxCharType>([](auto) { return true; }) - .Case<fir::VectorType>([](auto) { return true; }) + .Case<fir::BaseBoxType, fir::BoxCharType>([](mlir::Type) { return true; }) + .Case([](fir::VectorType) { return true; }) .Default([](mlir::Type) { return false; }); } diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index 1332dc5..e42c064 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -261,14 +261,12 @@ updateDeclaredInputTypeWithVolatility(mlir::Type inputType, mlir::Value memref, return std::make_pair(inputType, memref); } -void hlfir::DeclareOp::build(mlir::OpBuilder &builder, - mlir::OperationState &result, mlir::Value memref, - llvm::StringRef uniq_name, mlir::Value shape, - mlir::ValueRange typeparams, - mlir::Value dummy_scope, mlir::Value storage, - std::uint64_t storage_offset, - fir::FortranVariableFlagsAttr fortran_attrs, - cuf::DataAttributeAttr data_attr) { +void hlfir::DeclareOp::build( + mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value memref, + llvm::StringRef uniq_name, mlir::Value shape, mlir::ValueRange typeparams, + mlir::Value dummy_scope, mlir::Value storage, std::uint64_t storage_offset, + fir::FortranVariableFlagsAttr fortran_attrs, + cuf::DataAttributeAttr data_attr, unsigned dummy_arg_no) { auto nameAttr = builder.getStringAttr(uniq_name); mlir::Type inputType = memref.getType(); bool hasExplicitLbs = hasExplicitLowerBounds(shape); @@ -279,9 +277,12 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder, } auto [hlfirVariableType, firVarType] = getDeclareOutputTypes(inputType, hasExplicitLbs); + mlir::IntegerAttr argNoAttr; + if (dummy_arg_no > 0) + argNoAttr = builder.getUI32IntegerAttr(dummy_arg_no); build(builder, result, {hlfirVariableType, firVarType}, memref, shape, typeparams, dummy_scope, storage, storage_offset, nameAttr, - fortran_attrs, data_attr, /*skip_rebox=*/mlir::UnitAttr{}); + fortran_attrs, data_attr, /*skip_rebox=*/mlir::UnitAttr{}, argNoAttr); } llvm::LogicalResult hlfir::DeclareOp::verify() { @@ -591,6 +592,12 @@ llvm::LogicalResult hlfir::DesignateOp::verify() { return mlir::success(); } +std::optional<std::int64_t> hlfir::DesignateOp::getViewOffset(mlir::OpResult) { + // TODO: we can compute the constant offset + // based on the component/indices/etc. + return std::nullopt; +} + //===----------------------------------------------------------------------===// // ParentComponentOp //===----------------------------------------------------------------------===// diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp index 6a57bf2..13d9fc2 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp @@ -149,13 +149,18 @@ public: !assignOp.isTemporaryLHS() && mlir::isa<fir::RecordType>(fir::getElementTypeOf(lhsExv)); + mlir::ArrayAttr accessGroups; + if (auto attrs = assignOp.getOperation()->getAttrOfType<mlir::ArrayAttr>( + fir::getAccessGroupsAttrName())) + accessGroups = attrs; + // genScalarAssignment() must take care of potential overlap // between LHS and RHS. Note that the overlap is possible // also for components of LHS/RHS, and the Assign() runtime // must take care of it. - fir::factory::genScalarAssignment(builder, loc, lhsExv, rhsExv, - needFinalization, - assignOp.isTemporaryLHS()); + fir::factory::genScalarAssignment( + builder, loc, lhsExv, rhsExv, needFinalization, + assignOp.isTemporaryLHS(), accessGroups); } rewriter.eraseOp(assignOp); return mlir::success(); @@ -308,7 +313,8 @@ public: declareOp.getTypeparams(), declareOp.getDummyScope(), /*storage=*/declareOp.getStorage(), /*storage_offset=*/declareOp.getStorageOffset(), - declareOp.getUniqName(), fortranAttrs, dataAttr); + declareOp.getUniqName(), fortranAttrs, dataAttr, + declareOp.getDummyArgNoAttr()); // Propagate other attributes from hlfir.declare to fir.declare. // OpenACC's acc.declare is one example. Right now, the propagation @@ -467,7 +473,7 @@ public: if (designate.getComponent()) { mlir::Type baseRecordType = baseEntity.getFortranElementType(); if (fir::isRecordWithTypeParameters(baseRecordType)) - TODO(loc, "hlfir.designate with a parametrized derived type base"); + TODO(loc, "hlfir.designate with a parameterized derived type base"); fieldIndex = fir::FieldIndexOp::create( builder, loc, fir::FieldType::get(builder.getContext()), designate.getComponent().value(), baseRecordType, @@ -493,7 +499,7 @@ public: return mlir::success(); } TODO(loc, - "addressing parametrized derived type automatic components"); + "addressing parameterized derived type automatic components"); } baseEleTy = hlfir::getFortranElementType(componentType); shape = designate.getComponentShape(); diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp index 86d3974..356552f 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "flang/Optimizer/Analysis/AliasAnalysis.h" +#include "flang/Optimizer/Analysis/ArraySectionAnalyzer.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/HLFIRTools.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" @@ -93,40 +94,32 @@ public: // and proceed with the inlining. fir::AliasAnalysis aliasAnalysis; mlir::AliasResult aliasRes = aliasAnalysis.alias(lhs, rhs); - // TODO: use areIdenticalOrDisjointSlices() from - // OptimizedBufferization.cpp to check if we can still do the expansion. if (!aliasRes.isNo()) { - LLVM_DEBUG(llvm::dbgs() << "InlineHLFIRAssign:\n" - << "\tLHS: " << lhs << "\n" - << "\tRHS: " << rhs << "\n" - << "\tALIAS: " << aliasRes << "\n"); - return rewriter.notifyMatchFailure(assign, "RHS/LHS may alias"); + // Alias analysis reports potential aliasing, but we can use + // ArraySectionAnalyzer to check if the slices are disjoint + // or identical (which is safe for element-wise assignment). + fir::ArraySectionAnalyzer::SlicesOverlapKind overlap = + fir::ArraySectionAnalyzer::analyze(lhs, rhs); + if (overlap == fir::ArraySectionAnalyzer::SlicesOverlapKind::Unknown) { + LLVM_DEBUG(llvm::dbgs() << "InlineHLFIRAssign:\n" + << "\tLHS: " << lhs << "\n" + << "\tRHS: " << rhs << "\n" + << "\tALIAS: " << aliasRes << "\n"); + return rewriter.notifyMatchFailure(assign, "RHS/LHS may alias"); + } } } mlir::Location loc = assign->getLoc(); fir::FirOpBuilder builder(rewriter, assign.getOperation()); builder.setInsertionPoint(assign); - rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs); - lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs); - mlir::Value lhsShape = hlfir::genShape(loc, builder, lhs); - llvm::SmallVector<mlir::Value> lhsExtents = - hlfir::getIndexExtents(loc, builder, lhsShape); - mlir::Value rhsShape = hlfir::genShape(loc, builder, rhs); - llvm::SmallVector<mlir::Value> rhsExtents = - hlfir::getIndexExtents(loc, builder, rhsShape); - llvm::SmallVector<mlir::Value> extents = - fir::factory::deduceOptimalExtents(lhsExtents, rhsExtents); - hlfir::LoopNest loopNest = - hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true, - flangomp::shouldUseWorkshareLowering(assign)); - builder.setInsertionPointToStart(loopNest.body); - auto rhsArrayElement = - hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices); - rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement); - auto lhsArrayElement = - hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices); - hlfir::AssignOp::create(builder, loc, rhsArrayElement, lhsArrayElement); + mlir::ArrayAttr accessGroups; + if (auto attrs = assign.getOperation()->getAttrOfType<mlir::ArrayAttr>( + fir::getAccessGroupsAttrName())) + accessGroups = attrs; + hlfir::genNoAliasArrayAssignment( + loc, builder, rhs, lhs, flangomp::shouldUseWorkshareLowering(assign), + /*temporaryLHS=*/false, /*combiner=*/nullptr, accessGroups); rewriter.eraseOp(assign); return mlir::success(); } diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp index 32998ab..a3fd19d 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp @@ -96,7 +96,7 @@ struct MaskedArrayExpr { /// hlfir.elemental_addr that form the elemental tree producing /// the expression value. hlfir.elemental that produce values /// used inside transformational operations are not part of this set. - llvm::SmallPtrSet<mlir::Operation *, 4> elementalParts{}; + hlfir::ElementalTree elementalParts; /// Was generateNoneElementalPart called? bool noneElementalPartWasGenerated = false; /// Is this expression the mask expression of the outer where statement? @@ -517,7 +517,10 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) { } else { // TODO: preserve allocatable assignment aspects for forall once // they are conveyed in hlfir.region_assign. - hlfir::AssignOp::create(builder, loc, rhsEntity, lhsEntity); + auto assignOp = hlfir::AssignOp::create(builder, loc, rhsEntity, lhsEntity); + if (auto accessGroups = regionAssignOp->getAttrOfType<mlir::ArrayAttr>( + fir::getAccessGroupsAttrName())) + assignOp->setAttr(fir::getAccessGroupsAttrName(), accessGroups); } generateCleanupIfAny(loweredLhs.elementalCleanup); if (loweredLhs.vectorSubscriptLoopNest) @@ -897,62 +900,11 @@ bool OrderedAssignmentRewriter::isRequiredInCurrentRun( return false; } -/// Is the apply using all the elemental indices in order? -static bool isInOrderApply(hlfir::ApplyOp apply, - hlfir::ElementalOpInterface elemental) { - mlir::Region::BlockArgListType elementalIndices = elemental.getIndices(); - if (elementalIndices.size() != apply.getIndices().size()) - return false; - for (auto [elementalIdx, applyIdx] : - llvm::zip(elementalIndices, apply.getIndices())) - if (elementalIdx != applyIdx) - return false; - return true; -} - -/// Gather the tree of hlfir::ElementalOpInterface use-def, if any, starting -/// from \p elemental, which may be a nullptr. -static void -gatherElementalTree(hlfir::ElementalOpInterface elemental, - llvm::SmallPtrSetImpl<mlir::Operation *> &elementalOps, - bool isOutOfOrder) { - if (elemental) { - // Only inline an applied elemental that must be executed in order if the - // applying indices are in order. An hlfir::Elemental may have been created - // for a transformational like transpose, and Fortran 2018 standard - // section 10.2.3.2, point 10 imply that impure elemental sub-expression - // evaluations should not be masked if they are the arguments of - // transformational expressions. - if (isOutOfOrder && elemental.isOrdered()) - return; - elementalOps.insert(elemental.getOperation()); - for (mlir::Operation &op : elemental.getElementalRegion().getOps()) - if (auto apply = mlir::dyn_cast<hlfir::ApplyOp>(op)) { - bool isUnorderedApply = - isOutOfOrder || !isInOrderApply(apply, elemental); - auto maybeElemental = - mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>( - apply.getExpr().getDefiningOp()); - gatherElementalTree(maybeElemental, elementalOps, isUnorderedApply); - } - } -} - MaskedArrayExpr::MaskedArrayExpr(mlir::Location loc, mlir::Region ®ion, bool isOuterMaskExpr) : loc{loc}, region{region}, isOuterMaskExpr{isOuterMaskExpr} { mlir::Operation &terminator = region.back().back(); - if (auto elementalAddr = - mlir::dyn_cast<hlfir::ElementalOpInterface>(terminator)) { - // Vector subscripted designator (hlfir.elemental_addr terminator). - gatherElementalTree(elementalAddr, elementalParts, /*isOutOfOrder=*/false); - return; - } - // Try if elemental expression. - mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity(); - auto maybeElemental = mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>( - entity.getDefiningOp()); - gatherElementalTree(maybeElemental, elementalParts, /*isOutOfOrder=*/false); + elementalParts = hlfir::ElementalTree::buildElementalTree(terminator); } void MaskedArrayExpr::generateNoneElementalPart(fir::FirOpBuilder &builder, diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp index 2712bfb..5889122 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "flang/Optimizer/Analysis/AliasAnalysis.h" +#include "flang/Optimizer/Analysis/ArraySectionAnalyzer.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/HLFIRTools.h" #include "flang/Optimizer/Dialect/FIROps.h" @@ -88,13 +89,6 @@ private: /// determines if the transformation can be applied to this elemental static std::optional<MatchInfo> findMatch(hlfir::ElementalOp elemental); - /// Returns the array indices for the given hlfir.designate. - /// It recognizes the computations used to transform the one-based indices - /// into the array's lb-based indices, and returns the one-based indices - /// in these cases. - static llvm::SmallVector<mlir::Value> - getDesignatorIndices(hlfir::DesignateOp designate); - public: using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern; @@ -167,344 +161,6 @@ containsReadOrWriteEffectOn(const mlir::MemoryEffects::EffectInstance &effect, return mlir::AliasResult::NoAlias; } -// Helper class for analyzing two array slices represented -// by two hlfir.designate operations. -class ArraySectionAnalyzer { -public: - // The result of the analyzis is one of the values below. - enum class SlicesOverlapKind { - // Slices overlap is unknown. - Unknown, - // Slices are definitely identical. - DefinitelyIdentical, - // Slices are definitely disjoint. - DefinitelyDisjoint, - // Slices may be either disjoint or identical, - // i.e. there is definitely no partial overlap. - EitherIdenticalOrDisjoint - }; - - // Analyzes two hlfir.designate results and returns the overlap kind. - // The callers may use this method when the alias analysis reports - // an alias of some kind, so that we can run Fortran specific analysis - // on the array slices to see if they are identical or disjoint. - // Note that the alias analysis are not able to give such an answer - // about the references. - static SlicesOverlapKind analyze(mlir::Value ref1, mlir::Value ref2); - -private: - struct SectionDesc { - // An array section is described by <lb, ub, stride> tuple. - // If the designator's subscript is not a triple, then - // the section descriptor is constructed as <lb, nullptr, nullptr>. - mlir::Value lb, ub, stride; - - SectionDesc(mlir::Value lb, mlir::Value ub, mlir::Value stride) - : lb(lb), ub(ub), stride(stride) { - assert(lb && "lower bound or index must be specified"); - normalize(); - } - - // Normalize the section descriptor: - // 1. If UB is nullptr, then it is set to LB. - // 2. If LB==UB, then stride does not matter, - // so it is reset to nullptr. - // 3. If STRIDE==1, then it is reset to nullptr. - void normalize() { - if (!ub) - ub = lb; - if (lb == ub) - stride = nullptr; - if (stride) - if (auto val = fir::getIntIfConstant(stride)) - if (*val == 1) - stride = nullptr; - } - - bool operator==(const SectionDesc &other) const { - return lb == other.lb && ub == other.ub && stride == other.stride; - } - }; - - // Given an operand_iterator over the indices operands, - // read the subscript values and return them as SectionDesc - // updating the iterator. If isTriplet is true, - // the subscript is a triplet, and the result is <lb, ub, stride>. - // Otherwise, the subscript is a scalar index, and the result - // is <index, nullptr, nullptr>. - static SectionDesc readSectionDesc(mlir::Operation::operand_iterator &it, - bool isTriplet) { - if (isTriplet) - return {*it++, *it++, *it++}; - return {*it++, nullptr, nullptr}; - } - - // Return the ordered lower and upper bounds of the section. - // If stride is known to be non-negative, then the ordered - // bounds match the <lb, ub> of the descriptor. - // If stride is known to be negative, then the ordered - // bounds are <ub, lb> of the descriptor. - // If stride is unknown, we cannot deduce any order, - // so the result is <nullptr, nullptr> - static std::pair<mlir::Value, mlir::Value> - getOrderedBounds(const SectionDesc &desc) { - mlir::Value stride = desc.stride; - // Null stride means stride=1. - if (!stride) - return {desc.lb, desc.ub}; - // Reverse the bounds, if stride is negative. - if (auto val = fir::getIntIfConstant(stride)) { - if (*val >= 0) - return {desc.lb, desc.ub}; - else - return {desc.ub, desc.lb}; - } - - return {nullptr, nullptr}; - } - - // Given two array sections <lb1, ub1, stride1> and - // <lb2, ub2, stride2>, return true only if the sections - // are known to be disjoint. - // - // For example, for any positive constant C: - // X:Y does not overlap with (Y+C):Z - // X:Y does not overlap with Z:(X-C) - static bool areDisjointSections(const SectionDesc &desc1, - const SectionDesc &desc2) { - auto [lb1, ub1] = getOrderedBounds(desc1); - auto [lb2, ub2] = getOrderedBounds(desc2); - if (!lb1 || !lb2) - return false; - // Note that this comparison must be made on the ordered bounds, - // otherwise 'a(x:y:1) = a(z:x-1:-1) + 1' may be incorrectly treated - // as not overlapping (x=2, y=10, z=9). - if (isLess(ub1, lb2) || isLess(ub2, lb1)) - return true; - return false; - } - - // Given two array sections <lb1, ub1, stride1> and - // <lb2, ub2, stride2>, return true only if the sections - // are known to be identical. - // - // For example: - // <x, x, stride> - // <x, nullptr, nullptr> - // - // These sections are identical, from the point of which array - // elements are being addresses, even though the shape - // of the array slices might be different. - static bool areIdenticalSections(const SectionDesc &desc1, - const SectionDesc &desc2) { - if (desc1 == desc2) - return true; - return false; - } - - // Return true, if v1 is known to be less than v2. - static bool isLess(mlir::Value v1, mlir::Value v2); -}; - -ArraySectionAnalyzer::SlicesOverlapKind -ArraySectionAnalyzer::analyze(mlir::Value ref1, mlir::Value ref2) { - if (ref1 == ref2) - return SlicesOverlapKind::DefinitelyIdentical; - - auto des1 = ref1.getDefiningOp<hlfir::DesignateOp>(); - auto des2 = ref2.getDefiningOp<hlfir::DesignateOp>(); - // We only support a pair of designators right now. - if (!des1 || !des2) - return SlicesOverlapKind::Unknown; - - if (des1.getMemref() != des2.getMemref()) { - // If the bases are different, then there is unknown overlap. - LLVM_DEBUG(llvm::dbgs() << "No identical base for:\n" - << des1 << "and:\n" - << des2 << "\n"); - return SlicesOverlapKind::Unknown; - } - - // Require all components of the designators to be the same. - // It might be too strict, e.g. we may probably allow for - // different type parameters. - if (des1.getComponent() != des2.getComponent() || - des1.getComponentShape() != des2.getComponentShape() || - des1.getSubstring() != des2.getSubstring() || - des1.getComplexPart() != des2.getComplexPart() || - des1.getTypeparams() != des2.getTypeparams()) { - LLVM_DEBUG(llvm::dbgs() << "Different designator specs for:\n" - << des1 << "and:\n" - << des2 << "\n"); - return SlicesOverlapKind::Unknown; - } - - // Analyze the subscripts. - auto des1It = des1.getIndices().begin(); - auto des2It = des2.getIndices().begin(); - bool identicalTriplets = true; - bool identicalIndices = true; - for (auto [isTriplet1, isTriplet2] : - llvm::zip(des1.getIsTriplet(), des2.getIsTriplet())) { - SectionDesc desc1 = readSectionDesc(des1It, isTriplet1); - SectionDesc desc2 = readSectionDesc(des2It, isTriplet2); - - // See if we can prove that any of the sections do not overlap. - // This is mostly a Polyhedron/nf performance hack that looks for - // particular relations between the lower and upper bounds - // of the array sections, e.g. for any positive constant C: - // X:Y does not overlap with (Y+C):Z - // X:Y does not overlap with Z:(X-C) - if (areDisjointSections(desc1, desc2)) - return SlicesOverlapKind::DefinitelyDisjoint; - - if (!areIdenticalSections(desc1, desc2)) { - if (isTriplet1 || isTriplet2) { - // For example: - // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %0) - // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %1) - // - // If all the triplets (section speficiers) are the same, then - // we do not care if %0 is equal to %1 - the slices are either - // identical or completely disjoint. - // - // Also, treat these as identical sections: - // hlfir.designate %6#0 (%c2:%c2:%c1) - // hlfir.designate %6#0 (%c2) - identicalTriplets = false; - LLVM_DEBUG(llvm::dbgs() << "Triplet mismatch for:\n" - << des1 << "and:\n" - << des2 << "\n"); - } else { - identicalIndices = false; - LLVM_DEBUG(llvm::dbgs() << "Indices mismatch for:\n" - << des1 << "and:\n" - << des2 << "\n"); - } - } - } - - if (identicalTriplets) { - if (identicalIndices) - return SlicesOverlapKind::DefinitelyIdentical; - else - return SlicesOverlapKind::EitherIdenticalOrDisjoint; - } - - LLVM_DEBUG(llvm::dbgs() << "Different sections for:\n" - << des1 << "and:\n" - << des2 << "\n"); - return SlicesOverlapKind::Unknown; -} - -bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) { - auto removeConvert = [](mlir::Value v) -> mlir::Operation * { - auto *op = v.getDefiningOp(); - while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op)) - op = conv.getValue().getDefiningOp(); - return op; - }; - - auto isPositiveConstant = [](mlir::Value v) -> bool { - if (auto val = fir::getIntIfConstant(v)) - return *val > 0; - return false; - }; - - auto *op1 = removeConvert(v1); - auto *op2 = removeConvert(v2); - if (!op1 || !op2) - return false; - - // Check if they are both constants. - if (auto val1 = fir::getIntIfConstant(op1->getResult(0))) - if (auto val2 = fir::getIntIfConstant(op2->getResult(0))) - return *val1 < *val2; - - // Handle some variable cases (C > 0): - // v2 = v1 + C - // v2 = C + v1 - // v1 = v2 - C - if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2)) - if ((addi.getLhs().getDefiningOp() == op1 && - isPositiveConstant(addi.getRhs())) || - (addi.getRhs().getDefiningOp() == op1 && - isPositiveConstant(addi.getLhs()))) - return true; - if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1)) - if (subi.getLhs().getDefiningOp() == op2 && - isPositiveConstant(subi.getRhs())) - return true; - return false; -} - -llvm::SmallVector<mlir::Value> -ElementalAssignBufferization::getDesignatorIndices( - hlfir::DesignateOp designate) { - mlir::Value memref = designate.getMemref(); - - // If the object is a box, then the indices may be adjusted - // according to the box's lower bound(s). Scan through - // the computations to try to find the one-based indices. - if (mlir::isa<fir::BaseBoxType>(memref.getType())) { - // Look for the following pattern: - // %13 = fir.load %12 : !fir.ref<!fir.box<...> - // %14:3 = fir.box_dims %13, %c0 : (!fir.box<...>, index) -> ... - // %17 = arith.subi %14#0, %c1 : index - // %18 = arith.addi %arg2, %17 : index - // %19 = hlfir.designate %13 (%18) : (!fir.box<...>, index) -> ... - // - // %arg2 is a one-based index. - - auto isNormalizedLb = [memref](mlir::Value v, unsigned dim) { - // Return true, if v and dim are such that: - // %14:3 = fir.box_dims %13, %dim : (!fir.box<...>, index) -> ... - // %17 = arith.subi %14#0, %c1 : index - // %19 = hlfir.designate %13 (...) : (!fir.box<...>, index) -> ... - if (auto subOp = - mlir::dyn_cast_or_null<mlir::arith::SubIOp>(v.getDefiningOp())) { - auto cst = fir::getIntIfConstant(subOp.getRhs()); - if (!cst || *cst != 1) - return false; - if (auto dimsOp = mlir::dyn_cast_or_null<fir::BoxDimsOp>( - subOp.getLhs().getDefiningOp())) { - if (memref != dimsOp.getVal() || - dimsOp.getResult(0) != subOp.getLhs()) - return false; - auto dimsOpDim = fir::getIntIfConstant(dimsOp.getDim()); - return dimsOpDim && dimsOpDim == dim; - } - } - return false; - }; - - llvm::SmallVector<mlir::Value> newIndices; - for (auto index : llvm::enumerate(designate.getIndices())) { - if (auto addOp = mlir::dyn_cast_or_null<mlir::arith::AddIOp>( - index.value().getDefiningOp())) { - for (unsigned opNum = 0; opNum < 2; ++opNum) - if (isNormalizedLb(addOp->getOperand(opNum), index.index())) { - newIndices.push_back(addOp->getOperand((opNum + 1) % 2)); - break; - } - - // If new one-based index was not added, exit early. - if (newIndices.size() <= index.index()) - break; - } - } - - // If any of the indices is not adjusted to the array's lb, - // then return the original designator indices. - if (newIndices.size() != designate.getIndices().size()) - return designate.getIndices(); - - return newIndices; - } - - return designate.getIndices(); -} - std::optional<ElementalAssignBufferization::MatchInfo> ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) { mlir::Operation::user_range users = elemental->getUsers(); @@ -627,22 +283,20 @@ ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) { if (!res.isPartial()) { if (auto designate = effect.getValue().getDefiningOp<hlfir::DesignateOp>()) { - ArraySectionAnalyzer::SlicesOverlapKind overlap = - ArraySectionAnalyzer::analyze(match.array, designate.getMemref()); + fir::ArraySectionAnalyzer::SlicesOverlapKind overlap = + fir::ArraySectionAnalyzer::analyze(match.array, + designate.getMemref()); if (overlap == - ArraySectionAnalyzer::SlicesOverlapKind::DefinitelyDisjoint) + fir::ArraySectionAnalyzer::SlicesOverlapKind::DefinitelyDisjoint) continue; - if (overlap == ArraySectionAnalyzer::SlicesOverlapKind::Unknown) { + if (overlap == fir::ArraySectionAnalyzer::SlicesOverlapKind::Unknown) { LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate << " at " << elemental.getLoc() << "\n"); return std::nullopt; } - auto indices = getDesignatorIndices(designate); - auto elementalIndices = elemental.getIndices(); - if (indices.size() == elementalIndices.size() && - std::equal(indices.begin(), indices.end(), elementalIndices.begin(), - elementalIndices.end())) + if (fir::ArraySectionAnalyzer::isDesignatingArrayInOrder(designate, + elemental)) continue; LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate @@ -727,9 +381,13 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite( // Assign the element value to the array element for this iteration. auto arrayElement = hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices); - hlfir::AssignOp::create( + auto newAssign = hlfir::AssignOp::create( builder, loc, elementValue, arrayElement, /*realloc=*/false, /*keep_lhs_length_if_realloc=*/false, match->assign.getTemporaryLhs()); + if (auto accessGroups = + match->assign.getOperation()->getAttrOfType<mlir::ArrayAttr>( + fir::getAccessGroupsAttrName())) + newAssign->setAttr(fir::getAccessGroupsAttrName(), accessGroups); rewriter.eraseOp(match->assign); rewriter.eraseOp(match->destroy); @@ -788,6 +446,11 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite( llvm::SmallVector<mlir::Value> extents = hlfir::getIndexExtents(loc, builder, shape); + mlir::ArrayAttr accessGroups; + if (auto attrs = assign.getOperation()->getAttrOfType<mlir::ArrayAttr>( + fir::getAccessGroupsAttrName())) + accessGroups = attrs; + if (lhs.isSimplyContiguous() && extents.size() > 1) { // Flatten the array to use a single assign loop, that can be better // optimized. @@ -824,7 +487,9 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite( mlir::Value arrayElement = hlfir::DesignateOp::create(builder, loc, fir::ReferenceType::get(eleTy), flatArray, loopNest.oneBasedIndices); - hlfir::AssignOp::create(builder, loc, rhs, arrayElement); + auto newAssign = hlfir::AssignOp::create(builder, loc, rhs, arrayElement); + if (accessGroups) + newAssign->setAttr(fir::getAccessGroupsAttrName(), accessGroups); } else { hlfir::LoopNest loopNest = hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true, @@ -832,7 +497,9 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite( builder.setInsertionPointToStart(loopNest.body); auto arrayElement = hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices); - hlfir::AssignOp::create(builder, loc, rhs, arrayElement); + auto newAssign = hlfir::AssignOp::create(builder, loc, rhs, arrayElement); + if (accessGroups) + newAssign->setAttr(fir::getAccessGroupsAttrName(), accessGroups); } rewriter.eraseOp(assign); diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp index 63a5803..6bc5317 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp @@ -8,6 +8,7 @@ #include "ScheduleOrderedAssignments.h" #include "flang/Optimizer/Analysis/AliasAnalysis.h" +#include "flang/Optimizer/Analysis/ArraySectionAnalyzer.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/Dialect/Support/FIRContext.h" @@ -23,7 +24,13 @@ /// Log RAW or WAW conflict. [[maybe_unused]] static void logConflict(llvm::raw_ostream &os, mlir::Value writtenOrReadVarA, - mlir::Value writtenVarB); + mlir::Value writtenVarB, + bool isAligned = false); +/// Log when a region must be retroactively saved. +[[maybe_unused]] static void +logRetroactiveSave(llvm::raw_ostream &os, mlir::Region &yieldRegion, + hlfir::Run &modifyingRun, + hlfir::RegionAssignOp currentAssign); /// Log when an expression evaluation must be saved. [[maybe_unused]] static void logSaveEvaluation(llvm::raw_ostream &os, unsigned runid, @@ -39,15 +46,129 @@ logStartScheduling(llvm::raw_ostream &os, hlfir::OrderedAssignmentTreeOpInterface root); /// Log op if effect value is not known. [[maybe_unused]] static void -logIfUnkownEffectValue(llvm::raw_ostream &os, - mlir::MemoryEffects::EffectInstance effect, - mlir::Operation &op); +logIfUnknownEffectValue(llvm::raw_ostream &os, + mlir::MemoryEffects::EffectInstance effect, + mlir::Operation &op); //===----------------------------------------------------------------------===// // Scheduling Implementation //===----------------------------------------------------------------------===// +/// Is the apply using all the elemental indices in order? +static bool isInOrderApply(hlfir::ApplyOp apply, + hlfir::ElementalOpInterface elemental) { + mlir::Region::BlockArgListType elementalIndices = elemental.getIndices(); + if (elementalIndices.size() != apply.getIndices().size()) + return false; + for (auto [elementalIdx, applyIdx] : + llvm::zip(elementalIndices, apply.getIndices())) + if (elementalIdx != applyIdx) + return false; + return true; +} + +hlfir::ElementalTree +hlfir::ElementalTree::buildElementalTree(mlir::Operation ®ionTerminator) { + ElementalTree tree; + if (auto elementalAddr = + mlir::dyn_cast<hlfir::ElementalOpInterface>(regionTerminator)) { + // Vector subscripted designator (hlfir.elemental_addr terminator). + tree.gatherElementalTree(elementalAddr, /*isAppliedInOrder=*/true); + return tree; + } + // Try if elemental expression. + if (auto yield = mlir::dyn_cast<hlfir::YieldOp>(regionTerminator)) { + mlir::Value entity = yield.getEntity(); + if (auto maybeElemental = + mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>( + entity.getDefiningOp())) + tree.gatherElementalTree(maybeElemental, /*isAppliedInOrder=*/true); + } + return tree; +} + +// Check if op is an ElementalOpInterface that is part of this elemental tree. +bool hlfir::ElementalTree::contains(mlir::Operation *op) const { + for (auto &p : tree) + if (p.first == op) + return true; + return false; +} + +std::optional<bool> hlfir::ElementalTree::isOrdered(mlir::Operation *op) const { + for (auto &p : tree) + if (p.first == op) + return p.second; + return std::nullopt; +} + +void hlfir::ElementalTree::gatherElementalTree( + hlfir::ElementalOpInterface elemental, bool isAppliedInOrder) { + if (!elemental) + return; + // Only inline an applied elemental that must be executed in order if the + // applying indices are in order. An hlfir::Elemental may have been created + // for a transformational like transpose, and Fortran 2018 standard + // section 10.2.3.2, point 10 imply that impure elemental sub-expression + // evaluations should not be masked if they are the arguments of + // transformational expressions. + if (!isAppliedInOrder && elemental.isOrdered()) + return; + + insert(elemental, isAppliedInOrder); + for (mlir::Operation &op : elemental.getElementalRegion().getOps()) + if (auto apply = mlir::dyn_cast<hlfir::ApplyOp>(op)) { + bool isUnorderedApply = + !isAppliedInOrder || !isInOrderApply(apply, elemental); + auto maybeElemental = mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>( + apply.getExpr().getDefiningOp()); + gatherElementalTree(maybeElemental, !isUnorderedApply); + } +} + +void hlfir::ElementalTree::insert(hlfir::ElementalOpInterface elementalOp, + bool isAppliedInOrder) { + tree.push_back({elementalOp.getOperation(), isAppliedInOrder}); +} + +static bool isInOrderDesignate(hlfir::DesignateOp designate, + hlfir::ElementalTree *tree) { + if (!tree) + return false; + if (auto elemental = + designate->getParentOfType<hlfir::ElementalOpInterface>()) + if (tree->isOrdered(elemental.getOperation())) + return fir::ArraySectionAnalyzer::isDesignatingArrayInOrder(designate, + elemental); + return false; +} + +hlfir::DetailedEffectInstance::DetailedEffectInstance( + mlir::MemoryEffects::Effect *effect, mlir::OpOperand *value, + mlir::Value orderedElementalEffectOn) + : effectInstance(effect, value), + orderedElementalEffectOn(orderedElementalEffectOn) {} + +hlfir::DetailedEffectInstance::DetailedEffectInstance( + mlir::MemoryEffects::EffectInstance effectInst, + mlir::Value orderedElementalEffectOn) + : effectInstance(effectInst), + orderedElementalEffectOn(orderedElementalEffectOn) {} + +hlfir::DetailedEffectInstance +hlfir::DetailedEffectInstance::getArrayReadEffect(mlir::OpOperand *array) { + return DetailedEffectInstance(mlir::MemoryEffects::Read::get(), array, + array->get()); +} + +hlfir::DetailedEffectInstance +hlfir::DetailedEffectInstance::getArrayWriteEffect(mlir::OpOperand *array) { + return DetailedEffectInstance(mlir::MemoryEffects::Write::get(), array, + array->get()); +} + namespace { + /// Structure that is in charge of building the schedule. For each /// hlfir.region_assign inside an ordered assignment tree, it is walked through /// the parent operations and their "leaf" regions (that contain expression @@ -99,20 +220,25 @@ public: /// After all the dependent evaluation regions have been analyzed, create the /// action to evaluate the assignment that was being analyzed. - void finishSchedulingAssignment(hlfir::RegionAssignOp assign); + void finishSchedulingAssignment(hlfir::RegionAssignOp assign, + bool leafRegionsMayOnlyRead); /// Once all the assignments have been analyzed and scheduled, return the /// schedule. The scheduler object should not be used after this call. hlfir::Schedule moveSchedule() { return std::move(schedule); } private: + struct EvaluationState { + bool saved = false; + std::optional<hlfir::Schedule::iterator> modifiedInRun; + }; + /// Save a conflicting region that is evaluating an expression that is /// controlling or masking the current assignment, or is evaluating the /// RHS/LHS. - void - saveEvaluation(mlir::Region &yieldRegion, - llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effects, - bool anyWrite); + void saveEvaluation(mlir::Region &yieldRegion, + llvm::ArrayRef<hlfir::DetailedEffectInstance> effects, + bool anyWrite); /// Can the current assignment be schedule with the previous run. This is /// only possible if the assignment and all of its dependencies have no side @@ -120,19 +246,17 @@ private: bool canFuseAssignmentWithPreviousRun(); /// Memory effects of the assignments being lowered. - llvm::SmallVector<mlir::MemoryEffects::EffectInstance> assignEffects; + llvm::SmallVector<hlfir::DetailedEffectInstance> assignEffects; /// Memory effects of the evaluations implied by the assignments /// being lowered. They do not include the implicit writes /// to the LHS of the assignments. - llvm::SmallVector<mlir::MemoryEffects::EffectInstance> assignEvaluateEffects; + llvm::SmallVector<hlfir::DetailedEffectInstance> assignEvaluateEffects; /// Memory effects of the unsaved evaluation region that are controlling or /// masking the current assignments. - llvm::SmallVector<mlir::MemoryEffects::EffectInstance> - parentEvaluationEffects; + llvm::SmallVector<hlfir::DetailedEffectInstance> parentEvaluationEffects; /// Same as parentEvaluationEffects, but for the current "leaf group" being /// analyzed scheduled. - llvm::SmallVector<mlir::MemoryEffects::EffectInstance> - independentEvaluationEffects; + llvm::SmallVector<hlfir::DetailedEffectInstance> independentEvaluationEffects; /// Were any region saved for the current assignment? bool savedAnyRegionForCurrentAssignment = false; @@ -140,7 +264,10 @@ private: // Schedule being built. hlfir::Schedule schedule; /// Leaf regions that have been saved so far. - llvm::SmallPtrSet<mlir::Region *, 16> savedRegions; + llvm::DenseMap<mlir::Region *, EvaluationState> regionStates; + /// Regions that have an aligned conflict with the current assignment. + llvm::SmallVector<mlir::Region *> pendingAlignedRegions; + /// Is schedule.back() a schedule that is only saving region with read /// effects? bool currentRunIsReadOnly = false; @@ -171,9 +298,10 @@ static bool isForallIndex(mlir::Value var) { /// side effect interface, or that are writing temporary variables that may be /// hard to identify as such (one would have to prove the write is "local" to /// the region even when the alloca may be outside of the region). -static void gatherMemoryEffects( +static void gatherMemoryEffectsImpl( mlir::Region ®ion, bool mayOnlyRead, - llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &effects) { + llvm::SmallVectorImpl<hlfir::DetailedEffectInstance> &effects, + hlfir::ElementalTree *tree = nullptr) { /// This analysis is a simple walk of all the operations of the region that is /// evaluating and yielding a value. This is a lot simpler and safer than /// trying to walk back the SSA DAG from the yielded value. But if desired, @@ -181,7 +309,7 @@ static void gatherMemoryEffects( for (mlir::Operation &op : region.getOps()) { if (op.hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) { for (mlir::Region &subRegion : op.getRegions()) - gatherMemoryEffects(subRegion, mayOnlyRead, effects); + gatherMemoryEffectsImpl(subRegion, mayOnlyRead, effects, tree); // In MLIR, RecursiveMemoryEffects can be combined with // MemoryEffectOpInterface to describe extra effects on top of the // effects of the nested operations. However, the presence of @@ -214,17 +342,45 @@ static void gatherMemoryEffects( interface.getEffects(opEffects); for (auto &effect : opEffects) if (!isForallIndex(effect.getValue())) { + mlir::Value array; + if (effect.getValue()) + if (auto designate = + effect.getValue().getDefiningOp<hlfir::DesignateOp>()) + if (isInOrderDesignate(designate, tree)) + array = designate.getMemref(); + if (mlir::isa<mlir::MemoryEffects::Read>(effect.getEffect())) { - LLVM_DEBUG(logIfUnkownEffectValue(llvm::dbgs(), effect, op);); - effects.push_back(effect); + LLVM_DEBUG(logIfUnknownEffectValue(llvm::dbgs(), effect, op);); + effects.emplace_back(effect, array); } else if (!mayOnlyRead && mlir::isa<mlir::MemoryEffects::Write>(effect.getEffect())) { - LLVM_DEBUG(logIfUnkownEffectValue(llvm::dbgs(), effect, op);); - effects.push_back(effect); + LLVM_DEBUG(logIfUnknownEffectValue(llvm::dbgs(), effect, op);); + effects.emplace_back(effect, array); } } } } +static void gatherMemoryEffects( + mlir::Region ®ion, bool mayOnlyRead, + llvm::SmallVectorImpl<hlfir::DetailedEffectInstance> &effects) { + if (!region.getParentOfType<hlfir::ForallOp>()) { + // TODO: leverage array access analysis for FORALL. + // While FORALL assignments can be array assignments, the iteration space + // is also driven by the FORALL indices, so the way ArraySectionAnalyzer + // results are used is not adequate for it. + // For instance "disjoint" array access cannot be ignored in: + // "forall (i=1:10) x(i+1,:) = x(i,:)". + // While identical access can probably also be accepted, this would deserve + // more thinking, it would probably make sense to also deal with "aligned + // scalar" access for them like in "forall (i=1:10) x(i) = x(i) + 1". For + // now this feature is disabled for inside FORALL. + hlfir::ElementalTree tree = + hlfir::ElementalTree::buildElementalTree(region.back().back()); + gatherMemoryEffectsImpl(region, mayOnlyRead, effects, &tree); + return; + } + gatherMemoryEffectsImpl(region, mayOnlyRead, effects, /*tree=*/nullptr); +} /// Return the entity yielded by a region, or a null value if the region /// is not terminated by a yield. @@ -246,10 +402,14 @@ static mlir::OpOperand *getYieldedEntity(mlir::Region ®ion) { static void gatherAssignEffects( hlfir::RegionAssignOp regionAssign, bool userDefAssignmentMayOnlyWriteToAssignedVariable, - llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &assignEffects) { + llvm::SmallVectorImpl<hlfir::DetailedEffectInstance> &assignEffects) { mlir::OpOperand *assignedVar = getYieldedEntity(regionAssign.getLhsRegion()); assert(assignedVar && "lhs cannot be an empty region"); - assignEffects.emplace_back(mlir::MemoryEffects::Write::get(), assignedVar); + if (regionAssign->getParentOfType<hlfir::ForallOp>()) + assignEffects.emplace_back(mlir::MemoryEffects::Write::get(), assignedVar); + else + assignEffects.emplace_back( + hlfir::DetailedEffectInstance::getArrayWriteEffect(assignedVar)); if (!regionAssign.getUserDefinedAssignment().empty()) { // The write effect on the INTENT(OUT) LHS argument is already taken @@ -273,7 +433,7 @@ static void gatherAssignEffects( static void gatherAssignEvaluationEffects( hlfir::RegionAssignOp regionAssign, bool userDefAssignmentMayOnlyWriteToAssignedVariable, - llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &assignEffects) { + llvm::SmallVectorImpl<hlfir::DetailedEffectInstance> &assignEffects) { gatherMemoryEffects(regionAssign.getLhsRegion(), userDefAssignmentMayOnlyWriteToAssignedVariable, assignEffects); @@ -308,12 +468,57 @@ static mlir::Value getStorageSource(mlir::Value var) { return source; } +namespace { + +/// Class to represent conflicts between several accesses (effects) to a memory +/// location (read after write, write after write). +struct ConflictKind { + enum Kind { + // None: The effects are not affecting the same memory location, or they are + // all reads. + None, + // Aligned: There are both read and write effects affecting the same memory + // location, but it is known that these effects are all accessing the memory + // location element by element in array order. This means the conflict does + // not introduce loop-carried dependencies. + Aligned, + // Any: There may be both read and write effects affecting the same memory + // in any way. + Any + }; + Kind kind; + + ConflictKind(Kind k) : kind(k) {} + + static ConflictKind none() { return ConflictKind(None); } + static ConflictKind aligned() { return ConflictKind(Aligned); } + static ConflictKind any() { return ConflictKind(Any); } + + bool isNone() const { return kind == None; } + bool isAligned() const { return kind == Aligned; } + bool isAny() const { return kind == Any; } + + // Merge conflicts: + // none || none -> none + // aligned || <not any> -> aligned + // any || _ -> any + ConflictKind operator||(const ConflictKind &other) const { + if (kind == Any || other.kind == Any) + return any(); + if (kind == Aligned || other.kind == Aligned) + return aligned(); + return none(); + } +}; +} // namespace + /// Could there be any read or write in effectsA on a variable written to in /// effectsB? -static bool -anyRAWorWAW(llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effectsA, - llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effectsB, +static ConflictKind +anyRAWorWAW(llvm::ArrayRef<hlfir::DetailedEffectInstance> effectsA, + llvm::ArrayRef<hlfir::DetailedEffectInstance> effectsB, fir::AliasAnalysis &aliasAnalysis) { + ConflictKind result = ConflictKind::none(); for (const auto &effectB : effectsB) if (mlir::isa<mlir::MemoryEffects::Write>(effectB.getEffect())) { mlir::Value writtenVarB = effectB.getValue(); @@ -325,38 +530,66 @@ anyRAWorWAW(llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effectsA, mlir::Value writtenOrReadVarA = effectA.getValue(); if (!writtenVarB || !writtenOrReadVarA) { LLVM_DEBUG( - logConflict(llvm::dbgs(), writtenOrReadVarA, writtenVarB);); - return true; // unknown conflict. + logConflict(llvm::dbgs(), writtenOrReadVarA, writtenVarB)); + return ConflictKind::any(); // unknown conflict. } writtenOrReadVarA = getStorageSource(writtenOrReadVarA); if (!aliasAnalysis.alias(writtenOrReadVarA, writtenVarB).isNo()) { + mlir::Value arrayA = effectA.getOrderedElementalEffectOn(); + mlir::Value arrayB = effectB.getOrderedElementalEffectOn(); + if (arrayA && arrayB) { + if (arrayA == arrayB) { + result = result || ConflictKind::aligned(); + LLVM_DEBUG(logConflict(llvm::dbgs(), writtenOrReadVarA, + writtenVarB, /*isAligned=*/true)); + continue; + } + auto overlap = fir::ArraySectionAnalyzer::analyze(arrayA, arrayB); + if (overlap == fir::ArraySectionAnalyzer::SlicesOverlapKind:: + DefinitelyDisjoint) + continue; + if (overlap == fir::ArraySectionAnalyzer::SlicesOverlapKind:: + DefinitelyIdentical || + overlap == fir::ArraySectionAnalyzer::SlicesOverlapKind:: + EitherIdenticalOrDisjoint) { + result = result || ConflictKind::aligned(); + LLVM_DEBUG(logConflict(llvm::dbgs(), writtenOrReadVarA, + writtenVarB, /*isAligned=*/true)); + continue; + } + LLVM_DEBUG(llvm::dbgs() << "conflicting arrays:" << arrayA + << " and " << arrayB << "\n"); + return ConflictKind::any(); + } LLVM_DEBUG( - logConflict(llvm::dbgs(), writtenOrReadVarA, writtenVarB);); - return true; + logConflict(llvm::dbgs(), writtenOrReadVarA, writtenVarB)); + return ConflictKind::any(); } } } - return false; + return result; } /// Could there be any read or write in effectsA on a variable written to in /// effectsB, or any read in effectsB on a variable written to in effectsA? -static bool -conflict(llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effectsA, - llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effectsB) { +static ConflictKind +conflict(llvm::ArrayRef<hlfir::DetailedEffectInstance> effectsA, + llvm::ArrayRef<hlfir::DetailedEffectInstance> effectsB) { fir::AliasAnalysis aliasAnalysis; // (RAW || WAW) || (WAR || WAW). - return anyRAWorWAW(effectsA, effectsB, aliasAnalysis) || - anyRAWorWAW(effectsB, effectsA, aliasAnalysis); + ConflictKind result = anyRAWorWAW(effectsA, effectsB, aliasAnalysis); + if (result.isAny()) + return result; + return result || anyRAWorWAW(effectsB, effectsA, aliasAnalysis); } /// Could there be any write effects in "effects" affecting memory storages /// that are not local to the current region. static bool -anyNonLocalWrite(llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effects, +anyNonLocalWrite(llvm::ArrayRef<hlfir::DetailedEffectInstance> effects, mlir::Region ®ion) { return llvm::any_of( - effects, [®ion](const mlir::MemoryEffects::EffectInstance &effect) { + effects, [®ion](const hlfir::DetailedEffectInstance &effect) { if (mlir::isa<mlir::MemoryEffects::Write>(effect.getEffect())) { if (mlir::Value v = effect.getValue()) { v = getStorageSource(v); @@ -393,9 +626,9 @@ void Scheduler::saveEvaluationIfConflict(mlir::Region &yieldRegion, // If the region evaluation was previously executed and saved, the saved // value will be used when evaluating the current assignment and this has // no effects in the current assignment evaluation. - if (savedRegions.contains(&yieldRegion)) + if (regionStates[&yieldRegion].saved) return; - llvm::SmallVector<mlir::MemoryEffects::EffectInstance> effects; + llvm::SmallVector<hlfir::DetailedEffectInstance> effects; gatherMemoryEffects(yieldRegion, leafRegionsMayOnlyRead, effects); // Yield has no effect as such, but in the context of order assignments. // The order assignments will usually read the yielded entity (except for @@ -404,8 +637,13 @@ void Scheduler::saveEvaluationIfConflict(mlir::Region &yieldRegion, // intent(inout)). if (yieldIsImplicitRead) { mlir::OpOperand *entity = getYieldedEntity(yieldRegion); - if (entity && hlfir::isFortranVariableType(entity->get().getType())) - effects.emplace_back(mlir::MemoryEffects::Read::get(), entity); + if (entity && hlfir::isFortranVariableType(entity->get().getType())) { + if (yieldRegion.getParentOfType<hlfir::ForallOp>()) + effects.emplace_back(mlir::MemoryEffects::Read::get(), entity); + else + effects.emplace_back( + hlfir::DetailedEffectInstance::getArrayReadEffect(entity)); + } } if (!leafRegionsMayOnlyRead && anyNonLocalWrite(effects, yieldRegion)) { // Region with write effect must be executed only once (unless all writes @@ -415,33 +653,58 @@ void Scheduler::saveEvaluationIfConflict(mlir::Region &yieldRegion, << "saving eval because write effect prevents re-evaluation" << "\n";); saveEvaluation(yieldRegion, effects, /*anyWrite=*/true); - } else if (conflict(effects, assignEffects)) { - // Region that conflicts with the current assignments must be fully - // evaluated and saved before doing the assignment (Note that it may - // have already have been evaluated without saving it before, but this - // implies that it never conflicted with a prior assignment, so its value - // should be the same.) - saveEvaluation(yieldRegion, effects, /*anyWrite=*/false); - } else if (evaluationsMayConflict && - conflict(effects, assignEvaluateEffects)) { - // If evaluations of the assignment may conflict with the yield - // evaluations, we have to save yield evaluation. - // For example, a WHERE mask might be written by the masked assignment - // evaluations, and it has to be saved in this case: - // where (mask) r = f() ! function f modifies mask - saveEvaluation(yieldRegion, effects, - anyNonLocalWrite(effects, yieldRegion)); } else { - // Can be executed while doing the assignment. - independentEvaluationEffects.append(effects.begin(), effects.end()); + ConflictKind conflictKind = conflict(effects, assignEffects); + if (conflictKind.isAny()) { + // Region that conflicts with the current assignments must be fully + // evaluated and saved before doing the assignment (Note that it may + // have already been evaluated without saving it before, but this + // implies that it never conflicted with a prior assignment, so its value + // should be the same.) + saveEvaluation(yieldRegion, effects, /*anyWrite=*/false); + } else { + if (conflictKind.isAligned()) + pendingAlignedRegions.push_back(&yieldRegion); + + if (evaluationsMayConflict && + !conflict(effects, assignEvaluateEffects).isNone()) { + // If evaluations of the assignment may conflict with the yield + // evaluations, we have to save yield evaluation. + // For example, a WHERE mask might be written by the masked assignment + // evaluations, and it has to be saved in this case: + // where (mask) r = f() ! function f modifies mask + saveEvaluation(yieldRegion, effects, + anyNonLocalWrite(effects, yieldRegion)); + } else { + // Can be executed while doing the assignment. + independentEvaluationEffects.append(effects.begin(), effects.end()); + } + } } } void Scheduler::saveEvaluation( mlir::Region &yieldRegion, - llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effects, - bool anyWrite) { + llvm::ArrayRef<hlfir::DetailedEffectInstance> effects, bool anyWrite) { savedAnyRegionForCurrentAssignment = true; + auto &state = regionStates[&yieldRegion]; + if (state.modifiedInRun) { + // The region was modified in a previous run, but we now realize we need its + // value. We must save it before that modification run. + auto &newRun = *schedule.emplace(*state.modifiedInRun, hlfir::Run{}); + newRun.actions.emplace_back(hlfir::SaveEntity{&yieldRegion}); + // We do not have the parent effects from that time easily available here. + // However, since we are saving a parent of the current assignment, its + // parents are also parents of the current assignment. + newRun.memoryEffects.append(parentEvaluationEffects.begin(), + parentEvaluationEffects.end()); + newRun.memoryEffects.append(effects.begin(), effects.end()); + state.saved = true; + LLVM_DEBUG( + logSaveEvaluation(llvm::dbgs(), /*runid=*/0, yieldRegion, anyWrite);); + return; + } + if (anyWrite) { // Create a new run just for regions with side effect. Further analysis // could try to prove the effects do not conflict with the previous @@ -465,7 +728,7 @@ void Scheduler::saveEvaluation( schedule.back().memoryEffects.append(parentEvaluationEffects.begin(), parentEvaluationEffects.end()); schedule.back().memoryEffects.append(effects.begin(), effects.end()); - savedRegions.insert(&yieldRegion); + state.saved = true; LLVM_DEBUG( logSaveEvaluation(llvm::dbgs(), schedule.size(), yieldRegion, anyWrite);); } @@ -476,18 +739,78 @@ bool Scheduler::canFuseAssignmentWithPreviousRun() { if (savedAnyRegionForCurrentAssignment || schedule.empty()) return false; auto &previousRunEffects = schedule.back().memoryEffects; - return !conflict(previousRunEffects, assignEffects) && - !conflict(previousRunEffects, parentEvaluationEffects) && - !conflict(previousRunEffects, independentEvaluationEffects); + return !conflict(previousRunEffects, assignEffects).isAny() && + !conflict(previousRunEffects, parentEvaluationEffects).isAny() && + !conflict(previousRunEffects, independentEvaluationEffects).isAny(); +} + +/// Gather the parents of (not included) \p node in reverse execution order. +static void gatherParents( + hlfir::OrderedAssignmentTreeOpInterface node, + llvm::SmallVectorImpl<hlfir::OrderedAssignmentTreeOpInterface> &parents) { + while (node) { + auto parent = + mlir::dyn_cast_or_null<hlfir::OrderedAssignmentTreeOpInterface>( + node->getParentOp()); + if (parent && parent.getSubTreeRegion() == node->getParentRegion()) { + parents.push_back(parent); + node = parent; + } else { + break; + } + } +} + +// Build the list of the parent nodes for this assignment. The list is built +// from the closest parent until the ordered assignment tree root (this is the +// reverse of their execution order). +static void gatherAssignmentParents( + hlfir::RegionAssignOp assign, + llvm::SmallVectorImpl<hlfir::OrderedAssignmentTreeOpInterface> &parents) { + gatherParents(mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>( + assign.getOperation()), + parents); } -void Scheduler::finishSchedulingAssignment(hlfir::RegionAssignOp assign) { - // For now, always schedule each assignment in its own run. They could - // be done as part of previous assignment runs if it is proven they have - // no conflicting effects. +void Scheduler::finishSchedulingAssignment(hlfir::RegionAssignOp assign, + bool leafRegionsMayOnlyRead) { + // Schedule the assignment in a new run, unless it can be fused with the + // previous run (if enabled and proven safe). currentRunIsReadOnly = false; - if (!tryFusingAssignments || !canFuseAssignmentWithPreviousRun()) + bool fuse = tryFusingAssignments && canFuseAssignmentWithPreviousRun(); + if (!fuse) { + // If we cannot fuse, we are about to start a new run. + // Check if any parent region was modified in a previous run and needs to be + // saved. + llvm::SmallVector<hlfir::OrderedAssignmentTreeOpInterface> parents; + gatherAssignmentParents(assign, parents); + for (auto parent : parents) { + llvm::SmallVector<mlir::Region *, 4> yieldRegions; + parent.getLeafRegions(yieldRegions); + for (mlir::Region *yieldRegion : yieldRegions) { + if (regionStates[yieldRegion].modifiedInRun && + !regionStates[yieldRegion].saved) { + LLVM_DEBUG(logRetroactiveSave( + llvm::dbgs(), *yieldRegion, + **regionStates[yieldRegion].modifiedInRun, assign)); + llvm::SmallVector<hlfir::DetailedEffectInstance> effects; + gatherMemoryEffects(*yieldRegion, leafRegionsMayOnlyRead, effects); + saveEvaluation(*yieldRegion, effects, + anyNonLocalWrite(effects, *yieldRegion)); + } + } + } schedule.emplace_back(hlfir::Run{}); + } + + // Mark pending aligned regions as modified in the current run (which is the + // last one). + auto runIt = std::prev(schedule.end()); + for (mlir::Region *region : pendingAlignedRegions) + if (!regionStates[region].saved) + regionStates[region].modifiedInRun = runIt; + pendingAlignedRegions.clear(); + schedule.back().actions.emplace_back(assign); // TODO: when fusing, it would probably be best to filter the // parentEvaluationEffects that already in the previous run effects (since @@ -530,34 +853,6 @@ gatherAssignments(hlfir::OrderedAssignmentTreeOpInterface root, } } -/// Gather the parents of (not included) \p node in reverse execution order. -static void gatherParents( - hlfir::OrderedAssignmentTreeOpInterface node, - llvm::SmallVectorImpl<hlfir::OrderedAssignmentTreeOpInterface> &parents) { - while (node) { - auto parent = - mlir::dyn_cast_or_null<hlfir::OrderedAssignmentTreeOpInterface>( - node->getParentOp()); - if (parent && parent.getSubTreeRegion() == node->getParentRegion()) { - parents.push_back(parent); - node = parent; - } else { - break; - } - } -} - -// Build the list of the parent nodes for this assignment. The list is built -// from the closest parent until the ordered assignment tree root (this is the -// revere of their execution order). -static void gatherAssignmentParents( - hlfir::RegionAssignOp assign, - llvm::SmallVectorImpl<hlfir::OrderedAssignmentTreeOpInterface> &parents) { - gatherParents(mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>( - assign.getOperation()), - parents); -} - hlfir::Schedule hlfir::buildEvaluationSchedule(hlfir::OrderedAssignmentTreeOpInterface root, bool tryFusingAssignments) { @@ -616,7 +911,7 @@ hlfir::buildEvaluationSchedule(hlfir::OrderedAssignmentTreeOpInterface root, leafRegionsMayOnlyRead, /*yieldIsImplicitRead=*/false); scheduler.finishIndependentEvaluationGroup(); - scheduler.finishSchedulingAssignment(assign); + scheduler.finishSchedulingAssignment(assign, leafRegionsMayOnlyRead); } return scheduler.moveSchedule(); } @@ -704,6 +999,25 @@ static llvm::raw_ostream &printRegionPath(llvm::raw_ostream &os, return printRegionId(os, yieldRegion); } +[[maybe_unused]] static void +logRetroactiveSave(llvm::raw_ostream &os, mlir::Region &yieldRegion, + hlfir::Run &modifyingRun, + hlfir::RegionAssignOp currentAssign) { + printRegionPath(os, yieldRegion) << " is modified in order by "; + bool first = true; + for (auto &action : modifyingRun.actions) { + if (auto *assign = std::get_if<hlfir::RegionAssignOp>(&action)) { + if (!first) + os << ", "; + printNodePath(os, assign->getOperation()); + first = false; + } + } + os << " and is needed by "; + printNodePath(os, currentAssign.getOperation()); + os << " that is scheduled in a later run\n"; +} + [[maybe_unused]] static void logSaveEvaluation(llvm::raw_ostream &os, unsigned runid, mlir::Region &yieldRegion, @@ -721,13 +1035,14 @@ logAssignmentEvaluation(llvm::raw_ostream &os, unsigned runid, [[maybe_unused]] static void logConflict(llvm::raw_ostream &os, mlir::Value writtenOrReadVarA, - mlir::Value writtenVarB) { + mlir::Value writtenVarB, + bool isAligned) { auto printIfValue = [&](mlir::Value var) -> llvm::raw_ostream & { if (!var) return os << "<unknown>"; return os << var; }; - os << "conflict: R/W: "; + os << "conflict" << (isAligned ? " (aligned)" : "") << ": R/W: "; printIfValue(writtenOrReadVarA) << " W:"; printIfValue(writtenVarB) << "\n"; } @@ -743,9 +1058,9 @@ logStartScheduling(llvm::raw_ostream &os, } [[maybe_unused]] static void -logIfUnkownEffectValue(llvm::raw_ostream &os, - mlir::MemoryEffects::EffectInstance effect, - mlir::Operation &op) { +logIfUnknownEffectValue(llvm::raw_ostream &os, + mlir::MemoryEffects::EffectInstance effect, + mlir::Operation &op) { if (effect.getValue() != nullptr) return; os << "unknown effected value ("; diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.h b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.h index 2ed242e..7f479ab 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.h +++ b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.h @@ -15,9 +15,30 @@ #define OPTIMIZER_HLFIR_TRANSFORM_SCHEDULEORDEREDASSIGNMENTS_H #include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include <list> namespace hlfir { +struct ElementalTree { + // build an elemental tree given a masked region terminator. + static ElementalTree buildElementalTree(mlir::Operation ®ionTerminator); + // Check if op is an ElementalOpInterface that is part of this elemental tree. + bool contains(mlir::Operation *op) const; + + std::optional<bool> isOrdered(mlir::Operation *op) const; + +private: + void gatherElementalTree(hlfir::ElementalOpInterface elemental, + bool isAppliedInOrder); + void insert(hlfir::ElementalOpInterface elementalOp, bool isAppliedInOrder); + // List of ElementalOpInterface operation forming this tree, as well as a + // Boolean to indicate if they are applied in order (that is, if their + // indexing space is the same as the one for the array yielded by the mask + // region that owns this tree). + llvm::SmallVector<std::pair<mlir::Operation *, bool>> tree; +}; + /// Structure to represent that the value yielded by some region /// must be fully evaluated and saved for all index values at /// a given point of the ordered assignment tree evaluation. @@ -29,6 +50,37 @@ struct SaveEntity { mlir::Value getSavedValue(); }; +/// Wrapper class around mlir::MemoryEffects::EffectInstance that +/// allows providing an extra array value that indicates that the +/// effect is done element by element in array order (one element +/// accessed at each iteration of the ordered assignment iteration +/// space). +class DetailedEffectInstance { +public: + DetailedEffectInstance(mlir::MemoryEffects::Effect *effect, + mlir::OpOperand *value = nullptr, + mlir::Value orderedElementalEffectOn = nullptr); + DetailedEffectInstance(mlir::MemoryEffects::EffectInstance effectInstance, + mlir::Value orderedElementalEffectOn = nullptr); + + static DetailedEffectInstance getArrayReadEffect(mlir::OpOperand *array); + static DetailedEffectInstance getArrayWriteEffect(mlir::OpOperand *array); + + mlir::Value getValue() const { return effectInstance.getValue(); } + mlir::MemoryEffects::Effect *getEffect() const { + return effectInstance.getEffect(); + } + mlir::Value getOrderedElementalEffectOn() const { + return orderedElementalEffectOn; + } + +private: + mlir::MemoryEffects::EffectInstance effectInstance; + // Array whose elements are affected in array order by the + // ordered assignment iterations. Null value otherwise. + mlir::Value orderedElementalEffectOn; +}; + /// A run is a list of actions required to evaluate an ordered assignment tree /// that can be done in the same loop nest. /// The actions can evaluate and saves element values into temporary or evaluate @@ -42,11 +94,11 @@ struct Run { /// the assignment part of an hlfir::RegionAssignOp. using Action = std::variant<hlfir::RegionAssignOp, SaveEntity>; llvm::SmallVector<Action> actions; - llvm::SmallVector<mlir::MemoryEffects::EffectInstance> memoryEffects; + llvm::SmallVector<DetailedEffectInstance> memoryEffects; }; /// List of runs to be executed in order to evaluate an order assignment tree. -using Schedule = llvm::SmallVector<Run>; +using Schedule = std::list<Run>; /// Example of schedules and run, and what they mean: /// Fortran: forall (i=i:10) x(i) = y(i) diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index ce8ebaa..cc39652 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -931,6 +931,37 @@ private: mlir::Value genScalarAdd(mlir::Value value1, mlir::Value value2); }; +/// Reduction converter for Product. +class ProductAsElementalConverter + : public NumericReductionAsElementalConverterBase<hlfir::ProductOp> { + using Base = NumericReductionAsElementalConverterBase; + +public: + ProductAsElementalConverter(hlfir::ProductOp op, + mlir::PatternRewriter &rewriter) + : Base{op, rewriter} {} + +private: + virtual llvm::SmallVector<mlir::Value> genReductionInitValues( + [[maybe_unused]] mlir::ValueRange oneBasedIndices, + [[maybe_unused]] const llvm::SmallVectorImpl<mlir::Value> &extents) + final { + return {fir::factory::createOneValue(builder, loc, getResultElementType())}; + } + virtual llvm::SmallVector<mlir::Value> + reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> ¤tValue, + hlfir::Entity array, + mlir::ValueRange oneBasedIndices) final { + checkReductions(currentValue); + hlfir::Entity elementValue = + hlfir::loadElementAt(loc, builder, array, oneBasedIndices); + return {genScalarMult(currentValue[0], elementValue)}; + } + + // Generate scalar multiplication of the two values (of the same data type). + mlir::Value genScalarMult(mlir::Value value1, mlir::Value value2); +}; + /// Base class for logical reductions like ALL, ANY, COUNT. /// They do not have MASK and FastMathFlags. template <typename OpT> @@ -1194,6 +1225,20 @@ mlir::Value SumAsElementalConverter::genScalarAdd(mlir::Value value1, llvm_unreachable("unsupported SUM reduction type"); } +mlir::Value ProductAsElementalConverter::genScalarMult(mlir::Value value1, + mlir::Value value2) { + mlir::Type ty = value1.getType(); + assert(ty == value2.getType() && "reduction values' types do not match"); + if (mlir::isa<mlir::FloatType>(ty)) + return mlir::arith::MulFOp::create(builder, loc, value1, value2); + else if (mlir::isa<mlir::ComplexType>(ty)) + return fir::MulcOp::create(builder, loc, value1, value2); + else if (mlir::isa<mlir::IntegerType>(ty)) + return mlir::arith::MulIOp::create(builder, loc, value1, value2); + + llvm_unreachable("unsupported MUL reduction type"); +} + mlir::Value ReductionAsElementalConverter::genMaskValue( mlir::Value mask, mlir::Value isPresentPred, mlir::ValueRange indices) { mlir::OpBuilder::InsertionGuard guard(builder); @@ -1265,6 +1310,9 @@ public: } else if constexpr (std::is_same_v<Op, hlfir::SumOp>) { SumAsElementalConverter converter{op, rewriter}; return converter.convert(); + } else if constexpr (std::is_same_v<Op, hlfir::ProductOp>) { + ProductAsElementalConverter converter{op, rewriter}; + return converter.convert(); } return rewriter.notifyMatchFailure(op, "unexpected reduction operation"); } @@ -1371,15 +1419,12 @@ private: } /// The indices computations for the array shifts are done using I64 type. - /// For CSHIFT, all computations do not overflow signed and unsigned I64. - /// For EOSHIFT, some computations may involve negative shift values, - /// so using no-unsigned wrap flag would be incorrect. + /// For CSHIFT, and EOSHIFT all computations do not overflow signed I64. + /// While no-unsigned wrap could be set on some operation generated for + /// CSHIFT, it is in general unsafe to mix with computations involving + /// user defined bounds that may be negative. static void setArithOverflowFlags(Op op, fir::FirOpBuilder &builder) { - if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>) - builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw); - else - builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw | - mlir::arith::IntegerOverflowFlags::nuw); + builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw); } /// Return the element type of the EOSHIFT boundary that may be omitted @@ -1841,11 +1886,9 @@ private: hlfir::Entity srcArray = array; if (exposeContiguity && mlir::isa<fir::BaseBoxType>(srcArray.getType())) { assert(dimVal == 1 && "can expose contiguity only for dim 1"); - llvm::SmallVector<mlir::Value, maxRank> arrayLbounds = - hlfir::genLowerbounds(loc, builder, arrayShape, array.getRank()); hlfir::Entity section = - hlfir::gen1DSection(loc, builder, srcArray, dimVal, arrayLbounds, - arrayExtents, oneBasedIndices, typeParams); + hlfir::gen1DSection(loc, builder, srcArray, dimVal, arrayExtents, + oneBasedIndices, typeParams); mlir::Value addr = hlfir::genVariableRawAddress(loc, builder, section); mlir::Value shape = hlfir::genShape(loc, builder, section); mlir::Type boxType = fir::wrapInClassOrBoxType( @@ -3158,6 +3201,7 @@ public: mlir::RewritePatternSet patterns(context); patterns.insert<TransposeAsElementalConversion>(context); patterns.insert<ReductionConversion<hlfir::SumOp>>(context); + patterns.insert<ReductionConversion<hlfir::ProductOp>>(context); patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context); patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context); patterns.insert<CmpCharOpConversion>(context); diff --git a/flang/lib/Optimizer/OpenACC/Analysis/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/Analysis/CMakeLists.txt new file mode 100644 index 0000000..d9dda9d --- /dev/null +++ b/flang/lib/Optimizer/OpenACC/Analysis/CMakeLists.txt @@ -0,0 +1,24 @@ +add_flang_library(FIROpenACCAnalysis + FIROpenACCSupportAnalysis.cpp + + DEPENDS + FIRAnalysis + FIRDialect + FIROpenACCSupport + HLFIRDialect + + LINK_LIBS + FIRAnalysis + FIRDialect + FIROpenACCSupport + HLFIRDialect + + MLIR_DEPS + MLIROpenACCDialect + MLIROpenACCUtils + + MLIR_LIBS + MLIROpenACCDialect + MLIROpenACCUtils +) + diff --git a/flang/lib/Optimizer/OpenACC/Analysis/FIROpenACCSupportAnalysis.cpp b/flang/lib/Optimizer/OpenACC/Analysis/FIROpenACCSupportAnalysis.cpp new file mode 100644 index 0000000..3ad3188 --- /dev/null +++ b/flang/lib/Optimizer/OpenACC/Analysis/FIROpenACCSupportAnalysis.cpp @@ -0,0 +1,56 @@ +//===- FIROpenACCSupportAnalysis.cpp - FIR OpenACCSupport Analysis -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the FIR-specific OpenACCSupport analysis. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/OpenACC/Analysis/FIROpenACCSupportAnalysis.h" + +#include "flang/Optimizer/Builder/Todo.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/OpenACC/Support/FIROpenACCUtils.h" +#include "mlir/Dialect/OpenACC/OpenACCUtils.h" + +using namespace mlir; + +namespace fir { +namespace acc { + +std::string FIROpenACCSupportAnalysis::getVariableName(Value v) { + return fir::acc::getVariableName(v, /*preferDemangledName=*/true); +} + +std::string FIROpenACCSupportAnalysis::getRecipeName(mlir::acc::RecipeKind kind, + Type type, Value var) { + return fir::acc::getRecipeName(kind, type, var); +} + +mlir::InFlightDiagnostic +FIROpenACCSupportAnalysis::emitNYI(Location loc, const Twine &message) { + TODO(loc, message); + // Should be unreachable, but we return an actual diagnostic + // to satisfy the interface. + return mlir::emitError(loc, "not yet implemented: " + message.str()); +} + +bool FIROpenACCSupportAnalysis::isValidValueUse(Value v, Region ®ion) { + // First check using the base utility. + if (mlir::acc::isValidValueUse(v, region)) + return true; + + // FIR-specific: fir.logical is a trivial scalar type that can be + // passed by value. + if (mlir::isa<fir::LogicalType>(v.getType())) + return true; + + return false; +} + +} // namespace acc +} // namespace fir diff --git a/flang/lib/Optimizer/OpenACC/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/CMakeLists.txt index 790b9fd..16a4025 100644 --- a/flang/lib/Optimizer/OpenACC/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenACC/CMakeLists.txt @@ -1,2 +1,3 @@ +add_subdirectory(Analysis) add_subdirectory(Support) add_subdirectory(Transforms) diff --git a/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt index ef67ab1..9ff46c7 100644 --- a/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt @@ -2,10 +2,14 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) add_flang_library(FIROpenACCSupport FIROpenACCAttributes.cpp + FIROpenACCOpsInterfaces.cpp FIROpenACCTypeInterfaces.cpp + FIROpenACCUtils.cpp RegisterOpenACCExtensions.cpp DEPENDS + CUFAttrs + CUFDialect FIRBuilder FIRDialect FIRDialectSupport @@ -13,6 +17,8 @@ add_flang_library(FIROpenACCSupport HLFIRDialect LINK_LIBS + CUFAttrs + CUFDialect FIRBuilder FIRCodeGenDialect FIRDialect @@ -22,7 +28,9 @@ add_flang_library(FIROpenACCSupport MLIR_DEPS MLIROpenACCDialect + MLIROpenACCUtils MLIR_LIBS MLIROpenACCDialect + MLIROpenACCUtils ) diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp new file mode 100644 index 0000000..fc654e4 --- /dev/null +++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp @@ -0,0 +1,227 @@ +//===-- FIROpenACCOpsInterfaces.cpp ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implementation of external operation interfaces for FIR. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h" + +#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Optimizer/Support/InternalNames.h" +#include "mlir/IR/SymbolTable.h" +#include "llvm/ADT/SmallSet.h" + +namespace fir::acc { + +template <> +mlir::Value PartialEntityAccessModel<fir::ArrayCoorOp>::getBaseEntity( + mlir::Operation *op) const { + return mlir::cast<fir::ArrayCoorOp>(op).getMemref(); +} + +template <> +mlir::Value PartialEntityAccessModel<fir::CoordinateOp>::getBaseEntity( + mlir::Operation *op) const { + return mlir::cast<fir::CoordinateOp>(op).getRef(); +} + +template <> +mlir::Value PartialEntityAccessModel<hlfir::DesignateOp>::getBaseEntity( + mlir::Operation *op) const { + return mlir::cast<hlfir::DesignateOp>(op).getMemref(); +} + +mlir::Value PartialEntityAccessModel<fir::DeclareOp>::getBaseEntity( + mlir::Operation *op) const { + auto declareOp = mlir::cast<fir::DeclareOp>(op); + // If storage is present, return it (partial view case) + if (mlir::Value storage = declareOp.getStorage()) + return storage; + // Otherwise return the memref (complete view case) + return declareOp.getMemref(); +} + +bool PartialEntityAccessModel<fir::DeclareOp>::isCompleteView( + mlir::Operation *op) const { + // Complete view if storage is absent + return !mlir::cast<fir::DeclareOp>(op).getStorage(); +} + +mlir::Value PartialEntityAccessModel<hlfir::DeclareOp>::getBaseEntity( + mlir::Operation *op) const { + auto declareOp = mlir::cast<hlfir::DeclareOp>(op); + // If storage is present, return it (partial view case) + if (mlir::Value storage = declareOp.getStorage()) + return storage; + // Otherwise return the memref (complete view case) + return declareOp.getMemref(); +} + +bool PartialEntityAccessModel<hlfir::DeclareOp>::isCompleteView( + mlir::Operation *op) const { + // Complete view if storage is absent + return !mlir::cast<hlfir::DeclareOp>(op).getStorage(); +} + +mlir::SymbolRefAttr AddressOfGlobalModel::getSymbol(mlir::Operation *op) const { + return mlir::cast<fir::AddrOfOp>(op).getSymbolAttr(); +} + +bool GlobalVariableModel::isConstant(mlir::Operation *op) const { + auto globalOp = mlir::cast<fir::GlobalOp>(op); + return globalOp.getConstant().has_value(); +} + +mlir::Region *GlobalVariableModel::getInitRegion(mlir::Operation *op) const { + auto globalOp = mlir::cast<fir::GlobalOp>(op); + return globalOp.hasInitializationBody() ? &globalOp.getRegion() : nullptr; +} + +bool GlobalVariableModel::isDeviceData(mlir::Operation *op) const { + if (auto dataAttr = cuf::getDataAttr(op)) + return cuf::isDeviceDataAttribute(dataAttr.getValue()); + return false; +} + +// Helper to recursively process address-of operations in derived type +// descriptors and collect all needed fir.globals. +static void processAddrOfOpInDerivedTypeDescriptor( + fir::AddrOfOp addrOfOp, mlir::SymbolTable &symTab, + llvm::SmallSet<mlir::Operation *, 16> &globalsSet, + llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols) { + if (auto globalOp = symTab.lookup<fir::GlobalOp>( + addrOfOp.getSymbol().getLeafReference().getValue())) { + if (globalsSet.contains(globalOp)) + return; + globalsSet.insert(globalOp); + symbols.push_back(addrOfOp.getSymbolAttr()); + globalOp.walk([&](fir::AddrOfOp op) { + processAddrOfOpInDerivedTypeDescriptor(op, symTab, globalsSet, symbols); + }); + } +} + +// Utility to collect referenced symbols for type descriptors of derived types. +// This is the common logic for operations that may require type descriptor +// globals. +static void collectReferencedSymbolsForType( + mlir::Type ty, mlir::Operation *op, + llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols, + mlir::SymbolTable *symbolTable) { + ty = fir::getDerivedType(fir::unwrapRefType(ty)); + + // Look for type descriptor globals only if it's a derived (record) type + if (auto recTy = mlir::dyn_cast_if_present<fir::RecordType>(ty)) { + // If no symbol table provided, simply add the type descriptor name + if (!symbolTable) { + symbols.push_back(mlir::SymbolRefAttr::get( + op->getContext(), + fir::NameUniquer::getTypeDescriptorName(recTy.getName()))); + return; + } + + // Otherwise, do full lookup and recursive processing + llvm::SmallSet<mlir::Operation *, 16> globalsSet; + + fir::GlobalOp globalOp = symbolTable->lookup<fir::GlobalOp>( + fir::NameUniquer::getTypeDescriptorName(recTy.getName())); + if (!globalOp) + globalOp = symbolTable->lookup<fir::GlobalOp>( + fir::NameUniquer::getTypeDescriptorAssemblyName(recTy.getName())); + + if (globalOp) { + globalsSet.insert(globalOp); + symbols.push_back( + mlir::SymbolRefAttr::get(op->getContext(), globalOp.getSymName())); + globalOp.walk([&](fir::AddrOfOp addrOp) { + processAddrOfOpInDerivedTypeDescriptor(addrOp, *symbolTable, globalsSet, + symbols); + }); + } + } +} + +template <> +void IndirectGlobalAccessModel<fir::AllocaOp>::getReferencedSymbols( + mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols, + mlir::SymbolTable *symbolTable) const { + auto allocaOp = mlir::cast<fir::AllocaOp>(op); + collectReferencedSymbolsForType(allocaOp.getType(), op, symbols, symbolTable); +} + +template <> +void IndirectGlobalAccessModel<fir::EmboxOp>::getReferencedSymbols( + mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols, + mlir::SymbolTable *symbolTable) const { + auto emboxOp = mlir::cast<fir::EmboxOp>(op); + collectReferencedSymbolsForType(emboxOp.getMemref().getType(), op, symbols, + symbolTable); +} + +template <> +void IndirectGlobalAccessModel<fir::ReboxOp>::getReferencedSymbols( + mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols, + mlir::SymbolTable *symbolTable) const { + auto reboxOp = mlir::cast<fir::ReboxOp>(op); + collectReferencedSymbolsForType(reboxOp.getBox().getType(), op, symbols, + symbolTable); +} + +template <> +void IndirectGlobalAccessModel<fir::TypeDescOp>::getReferencedSymbols( + mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols, + mlir::SymbolTable *symbolTable) const { + auto typeDescOp = mlir::cast<fir::TypeDescOp>(op); + collectReferencedSymbolsForType(typeDescOp.getInType(), op, symbols, + symbolTable); +} + +template <> +bool OperationMoveModel<mlir::acc::LoopOp>::canMoveFromDescendant( + mlir::Operation *op, mlir::Operation *descendant, + mlir::Operation *candidate) const { + // It should be always allowed to move operations from descendants + // of acc.loop into the acc.loop. + return true; +} + +template <> +bool OperationMoveModel<mlir::acc::LoopOp>::canMoveOutOf( + mlir::Operation *op, mlir::Operation *candidate) const { + // Disallow moving operations, which have operands that are referenced + // in the data operands (e.g. in [first]private() etc.) of the acc.loop. + // For example: + // %17 = acc.private var(%16 : !fir.box<!fir.array<?xf32>>) + // acc.loop private(%17 : !fir.box<!fir.array<?xf32>>) ... { + // %19 = fir.box_addr %17 + // } + // We cannot hoist %19 without violating assumptions that OpenACC + // transformations rely on. + + // In general, some movement out of acc.loop is allowed, + // so return true if candidate is nullptr. + if (!candidate) + return true; + + auto loopOp = mlir::cast<mlir::acc::LoopOp>(op); + unsigned numDataOperands = loopOp.getNumDataOperands(); + for (unsigned i = 0; i < numDataOperands; ++i) { + mlir::Value dataOperand = loopOp.getDataOperand(i); + if (llvm::any_of(candidate->getOperands(), + [&](mlir::Value candidateOperand) { + return dataOperand == candidateOperand; + })) + return false; + } + return true; +} + +} // namespace fir::acc diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp index ed9e41c..9ced235 100644 --- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp @@ -15,12 +15,15 @@ #include "flang/Optimizer/Builder/DirectivesCommon.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/HLFIRTools.h" +#include "flang/Optimizer/Builder/IntrinsicCall.h" +#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h" #include "flang/Optimizer/Dialect/FIRCG/CGOps.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/Dialect/Support/KindMapping.h" +#include "flang/Optimizer/OpenACC/Support/FIROpenACCUtils.h" #include "flang/Optimizer/Support/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/OpenACC/OpenACC.h" @@ -193,6 +196,28 @@ OpenACCMappableModel<fir::PointerType>::getOffsetInBytes( mlir::Type type, mlir::Value var, mlir::ValueRange accBounds, const mlir::DataLayout &dataLayout) const; +template <typename Ty> +bool OpenACCMappableModel<Ty>::hasUnknownDimensions(mlir::Type type) const { + assert(fir::isa_ref_type(type) && "expected FIR reference type"); + return fir::hasDynamicSize(fir::unwrapRefType(type)); +} + +template bool OpenACCMappableModel<fir::ReferenceType>::hasUnknownDimensions( + mlir::Type type) const; + +template bool OpenACCMappableModel<fir::HeapType>::hasUnknownDimensions( + mlir::Type type) const; + +template bool OpenACCMappableModel<fir::PointerType>::hasUnknownDimensions( + mlir::Type type) const; + +template <> +bool OpenACCMappableModel<fir::BaseBoxType>::hasUnknownDimensions( + mlir::Type type) const { + // Descriptor-based entities have dimensions encoded. + return false; +} + static llvm::SmallVector<mlir::Value> generateSeqTyAccBounds(fir::SequenceType seqType, mlir::Value var, mlir::OpBuilder &builder) { @@ -202,48 +227,53 @@ generateSeqTyAccBounds(fir::SequenceType seqType, mlir::Value var, fir::FirOpBuilder firBuilder(builder, var.getDefiningOp()); mlir::Location loc = var.getLoc(); - if (seqType.hasDynamicExtents() || seqType.hasUnknownShape()) { - if (auto boxAddr = - mlir::dyn_cast_if_present<fir::BoxAddrOp>(var.getDefiningOp())) { - mlir::Value box = boxAddr.getVal(); - auto res = - hlfir::translateToExtendedValue(loc, firBuilder, hlfir::Entity(box)); - fir::ExtendedValue exv = res.first; - mlir::Value boxRef = box; - if (auto boxPtr = mlir::cast<mlir::acc::MappableType>(box.getType()) - .getVarPtr(box)) { - boxRef = boxPtr; + // If [hl]fir.declare is visible, extract the bounds from the declaration's + // shape (if it is provided). + if (mlir::isa<hlfir::DeclareOp, fir::DeclareOp>(var.getDefiningOp())) { + mlir::Value zero = + firBuilder.createIntegerConstant(loc, builder.getIndexType(), 0); + mlir::Value one = + firBuilder.createIntegerConstant(loc, builder.getIndexType(), 1); + + mlir::Value shape; + if (auto declareOp = + mlir::dyn_cast_if_present<fir::DeclareOp>(var.getDefiningOp())) + shape = declareOp.getShape(); + else if (auto declareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>( + var.getDefiningOp())) + shape = declareOp.getShape(); + + const bool strideIncludeLowerExtent = true; + + llvm::SmallVector<mlir::Value> accBounds; + mlir::Operation *anyShapeOp = shape ? shape.getDefiningOp() : nullptr; + if (auto shapeOp = mlir::dyn_cast_if_present<fir::ShapeOp>(anyShapeOp)) { + mlir::Value cummulativeExtent = one; + for (auto extent : shapeOp.getExtents()) { + mlir::Value upperbound = + mlir::arith::SubIOp::create(builder, loc, extent, one); + mlir::Value stride = one; + if (strideIncludeLowerExtent) { + stride = cummulativeExtent; + cummulativeExtent = mlir::arith::MulIOp::create( + builder, loc, cummulativeExtent, extent); + } + auto accBound = mlir::acc::DataBoundsOp::create( + builder, loc, mlir::acc::DataBoundsType::get(builder.getContext()), + /*lowerbound=*/zero, /*upperbound=*/upperbound, + /*extent=*/extent, /*stride=*/stride, /*strideInBytes=*/false, + /*startIdx=*/one); + accBounds.push_back(accBound); } - // TODO: Handle Fortran optional. - const mlir::Value isPresent; - fir::factory::AddrAndBoundsInfo info(box, boxRef, isPresent, - box.getType()); - return fir::factory::genBoundsOpsFromBox<mlir::acc::DataBoundsOp, - mlir::acc::DataBoundsType>( - firBuilder, loc, exv, info); - } - - if (mlir::isa<hlfir::DeclareOp, fir::DeclareOp>(var.getDefiningOp())) { - mlir::Value zero = - firBuilder.createIntegerConstant(loc, builder.getIndexType(), 0); - mlir::Value one = - firBuilder.createIntegerConstant(loc, builder.getIndexType(), 1); - - mlir::Value shape; - if (auto declareOp = - mlir::dyn_cast_if_present<fir::DeclareOp>(var.getDefiningOp())) - shape = declareOp.getShape(); - else if (auto declareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>( - var.getDefiningOp())) - shape = declareOp.getShape(); - - const bool strideIncludeLowerExtent = true; - - llvm::SmallVector<mlir::Value> accBounds; - if (auto shapeOp = - mlir::dyn_cast_if_present<fir::ShapeOp>(shape.getDefiningOp())) { - mlir::Value cummulativeExtent = one; - for (auto extent : shapeOp.getExtents()) { + } else if (auto shapeShiftOp = + mlir::dyn_cast_if_present<fir::ShapeShiftOp>(anyShapeOp)) { + mlir::Value lowerbound; + mlir::Value cummulativeExtent = one; + for (auto [idx, val] : llvm::enumerate(shapeShiftOp.getPairs())) { + if (idx % 2 == 0) { + lowerbound = val; + } else { + mlir::Value extent = val; mlir::Value upperbound = mlir::arith::SubIOp::create(builder, loc, extent, one); mlir::Value stride = one; @@ -257,40 +287,48 @@ generateSeqTyAccBounds(fir::SequenceType seqType, mlir::Value var, mlir::acc::DataBoundsType::get(builder.getContext()), /*lowerbound=*/zero, /*upperbound=*/upperbound, /*extent=*/extent, /*stride=*/stride, /*strideInBytes=*/false, - /*startIdx=*/one); + /*startIdx=*/lowerbound); accBounds.push_back(accBound); } - } else if (auto shapeShiftOp = - mlir::dyn_cast_if_present<fir::ShapeShiftOp>( - shape.getDefiningOp())) { - mlir::Value lowerbound; - mlir::Value cummulativeExtent = one; - for (auto [idx, val] : llvm::enumerate(shapeShiftOp.getPairs())) { - if (idx % 2 == 0) { - lowerbound = val; - } else { - mlir::Value extent = val; - mlir::Value upperbound = - mlir::arith::SubIOp::create(builder, loc, extent, one); - mlir::Value stride = one; - if (strideIncludeLowerExtent) { - stride = cummulativeExtent; - cummulativeExtent = mlir::arith::MulIOp::create( - builder, loc, cummulativeExtent, extent); - } - auto accBound = mlir::acc::DataBoundsOp::create( - builder, loc, - mlir::acc::DataBoundsType::get(builder.getContext()), - /*lowerbound=*/zero, /*upperbound=*/upperbound, - /*extent=*/extent, /*stride=*/stride, /*strideInBytes=*/false, - /*startIdx=*/lowerbound); - accBounds.push_back(accBound); - } - } } + } + + if (!accBounds.empty()) + return accBounds; + } + + if (seqType.hasDynamicExtents() || seqType.hasUnknownShape()) { + mlir::Value box; + bool mayBeOptional = false; + if (auto boxAddr = + mlir::dyn_cast_if_present<fir::BoxAddrOp>(var.getDefiningOp())) { + box = boxAddr.getVal(); + // Since fir.box_addr already accesses the box, we do not care + // checking if it is optional. + } else if (mlir::isa<fir::BaseBoxType>(var.getType())) { + box = var; + mayBeOptional = fir::mayBeAbsentBox(box); + } + + if (box) { + auto res = + hlfir::translateToExtendedValue(loc, firBuilder, hlfir::Entity(box)); + fir::ExtendedValue exv = res.first; + mlir::Value boxRef = box; + if (auto boxPtr = + mlir::cast<mlir::acc::MappableType>(box.getType()).getVarPtr(box)) + boxRef = boxPtr; + + mlir::Value isPresent = + !mayBeOptional ? mlir::Value{} + : fir::IsPresentOp::create(builder, loc, + builder.getI1Type(), box); - if (!accBounds.empty()) - return accBounds; + fir::factory::AddrAndBoundsInfo info(box, boxRef, isPresent, + box.getType()); + return fir::factory::genBoundsOpsFromBox<mlir::acc::DataBoundsOp, + mlir::acc::DataBoundsType>( + firBuilder, loc, exv, info); } assert(false && "array with unknown dimension expected to have descriptor"); @@ -353,7 +391,7 @@ getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) { // calculation op. mlir::Value baseRef = llvm::TypeSwitch<mlir::Operation *, mlir::Value>(op) - .Case<fir::DeclareOp>([&](auto op) { + .Case([&](fir::DeclareOp op) { // If this declare binds a view with an underlying storage operand, // treat that storage as the base reference. Otherwise, fall back // to the declared memref. @@ -361,7 +399,7 @@ getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) { return storage; return mlir::Value(varPtr); }) - .Case<hlfir::DesignateOp>([&](auto op) { + .Case([&](hlfir::DesignateOp op) { // Get the base object. return op.getMemref(); }) @@ -369,12 +407,12 @@ getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) { // Get the base array on which the coordinate is being applied. return op.getMemref(); }) - .Case<fir::CoordinateOp>([&](auto op) { + .Case([&](fir::CoordinateOp op) { // For coordinate operation which is applied on derived type // object, get the base object. return op.getRef(); }) - .Case<fir::ConvertOp>([&](auto op) -> mlir::Value { + .Case([&](fir::ConvertOp op) -> mlir::Value { // Strip the conversion and recursively check the operand if (auto ptrLikeOperand = mlir::dyn_cast_if_present< mlir::TypedValue<mlir::acc::PointerLikeType>>( @@ -543,30 +581,141 @@ OpenACCPointerLikeModel<fir::LLVMPointerType>::getPointeeTypeCategory( return categorizePointee(pointer, varPtr, varType); } -static fir::ShapeOp genShapeOp(mlir::OpBuilder &builder, - fir::SequenceType seqTy, mlir::Location loc) { +static hlfir::Entity +genDesignateWithTriplets(fir::FirOpBuilder &builder, mlir::Location loc, + hlfir::Entity &entity, + hlfir::DesignateOp::Subscripts &triplets, + mlir::Value shape, mlir::ValueRange extents) { + llvm::SmallVector<mlir::Value> lenParams; + hlfir::genLengthParameters(loc, builder, entity, lenParams); + + // Compute result type of array section. + fir::SequenceType::Shape resultTypeShape; + bool shapeIsConstant = true; + for (mlir::Value extent : extents) { + if (std::optional<std::int64_t> cst_extent = + fir::getIntIfConstant(extent)) { + resultTypeShape.push_back(*cst_extent); + } else { + resultTypeShape.push_back(fir::SequenceType::getUnknownExtent()); + shapeIsConstant = false; + } + } + assert(!resultTypeShape.empty() && + "expect private sections to always represented as arrays"); + mlir::Type eleTy = entity.getFortranElementType(); + auto seqTy = fir::SequenceType::get(resultTypeShape, eleTy); + bool isVolatile = fir::isa_volatile_type(entity.getType()); + bool resultNeedsBox = + llvm::isa<fir::BaseBoxType>(entity.getType()) || !shapeIsConstant; + bool isPolymorphic = fir::isPolymorphicType(entity.getType()); + mlir::Type resultType; + if (isPolymorphic) { + resultType = fir::ClassType::get(seqTy, isVolatile); + } else if (resultNeedsBox) { + resultType = fir::BoxType::get(seqTy, isVolatile); + } else { + resultType = fir::ReferenceType::get(seqTy, isVolatile); + } + + // Generate section with hlfir.designate. + auto designate = hlfir::DesignateOp::create( + builder, loc, resultType, entity, /*component=*/"", + /*componentShape=*/mlir::Value{}, triplets, + /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, shape, + lenParams); + return hlfir::Entity{designate.getResult()}; +} + +// Designate uses triplets based on object lower bounds while acc.bounds are +// zero based. This helper shift the bounds to create the designate triplets. +static hlfir::DesignateOp::Subscripts +genTripletsFromAccBounds(fir::FirOpBuilder &builder, mlir::Location loc, + const llvm::SmallVector<mlir::Value> &accBounds, + hlfir::Entity entity) { + assert(entity.getRank() * 3 == static_cast<int>(accBounds.size()) && + "must get lb,ub,step for each dimension"); + hlfir::DesignateOp::Subscripts triplets; + for (unsigned i = 0; i < accBounds.size(); i += 3) { + mlir::Value lb = hlfir::genLBound(loc, builder, entity, i / 3); + lb = builder.createConvert(loc, accBounds[i].getType(), lb); + assert(accBounds[i].getType() == accBounds[i + 1].getType() && + "mix of integer types in triplets"); + mlir::Value sliceLB = + builder.createOrFold<mlir::arith::AddIOp>(loc, accBounds[i], lb); + mlir::Value sliceUB = + builder.createOrFold<mlir::arith::AddIOp>(loc, accBounds[i + 1], lb); + triplets.emplace_back( + hlfir::DesignateOp::Triplet{sliceLB, sliceUB, accBounds[i + 2]}); + } + return triplets; +} + +static std::pair<mlir::Value, llvm::SmallVector<mlir::Value>> +computeSectionShapeAndExtents(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::ValueRange bounds) { llvm::SmallVector<mlir::Value> extents; + // Compute the fir.shape of the array section and the triplets to create + // hlfir.designate. mlir::Type idxTy = builder.getIndexType(); - for (auto extent : seqTy.getShape()) - extents.push_back(mlir::arith::ConstantOp::create( - builder, loc, idxTy, builder.getIntegerAttr(idxTy, extent))); - return fir::ShapeOp::create(builder, loc, extents); + for (unsigned i = 0; i + 2 < bounds.size(); i += 3) + extents.push_back(builder.genExtentFromTriplet( + loc, bounds[i], bounds[i + 1], bounds[i + 2], idxTy, /*fold=*/true)); + mlir::Value shape = fir::ShapeOp::create(builder, loc, extents); + return {shape, extents}; +} + +static std::pair<hlfir::Entity, hlfir::Entity> +genArraySectionsInRecipe(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::ValueRange bounds, hlfir::Entity lhs, + hlfir::Entity rhs) { + assert(lhs.getRank() * 3 == static_cast<int>(bounds.size()) && + "must get lb,ub,step for each dimension"); + lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs); + rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs); + // Get the list of lb,ub,step values for the sections that can be used inside + // the recipe region. + auto [shape, extents] = computeSectionShapeAndExtents(builder, loc, bounds); + hlfir::DesignateOp::Subscripts rhsTriplets = + genTripletsFromAccBounds(builder, loc, bounds, rhs); + hlfir::DesignateOp::Subscripts lhsTriplets; + // Share the bounds when both rhs/lhs are known to be 1-based to avoid noise + // in the IR for the most common cases. + if (!lhs.mayHaveNonDefaultLowerBounds() && + !rhs.mayHaveNonDefaultLowerBounds()) + lhsTriplets = rhsTriplets; + else + lhsTriplets = genTripletsFromAccBounds(builder, loc, bounds, lhs); + hlfir::Entity leftSection = + genDesignateWithTriplets(builder, loc, lhs, lhsTriplets, shape, extents); + hlfir::Entity rightSection = + genDesignateWithTriplets(builder, loc, rhs, rhsTriplets, shape, extents); + return {leftSection, rightSection}; +} + +static bool boundsAreAllConstants(mlir::ValueRange bounds) { + for (mlir::Value bound : bounds) + if (!fir::getIntIfConstant(bound).has_value()) + return false; + return true; } template <typename Ty> mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit( - mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, + mlir::Type type, mlir::OpBuilder &mlirBuilder, mlir::Location loc, mlir::TypedValue<mlir::acc::MappableType> var, llvm::StringRef varName, - mlir::ValueRange extents, mlir::Value initVal, bool &needsDestroy) const { - needsDestroy = false; - mlir::Value retVal; - mlir::Type unwrappedTy = fir::unwrapRefType(type); - mlir::ModuleOp mod = builder.getInsertionBlock() + mlir::ValueRange bounds, mlir::Value initVal, bool &needsDestroy) const { + mlir::ModuleOp mod = mlirBuilder.getInsertionBlock() ->getParent() ->getParentOfType<mlir::ModuleOp>(); - - if (auto recType = llvm::dyn_cast<fir::RecordType>( - fir::getFortranElementType(unwrappedTy))) { + assert(mod && "failed to retrieve ModuleOp"); + fir::FirOpBuilder builder(mlirBuilder, mod); + + hlfir::Entity inputVar = hlfir::Entity{var}; + if (inputVar.isPolymorphic()) + TODO(loc, "OpenACC: polymorphic variable privatization"); + if (auto recType = + llvm::dyn_cast<fir::RecordType>(inputVar.getFortranElementType())) { // Need to make deep copies of allocatable components. if (fir::isRecordWithAllocatableMember(recType)) TODO(loc, @@ -575,117 +724,161 @@ mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit( if (fir::isRecordWithFinalRoutine(recType, mod).value_or(false)) TODO(loc, "OpenACC: privatizing derived type with user assignment or " "final routine "); + // Pointer components needs to be initialized to NULL() for private-like + // recipes. + if (fir::isRecordWithDescriptorMember(recType)) + TODO(loc, "OpenACC: privatizing derived type with pointer components"); + } + bool isPointerOrAllocatable = inputVar.isMutableBox(); + hlfir::Entity dereferencedVar = + hlfir::derefPointersAndAllocatables(loc, builder, inputVar); + + // Step 1: Gather the address, shape, extents, and lengths parameters of the + // entity being privatized. Designate the array section if only a section is + // privatized, otherwise just use the original variable. + hlfir::Entity privatizedVar = dereferencedVar; + mlir::Value tempShape; + llvm::SmallVector<mlir::Value> tempExtents; + // TODO: while it seems best to allocate as little memory as possible and + // allocate only the storage for the section, this may actually have drawbacks + // when the array has static size and can be privatized with an alloca while + // the section size is dynamic and requires an dynamic allocmem. Hence, we + // currently allocate the full array storage in such cases. This could be + // improved via some kind of threshold if the base array size is large enough + // to justify doing a dynamic allocation with the hope that it is much + // smaller. + bool allocateSection = false; + bool isDynamicSectionOfStaticSizeArray = + !bounds.empty() && + !fir::hasDynamicSize(dereferencedVar.getElementOrSequenceType()) && + !boundsAreAllConstants(bounds); + if (!bounds.empty() && !isDynamicSectionOfStaticSizeArray) { + allocateSection = true; + hlfir::DesignateOp::Subscripts triplets; + std::tie(tempShape, tempExtents) = + computeSectionShapeAndExtents(builder, loc, bounds); + triplets = genTripletsFromAccBounds(builder, loc, bounds, dereferencedVar); + privatizedVar = genDesignateWithTriplets(builder, loc, dereferencedVar, + triplets, tempShape, tempExtents); + } else if (privatizedVar.getRank() > 0) { + mlir::Value shape = hlfir::genShape(loc, builder, privatizedVar); + tempExtents = hlfir::getExplicitExtentsFromShape(shape, builder); + tempShape = fir::ShapeOp::create(builder, loc, tempExtents); + } + llvm::SmallVector<mlir::Value> typeParams; + hlfir::genLengthParameters(loc, builder, privatizedVar, typeParams); + mlir::Type baseType = privatizedVar.getElementOrSequenceType(); + // Step2: Create a temporary allocation for the privatized part. + mlir::Value alloc; + if (fir::hasDynamicSize(baseType) || + (isPointerOrAllocatable && bounds.empty())) { + // Note: heap allocation is forced for whole pointers/allocatable so that + // the private POINTER/ALLOCATABLE can be deallocated/reallocated on the + // device inside the compute region. It may not be a requirement, and this + // could be revisited. In practice, this only matters for scalars since + // array POINTER and ALLOCATABLE always have dynamic size. Constant sections + // of POINTER/ALLOCATABLE can use alloca since only part of the data is + // privatized (it makes no sense to deallocate them). + alloc = builder.createHeapTemporary(loc, baseType, varName, tempExtents, + typeParams); + needsDestroy = true; + } else { + alloc = builder.createTemporary(loc, baseType, varName, tempExtents, + typeParams); + } + // Step3: Assign the initial value to the privatized part if any. + if (initVal) { + mlir::Value tempEntity = alloc; + if (fir::hasDynamicSize(baseType)) + tempEntity = + fir::EmboxOp::create(builder, loc, fir::BoxType::get(baseType), alloc, + tempShape, /*slice=*/mlir::Value{}, typeParams); + hlfir::genNoAliasAssignment( + loc, builder, hlfir::Entity{initVal}, hlfir::Entity{tempEntity}, + /*emitWorkshareLoop=*/false, /*temporaryLHS=*/true); } - fir::FirOpBuilder firBuilder(builder, mod); - auto getDeclareOpForType = [&](mlir::Type ty) -> hlfir::DeclareOp { - auto alloca = fir::AllocaOp::create(firBuilder, loc, ty); - return hlfir::DeclareOp::create(firBuilder, loc, alloca, varName); - }; + // Making a dynamic allocation of the size of the whole base instead of the + // section in case of section would lead to improper deallocation because + // generatePrivateDestroy always deallocates the start of the section when + // there is a section. + assert(!(needsDestroy && !bounds.empty() && !allocateSection) && + "dynamic allocation of the whole base in case of section is not " + "expected"); + + if (inputVar.getType() == alloc.getType() && !allocateSection) + return alloc; + + // Step4: reconstruct the input variable from the privatized part: + // - get a mock base address if the privatized part is a section (so that any + // addressing of the input variable can be replaced by the same addressing of + // the privatized part even though the allocated part for the private does not + // cover all the input variable storage. This is relying on OpenACC + // constraint that any addressing of such privatized variable inside the + // construct region can only address the variable inside the privatized + // section). + // - reconstruct a descriptor with the same bounds and type parameters as the + // input if needed. + // - store this new descriptor in a temporary allocation if the input variable + // is a POINTER/ALLOCATABLE. + llvm::SmallVector<mlir::Value> inputVarLowerBounds, inputVarExtents; + if (dereferencedVar.isArray()) { + for (int dim = 0; dim < dereferencedVar.getRank(); ++dim) { + inputVarLowerBounds.push_back( + hlfir::genLBound(loc, builder, dereferencedVar, dim)); + inputVarExtents.push_back( + hlfir::genExtent(loc, builder, dereferencedVar, dim)); + } + } - if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(unwrappedTy)) { - if (fir::isa_trivial(seqTy.getEleTy())) { - mlir::Value shape; - if (seqTy.hasDynamicExtents()) { - shape = fir::ShapeOp::create(firBuilder, loc, llvm::to_vector(extents)); - } else { - shape = genShapeOp(firBuilder, seqTy, loc); - } - auto alloca = fir::AllocaOp::create( - firBuilder, loc, seqTy, /*typeparams=*/mlir::ValueRange{}, extents); - auto declareOp = - hlfir::DeclareOp::create(firBuilder, loc, alloca, varName, shape); - - if (initVal) { - mlir::Type idxTy = firBuilder.getIndexType(); - mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy()); - llvm::SmallVector<fir::DoLoopOp> loops; - llvm::SmallVector<mlir::Value> ivs; - - if (seqTy.hasDynamicExtents()) { - hlfir::AssignOp::create(firBuilder, loc, initVal, - declareOp.getBase()); - } else { - // Generate loop nest from slowest to fastest running dimension - for (auto ext : llvm::reverse(seqTy.getShape())) { - auto lb = firBuilder.createIntegerConstant(loc, idxTy, 0); - auto ub = firBuilder.createIntegerConstant(loc, idxTy, ext - 1); - auto step = firBuilder.createIntegerConstant(loc, idxTy, 1); - auto loop = fir::DoLoopOp::create(firBuilder, loc, lb, ub, step, - /*unordered=*/false); - firBuilder.setInsertionPointToStart(loop.getBody()); - loops.push_back(loop); - ivs.push_back(loop.getInductionVar()); - } - // Reverse IVs to match CoordinateOp's canonical index order. - std::reverse(ivs.begin(), ivs.end()); - auto coord = fir::CoordinateOp::create(firBuilder, loc, refTy, - declareOp.getBase(), ivs); - fir::StoreOp::create(firBuilder, loc, initVal, coord); - firBuilder.setInsertionPointAfter(loops[0]); - } - } - retVal = declareOp.getBase(); + mlir::Value privateVarBaseAddr = alloc; + if (allocateSection) { + // To compute the mock base address without doing pointer arithmetic, + // compute: TYPE, TEMP(ZERO_BASED_SECTION_LB:) MOCK_BASE = TEMP(0) + // This addresses the section "backwards" (0 <= ZERO_BASED_SECTION_LB). This + // is currently OK, but care should be taken to avoid tripping bound checks + // if added in the future. + mlir::Type inputBaseAddrType = + dereferencedVar.getBoxType().getBaseAddressType(); + mlir::Value tempBaseAddr = + builder.createConvert(loc, inputBaseAddrType, alloc); + mlir::Value zero = + builder.createIntegerConstant(loc, builder.getIndexType(), 0); + llvm::SmallVector<mlir::Value> lowerBounds; + llvm::SmallVector<mlir::Value> zeros; + for (unsigned i = 0; i < bounds.size(); i += 3) { + lowerBounds.push_back(bounds[i]); + zeros.push_back(zero); } - } else if (auto boxTy = - mlir::dyn_cast_or_null<fir::BaseBoxType>(unwrappedTy)) { - mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy()); - if (fir::isa_trivial(innerTy)) { - retVal = getDeclareOpForType(unwrappedTy).getBase(); - mlir::Value allocatedScalar = - fir::AllocMemOp::create(builder, loc, innerTy); - mlir::Value firClass = - fir::EmboxOp::create(builder, loc, boxTy, allocatedScalar); - fir::StoreOp::create(builder, loc, firClass, retVal); - needsDestroy = true; - } else if (mlir::isa<fir::SequenceType>(innerTy)) { - hlfir::Entity source = hlfir::Entity{var}; - auto [temp, cleanupFlag] = - hlfir::createTempFromMold(loc, firBuilder, source); - if (fir::isa_ref_type(type)) { - // When the temp is created - it is not a reference - thus we can - // end up with a type inconsistency. Therefore ensure storage is created - // for it. - retVal = getDeclareOpForType(unwrappedTy).getBase(); - mlir::Value storeDst = retVal; - if (fir::unwrapRefType(retVal.getType()) != temp.getType()) { - // `createTempFromMold` makes the unfortunate choice to lose the - // `fir.heap` and `fir.ptr` types when wrapping with a box. Namely, - // when wrapping a `fir.heap<fir.array>`, it will create instead a - // `fir.box<fir.array>`. Cast here to deal with this inconsistency. - storeDst = firBuilder.createConvert( - loc, firBuilder.getRefType(temp.getType()), retVal); - } - fir::StoreOp::create(builder, loc, temp, storeDst); - } else { - retVal = temp; - } - // If heap was allocated, a destroy is required later. - if (cleanupFlag) - needsDestroy = true; + mlir::Value offsetShapeShift = + builder.genShape(loc, lowerBounds, inputVarExtents); + mlir::Type eleRefType = + builder.getRefType(privatizedVar.getFortranElementType()); + mlir::Value mockBase = fir::ArrayCoorOp::create( + builder, loc, eleRefType, tempBaseAddr, offsetShapeShift, + /*slice=*/mlir::Value{}, /*indices=*/zeros, + /*typeParams=*/mlir::ValueRange{}); + privateVarBaseAddr = + builder.createConvert(loc, inputBaseAddrType, mockBase); + } + + mlir::Value retVal = privateVarBaseAddr; + if (inputVar.isBoxAddressOrValue()) { + // Recreate descriptor with same bounds as the input variable. + mlir::Value shape; + if (!inputVarExtents.empty()) + shape = builder.genShape(loc, inputVarLowerBounds, inputVarExtents); + mlir::Value box = fir::EmboxOp::create(builder, loc, inputVar.getBoxType(), + privateVarBaseAddr, shape, + /*slice=*/mlir::Value{}, typeParams); + if (inputVar.isMutableBox()) { + mlir::Value boxAlloc = + fir::AllocaOp::create(builder, loc, inputVar.getBoxType()); + fir::StoreOp::create(builder, loc, box, boxAlloc); + retVal = boxAlloc; } else { - TODO(loc, "Unsupported boxed type for OpenACC private-like recipe"); - } - if (initVal) { - hlfir::AssignOp::create(builder, loc, initVal, retVal); + retVal = box; } - } else if (llvm::isa<fir::BoxCharType, fir::CharacterType>(unwrappedTy)) { - TODO(loc, "Character type for OpenACC private-like recipe"); - } else { - assert((fir::isa_trivial(unwrappedTy) || - llvm::isa<fir::RecordType>(unwrappedTy)) && - "expected numerical, logical, and derived type without length " - "parameters"); - auto declareOp = getDeclareOpForType(unwrappedTy); - if (initVal && fir::isa_trivial(unwrappedTy)) { - auto convert = firBuilder.createConvert(loc, unwrappedTy, initVal); - fir::StoreOp::create(firBuilder, loc, convert, declareOp.getBase()); - } else if (initVal) { - // hlfir.assign with temporary LHS flag should just do it. Not implemented - // because not clear it is needed, so cannot be tested. - TODO(loc, "initial value for derived type in private-like recipe"); - } - retVal = declareOp.getBase(); } return retVal; } @@ -714,42 +907,249 @@ OpenACCMappableModel<fir::PointerType>::generatePrivateInit( mlir::ValueRange extents, mlir::Value initVal, bool &needsDestroy) const; template <typename Ty> +bool OpenACCMappableModel<Ty>::generateCopy( + mlir::Type type, mlir::OpBuilder &mlirBuilder, mlir::Location loc, + mlir::TypedValue<mlir::acc::MappableType> src, + mlir::TypedValue<mlir::acc::MappableType> dest, + mlir::ValueRange bounds) const { + mlir::ModuleOp mod = + mlirBuilder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>(); + assert(mod && "failed to retrieve parent module"); + fir::FirOpBuilder builder(mlirBuilder, mod); + hlfir::Entity source{src}; + hlfir::Entity destination{dest}; + + source = hlfir::derefPointersAndAllocatables(loc, builder, source); + destination = hlfir::derefPointersAndAllocatables(loc, builder, destination); + + if (!bounds.empty()) + std::tie(source, destination) = + genArraySectionsInRecipe(builder, loc, bounds, source, destination); + // The source and the destination of the firstprivate copy cannot alias, + // the destination is already properly allocated, so a simple assignment + // can be generated right away to avoid ending-up with runtime calls + // for arrays of numerical, logical and, character types. + // + // The temporary_lhs flag allows indicating that user defined assignments + // should not be called while copying components, and that the LHS and RHS + // are known to not alias since the LHS is a created object. + // + // TODO: detect cases where user defined assignment is needed and add a TODO. + // using temporary_lhs allows more aggressive optimizations of simple derived + // types. Existing compilers supporting OpenACC do not call user defined + // assignments, some use case is needed to decide what to do. + source = hlfir::loadTrivialScalar(loc, builder, source); + hlfir::AssignOp::create(builder, loc, source, destination, /*realloc=*/false, + /*keep_lhs_length_if_realloc=*/false, + /*temporary_lhs=*/true); + return true; +} + +template bool OpenACCMappableModel<fir::BaseBoxType>::generateCopy( + mlir::Type, mlir::OpBuilder &, mlir::Location, + mlir::TypedValue<mlir::acc::MappableType>, + mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange) const; +template bool OpenACCMappableModel<fir::ReferenceType>::generateCopy( + mlir::Type, mlir::OpBuilder &, mlir::Location, + mlir::TypedValue<mlir::acc::MappableType>, + mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange) const; +template bool OpenACCMappableModel<fir::PointerType>::generateCopy( + mlir::Type, mlir::OpBuilder &, mlir::Location, + mlir::TypedValue<mlir::acc::MappableType>, + mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange) const; +template bool OpenACCMappableModel<fir::HeapType>::generateCopy( + mlir::Type, mlir::OpBuilder &, mlir::Location, + mlir::TypedValue<mlir::acc::MappableType>, + mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange) const; + +template <typename Op> +static mlir::Value genLogicalCombiner(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value value1, + mlir::Value value2) { + mlir::Type i1 = builder.getI1Type(); + mlir::Value v1 = fir::ConvertOp::create(builder, loc, i1, value1); + mlir::Value v2 = fir::ConvertOp::create(builder, loc, i1, value2); + mlir::Value combined = Op::create(builder, loc, v1, v2); + return fir::ConvertOp::create(builder, loc, value1.getType(), combined); +} + +static mlir::Value genComparisonCombiner(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::arith::CmpIPredicate pred, + mlir::Value value1, + mlir::Value value2) { + mlir::Type i1 = builder.getI1Type(); + mlir::Value v1 = fir::ConvertOp::create(builder, loc, i1, value1); + mlir::Value v2 = fir::ConvertOp::create(builder, loc, i1, value2); + mlir::Value add = mlir::arith::CmpIOp::create(builder, loc, pred, v1, v2); + return fir::ConvertOp::create(builder, loc, value1.getType(), add); +} + +static mlir::Value genScalarCombiner(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::acc::ReductionOperator op, + mlir::Type ty, mlir::Value value1, + mlir::Value value2) { + value1 = builder.loadIfRef(loc, value1); + value2 = builder.loadIfRef(loc, value2); + if (op == mlir::acc::ReductionOperator::AccAdd) { + if (ty.isIntOrIndex()) + return mlir::arith::AddIOp::create(builder, loc, value1, value2); + if (mlir::isa<mlir::FloatType>(ty)) + return mlir::arith::AddFOp::create(builder, loc, value1, value2); + if (auto cmplxTy = mlir::dyn_cast_or_null<mlir::ComplexType>(ty)) + return fir::AddcOp::create(builder, loc, value1, value2); + TODO(loc, "reduction add type"); + } + + if (op == mlir::acc::ReductionOperator::AccMul) { + if (ty.isIntOrIndex()) + return mlir::arith::MulIOp::create(builder, loc, value1, value2); + if (mlir::isa<mlir::FloatType>(ty)) + return mlir::arith::MulFOp::create(builder, loc, value1, value2); + if (mlir::isa<mlir::ComplexType>(ty)) + return fir::MulcOp::create(builder, loc, value1, value2); + TODO(loc, "reduction mul type"); + } + + if (op == mlir::acc::ReductionOperator::AccMin) + return fir::genMin(builder, loc, {value1, value2}); + + if (op == mlir::acc::ReductionOperator::AccMax) + return fir::genMax(builder, loc, {value1, value2}); + + if (op == mlir::acc::ReductionOperator::AccIand) + return mlir::arith::AndIOp::create(builder, loc, value1, value2); + + if (op == mlir::acc::ReductionOperator::AccIor) + return mlir::arith::OrIOp::create(builder, loc, value1, value2); + + if (op == mlir::acc::ReductionOperator::AccXor) + return mlir::arith::XOrIOp::create(builder, loc, value1, value2); + + if (op == mlir::acc::ReductionOperator::AccLand) + return genLogicalCombiner<mlir::arith::AndIOp>(builder, loc, value1, + value2); + + if (op == mlir::acc::ReductionOperator::AccLor) + return genLogicalCombiner<mlir::arith::OrIOp>(builder, loc, value1, value2); + + if (op == mlir::acc::ReductionOperator::AccEqv) + return genComparisonCombiner(builder, loc, mlir::arith::CmpIPredicate::eq, + value1, value2); + + if (op == mlir::acc::ReductionOperator::AccNeqv) + return genComparisonCombiner(builder, loc, mlir::arith::CmpIPredicate::ne, + value1, value2); + + TODO(loc, "reduction operator"); +} + +template <typename Ty> +bool OpenACCMappableModel<Ty>::generateCombiner( + mlir::Type type, mlir::OpBuilder &mlirBuilder, mlir::Location loc, + mlir::TypedValue<mlir::acc::MappableType> dest, + mlir::TypedValue<mlir::acc::MappableType> source, mlir::ValueRange bounds, + mlir::acc::ReductionOperator op, mlir::Attribute fastmathFlags) const { + mlir::ModuleOp mod = + mlirBuilder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>(); + assert(mod && "failed to retrieve parent module"); + fir::FirOpBuilder builder(mlirBuilder, mod); + if (fastmathFlags) + if (auto fastMathAttr = + mlir::dyn_cast<mlir::arith::FastMathFlagsAttr>(fastmathFlags)) + builder.setFastMathFlags(fastMathAttr.getValue()); + // Generate loops that combine and assign the inputs into dest (or array + // section of the inputs when there are bounds). + hlfir::Entity srcSection{source}; + hlfir::Entity destSection{dest}; + if (!bounds.empty()) { + std::tie(srcSection, destSection) = + genArraySectionsInRecipe(builder, loc, bounds, srcSection, destSection); + } + + mlir::Type elementType = fir::getFortranElementType(dest.getType()); + auto genKernel = [&](mlir::Location l, fir::FirOpBuilder &b, + hlfir::Entity srcElementValue, + hlfir::Entity destElementValue) -> hlfir::Entity { + return hlfir::Entity{genScalarCombiner(builder, loc, op, elementType, + srcElementValue, destElementValue)}; + }; + hlfir::genNoAliasAssignment(loc, builder, srcSection, destSection, + /*emitWorkshareLoop=*/false, + /*temporaryLHS=*/false, genKernel); + return true; +} + +template bool OpenACCMappableModel<fir::BaseBoxType>::generateCombiner( + mlir::Type, mlir::OpBuilder &, mlir::Location, + mlir::TypedValue<mlir::acc::MappableType>, + mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange, + mlir::acc::ReductionOperator op, mlir::Attribute) const; +template bool OpenACCMappableModel<fir::ReferenceType>::generateCombiner( + mlir::Type, mlir::OpBuilder &, mlir::Location, + mlir::TypedValue<mlir::acc::MappableType>, + mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange, + mlir::acc::ReductionOperator op, mlir::Attribute) const; +template bool OpenACCMappableModel<fir::PointerType>::generateCombiner( + mlir::Type, mlir::OpBuilder &, mlir::Location, + mlir::TypedValue<mlir::acc::MappableType>, + mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange, + mlir::acc::ReductionOperator op, mlir::Attribute) const; +template bool OpenACCMappableModel<fir::HeapType>::generateCombiner( + mlir::Type, mlir::OpBuilder &, mlir::Location, + mlir::TypedValue<mlir::acc::MappableType>, + mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange, + mlir::acc::ReductionOperator op, mlir::Attribute) const; + +template <typename Ty> bool OpenACCMappableModel<Ty>::generatePrivateDestroy( - mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, - mlir::Value privatized) const { - mlir::Type unwrappedTy = fir::unwrapRefType(type); - // For boxed scalars allocated with AllocMem during init, free the heap. - if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(unwrappedTy)) { - mlir::Value boxVal = privatized; - if (fir::isa_ref_type(boxVal.getType())) - boxVal = fir::LoadOp::create(builder, loc, boxVal); - mlir::Value addr = fir::BoxAddrOp::create(builder, loc, boxVal); - // FreeMem only accepts fir.heap and this may not be represented in the box - // type if the privatized entity is not an allocatable. + mlir::Type type, mlir::OpBuilder &mlirBuilder, mlir::Location loc, + mlir::Value privatized, mlir::ValueRange bounds) const { + hlfir::Entity inputVar = hlfir::Entity{privatized}; + mlir::ModuleOp mod = + mlirBuilder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>(); + assert(mod && "failed to retrieve parent module"); + fir::FirOpBuilder builder(mlirBuilder, mod); + auto genFreeRawAddress = [&](hlfir::Entity entity) { + mlir::Value addr = hlfir::genVariableRawAddress(loc, builder, entity); mlir::Type heapType = fir::HeapType::get(fir::unwrapRefType(addr.getType())); if (heapType != addr.getType()) addr = fir::ConvertOp::create(builder, loc, heapType, addr); fir::FreeMemOp::create(builder, loc, addr); + }; + if (bounds.empty()) { + genFreeRawAddress(inputVar); return true; } - - // Nothing to do for other categories by default, they are stack allocated. + // The input variable is an array section, the base address is not the real + // allocation. Compute the section base address and deallocate that. + hlfir::Entity dereferencedVar = + hlfir::derefPointersAndAllocatables(loc, builder, inputVar); + hlfir::DesignateOp::Subscripts triplets; + auto [tempShape, tempExtents] = + computeSectionShapeAndExtents(builder, loc, bounds); + (void)tempExtents; + triplets = genTripletsFromAccBounds(builder, loc, bounds, dereferencedVar); + hlfir::Entity arraySection = genDesignateWithTriplets( + builder, loc, dereferencedVar, triplets, tempShape, tempExtents); + genFreeRawAddress(arraySection); return true; } template bool OpenACCMappableModel<fir::BaseBoxType>::generatePrivateDestroy( mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, - mlir::Value privatized) const; + mlir::Value privatized, mlir::ValueRange bounds) const; template bool OpenACCMappableModel<fir::ReferenceType>::generatePrivateDestroy( mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, - mlir::Value privatized) const; + mlir::Value privatized, mlir::ValueRange bounds) const; template bool OpenACCMappableModel<fir::HeapType>::generatePrivateDestroy( mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, - mlir::Value privatized) const; + mlir::Value privatized, mlir::ValueRange bounds) const; template bool OpenACCMappableModel<fir::PointerType>::generatePrivateDestroy( mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, - mlir::Value privatized) const; + mlir::Value privatized, mlir::ValueRange bounds) const; template <typename Ty> mlir::Value OpenACCPointerLikeModel<Ty>::genAllocate( @@ -825,41 +1225,6 @@ template mlir::Value OpenACCPointerLikeModel<fir::LLVMPointerType>::genAllocate( llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar, bool &needsFree) const; -static mlir::Value stripCasts(mlir::Value value, bool stripDeclare = true) { - mlir::Value currentValue = value; - - while (currentValue) { - auto *definingOp = currentValue.getDefiningOp(); - if (!definingOp) - break; - - if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(definingOp)) { - currentValue = convertOp.getValue(); - continue; - } - - if (auto viewLike = mlir::dyn_cast<mlir::ViewLikeOpInterface>(definingOp)) { - currentValue = viewLike.getViewSource(); - continue; - } - - if (stripDeclare) { - if (auto declareOp = mlir::dyn_cast<hlfir::DeclareOp>(definingOp)) { - currentValue = declareOp.getMemref(); - continue; - } - - if (auto declareOp = mlir::dyn_cast<fir::DeclareOp>(definingOp)) { - currentValue = declareOp.getMemref(); - continue; - } - } - break; - } - - return currentValue; -} - template <typename Ty> bool OpenACCPointerLikeModel<Ty>::genFree( mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, @@ -887,7 +1252,7 @@ bool OpenACCPointerLikeModel<Ty>::genFree( mlir::Value valueToInspect = allocRes ? allocRes : varToFree; // Strip casts and declare operations to find the original allocation - mlir::Value strippedValue = stripCasts(valueToInspect); + mlir::Value strippedValue = fir::acc::getOriginalDef(valueToInspect); mlir::Operation *originalAlloc = strippedValue.getDefiningOp(); // If we found an AllocMemOp (heap allocation), free it @@ -992,4 +1357,232 @@ template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::genCopy( mlir::TypedValue<mlir::acc::PointerLikeType> source, mlir::Type varType) const; +template <typename Ty> +mlir::Value OpenACCPointerLikeModel<Ty>::genLoad( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> srcPtr, + mlir::Type valueType) const { + + // Unwrap to get the pointee type. + mlir::Type pointeeTy = fir::dyn_cast_ptrEleTy(pointer); + assert(pointeeTy && "expected pointee type to be extractable"); + + // Box types contain both a descriptor and referenced data. The genLoad API + // handles simple loads and cannot properly manage both parts. + if (fir::isa_box_type(pointeeTy)) + return {}; + + // Unlimited polymorphic (class(*)) cannot be handled because type is unknown. + if (fir::isUnlimitedPolymorphicType(pointeeTy)) + return {}; + + // Return empty for dynamic size types because the load logic + // cannot be determined simply from the type. + if (fir::hasDynamicSize(pointeeTy)) + return {}; + + mlir::Value loadedValue = fir::LoadOp::create(builder, loc, srcPtr); + + // If valueType is provided and differs from the loaded type, insert a convert + if (valueType && loadedValue.getType() != valueType) + return fir::ConvertOp::create(builder, loc, valueType, loadedValue); + + return loadedValue; +} + +template mlir::Value OpenACCPointerLikeModel<fir::ReferenceType>::genLoad( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> srcPtr, + mlir::Type valueType) const; + +template mlir::Value OpenACCPointerLikeModel<fir::PointerType>::genLoad( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> srcPtr, + mlir::Type valueType) const; + +template mlir::Value OpenACCPointerLikeModel<fir::HeapType>::genLoad( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> srcPtr, + mlir::Type valueType) const; + +template mlir::Value OpenACCPointerLikeModel<fir::LLVMPointerType>::genLoad( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> srcPtr, + mlir::Type valueType) const; + +template <typename Ty> +bool OpenACCPointerLikeModel<Ty>::genStore( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value valueToStore, + mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const { + + // Unwrap to get the pointee type. + mlir::Type pointeeTy = fir::dyn_cast_ptrEleTy(pointer); + assert(pointeeTy && "expected pointee type to be extractable"); + + // Box types contain both a descriptor and referenced data. The genStore API + // handles simple stores and cannot properly manage both parts. + if (fir::isa_box_type(pointeeTy)) + return false; + + // Unlimited polymorphic (class(*)) cannot be handled because type is unknown. + if (fir::isUnlimitedPolymorphicType(pointeeTy)) + return false; + + // Return false for dynamic size types because the store logic + // cannot be determined simply from the type. + if (fir::hasDynamicSize(pointeeTy)) + return false; + + // Get the type from the value being stored + mlir::Type valueType = valueToStore.getType(); + mlir::Value convertedValue = valueToStore; + + // If the value type differs from the pointee type, insert a convert + if (valueType != pointeeTy) + convertedValue = + fir::ConvertOp::create(builder, loc, pointeeTy, valueToStore); + + fir::StoreOp::create(builder, loc, convertedValue, destPtr); + return true; +} + +template bool OpenACCPointerLikeModel<fir::ReferenceType>::genStore( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value valueToStore, + mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const; + +template bool OpenACCPointerLikeModel<fir::PointerType>::genStore( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value valueToStore, + mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const; + +template bool OpenACCPointerLikeModel<fir::HeapType>::genStore( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value valueToStore, + mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const; + +template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::genStore( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value valueToStore, + mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const; + +/// Check CUDA attributes on a function argument. +static bool hasCUDADeviceAttrOnFuncArg(mlir::BlockArgument blockArg) { + auto *owner = blockArg.getOwner(); + if (!owner) + return false; + + auto *parentOp = owner->getParentOp(); + if (!parentOp) + return false; + + if (auto funcLike = mlir::dyn_cast<mlir::FunctionOpInterface>(parentOp)) { + unsigned argIndex = blockArg.getArgNumber(); + if (argIndex < funcLike.getNumArguments()) + if (auto attr = funcLike.getArgAttr(argIndex, cuf::getDataAttrName())) + if (auto cudaAttr = mlir::dyn_cast<cuf::DataAttributeAttr>(attr)) + return cuf::isDeviceDataAttribute(cudaAttr.getValue()); + } + return false; +} + +/// Shared implementation for checking if a value represents device data. +static bool isDeviceDataImpl(mlir::Value var) { + // Strip casts to find the underlying value. + mlir::Value currentVal = + fir::acc::getOriginalDef(var, /*stripDeclare=*/false); + + if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(currentVal)) + return hasCUDADeviceAttrOnFuncArg(blockArg); + + mlir::Operation *defOp = currentVal.getDefiningOp(); + assert(defOp && "expected defining op for non-block-argument value"); + + // Check for CUDA attributes on the defining operation. + if (cuf::hasDeviceDataAttr(defOp)) + return true; + + // Handle operations that access a partial entity - check if the base entity + // is device data. + if (auto partialAccess = + mlir::dyn_cast<mlir::acc::PartialEntityAccessOpInterface>(defOp)) + if (mlir::Value base = partialAccess.getBaseEntity()) + return isDeviceDataImpl(base); + + // Handle fir.embox, fir.rebox, and similar ops via + // FortranObjectViewOpInterface to check if the underlying source is device + // data. + if (auto viewOp = mlir::dyn_cast<fir::FortranObjectViewOpInterface>(defOp)) + if (mlir::Value source = viewOp.getViewSource(defOp->getResult(0))) + return isDeviceDataImpl(source); + + // Handle address_of - check the referenced global. + if (auto addrOfIface = + mlir::dyn_cast<mlir::acc::AddressOfGlobalOpInterface>(defOp)) { + auto symbol = addrOfIface.getSymbol(); + if (auto global = mlir::SymbolTable::lookupNearestSymbolFrom< + mlir::acc::GlobalVariableOpInterface>(defOp, symbol)) + return global.isDeviceData(); + return false; + } + + return false; +} + +template <typename Ty> +bool OpenACCPointerLikeModel<Ty>::isDeviceData(mlir::Type pointer, + mlir::Value var) const { + return isDeviceDataImpl(var); +} + +template bool OpenACCPointerLikeModel<fir::ReferenceType>::isDeviceData( + mlir::Type, mlir::Value) const; +template bool + OpenACCPointerLikeModel<fir::PointerType>::isDeviceData(mlir::Type, + mlir::Value) const; +template bool + OpenACCPointerLikeModel<fir::HeapType>::isDeviceData(mlir::Type, + mlir::Value) const; +template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::isDeviceData( + mlir::Type, mlir::Value) const; + +template <typename Ty> +bool OpenACCMappableModel<Ty>::isDeviceData(mlir::Type type, + mlir::Value var) const { + return isDeviceDataImpl(var); +} + +template bool + OpenACCMappableModel<fir::BaseBoxType>::isDeviceData(mlir::Type, + mlir::Value) const; +template bool + OpenACCMappableModel<fir::ReferenceType>::isDeviceData(mlir::Type, + mlir::Value) const; +template bool + OpenACCMappableModel<fir::HeapType>::isDeviceData(mlir::Type, + mlir::Value) const; +template bool + OpenACCMappableModel<fir::PointerType>::isDeviceData(mlir::Type, + mlir::Value) const; + +std::optional<mlir::arith::AtomicRMWKind> +OpenACCReducibleLogicalModel::getAtomicRMWKind( + mlir::Type type, mlir::acc::ReductionOperator redOp) const { + switch (redOp) { + case mlir::acc::ReductionOperator::AccLand: + return mlir::arith::AtomicRMWKind::andi; + case mlir::acc::ReductionOperator::AccLor: + return mlir::arith::AtomicRMWKind::ori; + case mlir::acc::ReductionOperator::AccEqv: + case mlir::acc::ReductionOperator::AccNeqv: + // Eqv and Neqv are valid for logical types but don't have a direct + // AtomicRMWKind mapping yet. + return std::nullopt; + default: + // Other reduction operators are not valid for logical types. + return std::nullopt; + } +} + } // namespace fir::acc diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCUtils.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCUtils.cpp new file mode 100644 index 0000000..a53ea92 --- /dev/null +++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCUtils.cpp @@ -0,0 +1,655 @@ +//===- FIROpenACCUtils.cpp - FIR OpenACC Utilities ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utility functions for FIR OpenACC support. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/OpenACC/Support/FIROpenACCUtils.h" +#include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Optimizer/Builder/Complex.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIROpsSupport.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Dialect/Support/FIRContext.h" +#include "flang/Optimizer/Dialect/Support/KindMapping.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Optimizer/Support/InternalNames.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Dialect/OpenACC/OpenACCUtils.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +static constexpr llvm::StringRef accPrivateInitName = "acc.private.init"; +static constexpr llvm::StringRef accReductionInitName = "acc.reduction.init"; + +std::string fir::acc::getVariableName(Value v, bool preferDemangledName) { + std::string srcName; + std::string prefix; + llvm::SmallVector<std::string, 4> arrayIndices; + bool iterate = true; + mlir::Operation *defOp; + + // For integer constants, no need to further iterate - print their value + // immediately. + if (v.getDefiningOp()) { + IntegerAttr::ValueType val; + if (matchPattern(v.getDefiningOp(), m_ConstantInt(&val))) { + llvm::raw_string_ostream os(prefix); + val.print(os, /*isSigned=*/true); + return prefix; + } + } + + while (v && (defOp = v.getDefiningOp()) && iterate) { + iterate = + llvm::TypeSwitch<mlir::Operation *, bool>(defOp) + .Case([&v](mlir::ViewLikeOpInterface op) { + v = op.getViewSource(); + return true; + }) + .Case([&v](fir::ReboxOp op) { + v = op.getBox(); + return true; + }) + .Case([&v](fir::EmboxOp op) { + v = op.getMemref(); + return true; + }) + .Case([&v](fir::ConvertOp op) { + v = op.getValue(); + return true; + }) + .Case([&v](fir::LoadOp op) { + v = op.getMemref(); + return true; + }) + .Case([&v](fir::BoxAddrOp op) { + // The box holds the name of the variable. + v = op.getVal(); + return true; + }) + .Case([&](fir::AddrOfOp op) { + // Only use address_of symbol if mangled name is preferred + if (!preferDemangledName) { + auto symRef = op.getSymbol(); + srcName = symRef.getLeafReference().getValue().str(); + } + return false; + }) + .Case([&](fir::ArrayCoorOp op) { + v = op.getMemref(); + for (auto coor : op.getIndices()) { + auto idxName = getVariableName(coor, preferDemangledName); + arrayIndices.push_back(idxName.empty() ? "?" : idxName); + } + return true; + }) + .Case([&](fir::CoordinateOp op) { + std::optional<llvm::ArrayRef<int32_t>> fieldIndices = + op.getFieldIndices(); + if (fieldIndices && fieldIndices->size() > 0 && + (*fieldIndices)[0] != fir::CoordinateOp::kDynamicIndex) { + int fieldId = (*fieldIndices)[0]; + mlir::Type baseType = + fir::getFortranElementType(op.getRef().getType()); + if (auto recType = llvm::dyn_cast<fir::RecordType>(baseType)) { + srcName = recType.getTypeList()[fieldId].first; + } + } + if (!srcName.empty()) { + // If the field name is known - attempt to continue building + // name by looking at its parents. + prefix = + getVariableName(op.getRef(), preferDemangledName) + "%"; + } + return false; + }) + .Case([&](hlfir::DesignateOp op) { + if (op.getComponent()) { + srcName = op.getComponent().value().str(); + prefix = + getVariableName(op.getMemref(), preferDemangledName) + "%"; + return false; + } + for (auto coor : op.getIndices()) { + auto idxName = getVariableName(coor, preferDemangledName); + arrayIndices.push_back(idxName.empty() ? "?" : idxName); + } + v = op.getMemref(); + return true; + }) + .Case<fir::DeclareOp, hlfir::DeclareOp>([&](auto op) { + srcName = op.getUniqName().str(); + return false; + }) + .Case([&](fir::AllocaOp op) { + if (preferDemangledName) { + // Prefer demangled name (bindc_name over uniq_name) + srcName = op.getBindcName() ? *op.getBindcName() + : op.getUniqName() ? *op.getUniqName() + : ""; + } else { + // Prefer mangled name (uniq_name over bindc_name) + srcName = op.getUniqName() ? *op.getUniqName() + : op.getBindcName() ? *op.getBindcName() + : ""; + } + return false; + }) + .Default([](mlir::Operation *) { return false; }); + } + + // Fallback to the default implementation. + if (srcName.empty()) + return mlir::acc::getVariableName(v); + + // Build array index suffix if present + std::string suffix; + if (!arrayIndices.empty()) { + llvm::raw_string_ostream os(suffix); + os << "("; + llvm::interleaveComma(arrayIndices, os); + os << ")"; + } + + // Names from FIR operations may be mangled. + // When the demangled name is requested - demangle it. + if (preferDemangledName) { + auto [kind, deconstructed] = fir::NameUniquer::deconstruct(srcName); + if (kind != fir::NameUniquer::NameKind::NOT_UNIQUED) + return prefix + deconstructed.name + suffix; + } + + return prefix + srcName + suffix; +} + +bool fir::acc::areAllBoundsConstant(llvm::ArrayRef<Value> bounds) { + for (auto bound : bounds) { + auto dataBound = + mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp()); + if (!dataBound) + return false; + + // Check if this bound has constant values + bool hasConstant = false; + if (dataBound.getLowerbound() && dataBound.getUpperbound()) + hasConstant = + fir::getIntIfConstant(dataBound.getLowerbound()).has_value() && + fir::getIntIfConstant(dataBound.getUpperbound()).has_value(); + else if (dataBound.getExtent()) + hasConstant = fir::getIntIfConstant(dataBound.getExtent()).has_value(); + + if (!hasConstant) + return false; + } + return true; +} + +static std::string getBoundsString(llvm::ArrayRef<Value> bounds) { + if (bounds.empty()) + return ""; + + std::string boundStr; + llvm::raw_string_ostream os(boundStr); + os << "_section_"; + + llvm::interleave( + bounds, + [&](Value bound) { + auto boundsOp = + mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp()); + if (boundsOp.getLowerbound() && + fir::getIntIfConstant(boundsOp.getLowerbound()) && + boundsOp.getUpperbound() && + fir::getIntIfConstant(boundsOp.getUpperbound())) { + os << "lb" << *fir::getIntIfConstant(boundsOp.getLowerbound()) + << ".ub" << *fir::getIntIfConstant(boundsOp.getUpperbound()); + } else if (boundsOp.getExtent() && + fir::getIntIfConstant(boundsOp.getExtent())) { + os << "ext" << *fir::getIntIfConstant(boundsOp.getExtent()); + } else { + os << "?"; + } + }, + [&] { os << "x"; }); + + return os.str(); +} + +static std::string getRecipeName(mlir::acc::RecipeKind kind, Type type, + const fir::KindMapping &kindMap, + llvm::ArrayRef<Value> bounds, + mlir::acc::ReductionOperator reductionOp = + mlir::acc::ReductionOperator::AccNone) { + assert(fir::isa_fir_type(type) && "getRecipeName expects a FIR type"); + + // Build the complete prefix with all components before calling + // getTypeAsString + std::string prefixStr; + llvm::raw_string_ostream prefixOS(prefixStr); + + switch (kind) { + case mlir::acc::RecipeKind::private_recipe: + prefixOS << "privatization"; + break; + case mlir::acc::RecipeKind::firstprivate_recipe: + prefixOS << "firstprivatization"; + break; + case mlir::acc::RecipeKind::reduction_recipe: + prefixOS << "reduction"; + // Embed the reduction operator in the prefix + if (reductionOp != mlir::acc::ReductionOperator::AccNone) + prefixOS << "_" + << mlir::acc::stringifyReductionOperator(reductionOp).str(); + break; + } + + if (!bounds.empty()) + prefixOS << getBoundsString(bounds); + + return fir::getTypeAsString(type, kindMap, prefixOS.str()); +} + +std::string fir::acc::getRecipeName(mlir::acc::RecipeKind kind, Type type, + Value var, llvm::ArrayRef<Value> bounds, + mlir::acc::ReductionOperator reductionOp) { + auto kindMap = var && var.getDefiningOp() + ? fir::getKindMapping(var.getDefiningOp()) + : fir::KindMapping(type.getContext()); + return ::getRecipeName(kind, type, kindMap, bounds, reductionOp); +} + +/// Get the initial value for reduction operator. +template <typename R> +static R getReductionInitValue(mlir::acc::ReductionOperator op, mlir::Type ty) { + if (op == mlir::acc::ReductionOperator::AccMin) { + // min init value -> largest + if constexpr (std::is_same_v<R, llvm::APInt>) { + assert(ty.isIntOrIndex() && "expect integer or index type"); + return llvm::APInt::getSignedMaxValue(ty.getIntOrFloatBitWidth()); + } + if constexpr (std::is_same_v<R, llvm::APFloat>) { + auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty); + assert(floatTy && "expect float type"); + return llvm::APFloat::getLargest(floatTy.getFloatSemantics(), + /*negative=*/false); + } + } else if (op == mlir::acc::ReductionOperator::AccMax) { + // max init value -> smallest + if constexpr (std::is_same_v<R, llvm::APInt>) { + assert(ty.isIntOrIndex() && "expect integer or index type"); + return llvm::APInt::getSignedMinValue(ty.getIntOrFloatBitWidth()); + } + if constexpr (std::is_same_v<R, llvm::APFloat>) { + auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty); + assert(floatTy && "expect float type"); + return llvm::APFloat::getSmallest(floatTy.getFloatSemantics(), + /*negative=*/true); + } + } else if (op == mlir::acc::ReductionOperator::AccIand) { + if constexpr (std::is_same_v<R, llvm::APInt>) { + assert(ty.isIntOrIndex() && "expect integer type"); + unsigned bits = ty.getIntOrFloatBitWidth(); + return llvm::APInt::getAllOnes(bits); + } + } else { + assert(op != mlir::acc::ReductionOperator::AccNone); + // +, ior, ieor init value -> 0 + // * init value -> 1 + int64_t value = (op == mlir::acc::ReductionOperator::AccMul) ? 1 : 0; + if constexpr (std::is_same_v<R, llvm::APInt>) { + assert(ty.isIntOrIndex() && "expect integer or index type"); + return llvm::APInt(ty.getIntOrFloatBitWidth(), value, true); + } + + if constexpr (std::is_same_v<R, llvm::APFloat>) { + assert(mlir::isa<mlir::FloatType>(ty) && "expect float type"); + auto floatTy = mlir::dyn_cast<mlir::FloatType>(ty); + return llvm::APFloat(floatTy.getFloatSemantics(), value); + } + + if constexpr (std::is_same_v<R, int64_t>) + return value; + } + llvm_unreachable("OpenACC reduction unsupported type"); +} + +/// Return a constant with the initial value for the reduction operator and +/// type combination. +static mlir::Value getReductionInitValue(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Type varType, + mlir::acc::ReductionOperator op) { + mlir::Type ty = fir::getFortranElementType(varType); + if (op == mlir::acc::ReductionOperator::AccLand || + op == mlir::acc::ReductionOperator::AccLor || + op == mlir::acc::ReductionOperator::AccEqv || + op == mlir::acc::ReductionOperator::AccNeqv) { + assert(mlir::isa<fir::LogicalType>(ty) && "expect fir.logical type"); + bool value = true; // .true. for .and. and .eqv. + if (op == mlir::acc::ReductionOperator::AccLor || + op == mlir::acc::ReductionOperator::AccNeqv) + value = false; // .false. for .or. and .neqv. + return builder.createBool(loc, value); + } + if (ty.isIntOrIndex()) + return mlir::arith::ConstantOp::create( + builder, loc, ty, + builder.getIntegerAttr(ty, getReductionInitValue<llvm::APInt>(op, ty))); + if (op == mlir::acc::ReductionOperator::AccMin || + op == mlir::acc::ReductionOperator::AccMax) { + if (mlir::isa<mlir::ComplexType>(ty)) + llvm::report_fatal_error( + "min/max reduction not supported for complex type"); + if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty)) + return mlir::arith::ConstantOp::create( + builder, loc, ty, + builder.getFloatAttr(ty, + getReductionInitValue<llvm::APFloat>(op, ty))); + } else if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty)) { + return mlir::arith::ConstantOp::create( + builder, loc, ty, + builder.getFloatAttr(ty, getReductionInitValue<int64_t>(op, ty))); + } else if (auto cmplxTy = mlir::dyn_cast_or_null<mlir::ComplexType>(ty)) { + mlir::Type floatTy = cmplxTy.getElementType(); + mlir::Value realInit = builder.createRealConstant( + loc, floatTy, getReductionInitValue<int64_t>(op, cmplxTy)); + mlir::Value imagInit = builder.createRealConstant(loc, floatTy, 0.0); + return fir::factory::Complex{builder, loc}.createComplex(cmplxTy, realInit, + imagInit); + } + llvm::report_fatal_error("Unsupported OpenACC reduction type"); +} + +static llvm::SmallVector<mlir::Value> +getRecipeBounds(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::ValueRange dataBoundOps, + mlir::ValueRange blockBoundArgs) { + if (dataBoundOps.empty()) + return {}; + mlir::Type idxTy = builder.getIndexType(); + mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); + llvm::SmallVector<mlir::Value> bounds; + if (!blockBoundArgs.empty()) { + for (unsigned i = 0; i + 2 < blockBoundArgs.size(); i += 3) { + bounds.push_back(blockBoundArgs[i]); + bounds.push_back(blockBoundArgs[i + 1]); + // acc data bound strides is the inner size in bytes or elements, but + // sections are always 1-based, so there is no need to try to compute + // that back from the acc bounds. + bounds.push_back(one); + } + return bounds; + } + for (auto bound : dataBoundOps) { + auto dataBound = llvm::dyn_cast_if_present<mlir::acc::DataBoundsOp>( + bound.getDefiningOp()); + assert(dataBound && "expect acc bounds to be produced by DataBoundsOp"); + assert( + dataBound.getLowerbound() && dataBound.getUpperbound() && + "expect acc bounds for Fortran to always have lower and upper bounds"); + std::optional<std::int64_t> lb = + fir::getIntIfConstant(dataBound.getLowerbound()); + std::optional<std::int64_t> ub = + fir::getIntIfConstant(dataBound.getUpperbound()); + assert(lb.has_value() && ub.has_value() && + "must get constant bounds when there are no bound block arguments"); + bounds.push_back(builder.createIntegerConstant(loc, idxTy, *lb)); + bounds.push_back(builder.createIntegerConstant(loc, idxTy, *ub)); + bounds.push_back(one); + } + return bounds; +} + +static void addRecipeBoundsArgs(llvm::SmallVector<mlir::Value> &bounds, + bool allConstantBound, + llvm::SmallVector<mlir::Type> &argsTy, + llvm::SmallVector<mlir::Location> &argsLoc) { + if (!allConstantBound) { + for (mlir::Value bound : llvm::reverse(bounds)) { + auto dataBound = + mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp()); + argsTy.push_back(dataBound.getLowerbound().getType()); + argsLoc.push_back(dataBound.getLowerbound().getLoc()); + argsTy.push_back(dataBound.getUpperbound().getType()); + argsLoc.push_back(dataBound.getUpperbound().getLoc()); + argsTy.push_back(dataBound.getStartIdx().getType()); + argsLoc.push_back(dataBound.getStartIdx().getLoc()); + } + } +} + +using MappableValue = mlir::TypedValue<mlir::acc::MappableType>; + +// Generate the combiner or copy region block and block arguments and return the +// source and destination entities. +static std::pair<MappableValue, MappableValue> +genRecipeCombinerOrCopyRegion(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type ty, mlir::Region ®ion, + llvm::SmallVector<mlir::Value> &bounds, + bool allConstantBound) { + llvm::SmallVector<mlir::Type> argsTy{ty, ty}; + llvm::SmallVector<mlir::Location> argsLoc{loc, loc}; + addRecipeBoundsArgs(bounds, allConstantBound, argsTy, argsLoc); + mlir::Block *block = + builder.createBlock(®ion, region.end(), argsTy, argsLoc); + builder.setInsertionPointToEnd(®ion.back()); + auto firstArg = mlir::cast<MappableValue>(block->getArgument(0)); + auto secondArg = mlir::cast<MappableValue>(block->getArgument(1)); + return {firstArg, secondArg}; +} + +template <typename RecipeOp> +static RecipeOp genRecipeOp( + fir::FirOpBuilder &builder, mlir::ModuleOp mod, llvm::StringRef recipeName, + mlir::Location loc, mlir::Type ty, + llvm::SmallVector<mlir::Value> &dataOperationBounds, bool allConstantBound, + mlir::acc::ReductionOperator op = mlir::acc::ReductionOperator::AccNone) { + mlir::OpBuilder modBuilder(mod.getBodyRegion()); + RecipeOp recipe; + if constexpr (std::is_same_v<RecipeOp, mlir::acc::ReductionRecipeOp>) { + recipe = mlir::acc::ReductionRecipeOp::create(modBuilder, loc, recipeName, + ty, op); + } else { + recipe = RecipeOp::create(modBuilder, loc, recipeName, ty); + } + + assert(hlfir::isFortranVariableType(ty) && "expect Fortran variable type"); + + llvm::SmallVector<mlir::Type> argsTy{ty}; + llvm::SmallVector<mlir::Location> argsLoc{loc}; + if (!dataOperationBounds.empty()) + addRecipeBoundsArgs(dataOperationBounds, allConstantBound, argsTy, argsLoc); + + auto initBlock = builder.createBlock( + &recipe.getInitRegion(), recipe.getInitRegion().end(), argsTy, argsLoc); + builder.setInsertionPointToEnd(&recipe.getInitRegion().back()); + mlir::Value initValue; + if constexpr (std::is_same_v<RecipeOp, mlir::acc::ReductionRecipeOp>) { + assert(op != mlir::acc::ReductionOperator::AccNone); + initValue = getReductionInitValue(builder, loc, ty, op); + } + + // Since we reuse the same recipe for all variables of the same type - we + // cannot use the actual variable name. Thus use a temporary name. + llvm::StringRef initName; + if constexpr (std::is_same_v<RecipeOp, mlir::acc::ReductionRecipeOp>) + initName = accReductionInitName; + else + initName = accPrivateInitName; + + auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(ty); + assert(mappableTy && + "Expected that all variable types are considered mappable"); + bool needsDestroy = false; + llvm::SmallVector<mlir::Value> initBounds = + getRecipeBounds(builder, loc, dataOperationBounds, + initBlock->getArguments().drop_front(1)); + mlir::Value retVal = mappableTy.generatePrivateInit( + builder, loc, mlir::cast<MappableValue>(initBlock->getArgument(0)), + initName, initBounds, initValue, needsDestroy); + mlir::acc::YieldOp::create(builder, loc, retVal); + // Create destroy region and generate destruction if requested. + if (needsDestroy) { + llvm::SmallVector<mlir::Type> destroyArgsTy; + llvm::SmallVector<mlir::Location> destroyArgsLoc; + // original and privatized/reduction value + destroyArgsTy.push_back(ty); + destroyArgsTy.push_back(ty); + destroyArgsLoc.push_back(loc); + destroyArgsLoc.push_back(loc); + // Append bounds arguments (if any) in the same order as init region + if (argsTy.size() > 1) { + destroyArgsTy.append(argsTy.begin() + 1, argsTy.end()); + destroyArgsLoc.insert(destroyArgsLoc.end(), argsTy.size() - 1, loc); + } + + mlir::Block *destroyBlock = builder.createBlock( + &recipe.getDestroyRegion(), recipe.getDestroyRegion().end(), + destroyArgsTy, destroyArgsLoc); + builder.setInsertionPointToEnd(destroyBlock); + + llvm::SmallVector<mlir::Value> destroyBounds = + getRecipeBounds(builder, loc, dataOperationBounds, + destroyBlock->getArguments().drop_front(2)); + [[maybe_unused]] bool success = mappableTy.generatePrivateDestroy( + builder, loc, destroyBlock->getArgument(1), destroyBounds); + assert(success && "failed to generate destroy region"); + mlir::acc::TerminatorOp::create(builder, loc); + } + return recipe; +} + +mlir::SymbolRefAttr +fir::acc::createOrGetPrivateRecipe(mlir::OpBuilder &mlirBuilder, + mlir::Location loc, mlir::Type ty, + llvm::SmallVector<mlir::Value> &bounds) { + mlir::ModuleOp mod = + mlirBuilder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(mlirBuilder, mod); + std::string recipeName = ::getRecipeName( + mlir::acc::RecipeKind::private_recipe, ty, builder.getKindMap(), bounds); + if (auto recipe = mod.lookupSymbol<mlir::acc::PrivateRecipeOp>(recipeName)) + return mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName()); + + mlir::OpBuilder::InsertionGuard guard(builder); + bool allConstantBound = fir::acc::areAllBoundsConstant(bounds); + auto recipe = genRecipeOp<mlir::acc::PrivateRecipeOp>( + builder, mod, recipeName, loc, ty, bounds, allConstantBound); + return mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName()); +} + +mlir::SymbolRefAttr fir::acc::createOrGetFirstprivateRecipe( + mlir::OpBuilder &mlirBuilder, mlir::Location loc, mlir::Type ty, + llvm::SmallVector<mlir::Value> &dataBoundOps) { + mlir::ModuleOp mod = + mlirBuilder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(mlirBuilder, mod); + std::string recipeName = + ::getRecipeName(mlir::acc::RecipeKind::firstprivate_recipe, ty, + builder.getKindMap(), dataBoundOps); + if (auto recipe = + mod.lookupSymbol<mlir::acc::FirstprivateRecipeOp>(recipeName)) + return mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName()); + + mlir::OpBuilder::InsertionGuard guard(builder); + bool allConstantBound = fir::acc::areAllBoundsConstant(dataBoundOps); + auto recipe = genRecipeOp<mlir::acc::FirstprivateRecipeOp>( + builder, mod, recipeName, loc, ty, dataBoundOps, allConstantBound); + auto [source, destination] = genRecipeCombinerOrCopyRegion( + builder, loc, ty, recipe.getCopyRegion(), dataBoundOps, allConstantBound); + llvm::SmallVector<mlir::Value> copyBounds = + getRecipeBounds(builder, loc, dataBoundOps, + recipe.getCopyRegion().getArguments().drop_front(2)); + + auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(ty); + assert(mappableTy && + "Expected that all variable types are considered mappable"); + [[maybe_unused]] bool success = + mappableTy.generateCopy(builder, loc, source, destination, copyBounds); + assert(success && "failed to generate copy"); + mlir::acc::TerminatorOp::create(builder, loc); + return mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName()); +} + +mlir::SymbolRefAttr fir::acc::createOrGetReductionRecipe( + mlir::OpBuilder &mlirBuilder, mlir::Location loc, mlir::Type ty, + mlir::acc::ReductionOperator op, + llvm::SmallVector<mlir::Value> &dataBoundOps, + mlir::Attribute fastMathAttr) { + mlir::ModuleOp mod = + mlirBuilder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(mlirBuilder, mod); + std::string recipeName = + ::getRecipeName(mlir::acc::RecipeKind::reduction_recipe, ty, + builder.getKindMap(), dataBoundOps, op); + if (auto recipe = mod.lookupSymbol<mlir::acc::ReductionRecipeOp>(recipeName)) + return mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName()); + + mlir::OpBuilder::InsertionGuard guard(builder); + bool allConstantBound = fir::acc::areAllBoundsConstant(dataBoundOps); + auto recipe = genRecipeOp<mlir::acc::ReductionRecipeOp>( + builder, mod, recipeName, loc, ty, dataBoundOps, allConstantBound, op); + + auto [dest, source] = genRecipeCombinerOrCopyRegion( + builder, loc, ty, recipe.getCombinerRegion(), dataBoundOps, + allConstantBound); + llvm::SmallVector<mlir::Value> combinerBounds = + getRecipeBounds(builder, loc, dataBoundOps, + recipe.getCombinerRegion().getArguments().drop_front(2)); + + auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(ty); + assert(mappableTy && + "Expected that all variable types are considered mappable"); + [[maybe_unused]] bool success = mappableTy.generateCombiner( + builder, loc, dest, source, combinerBounds, op, fastMathAttr); + assert(success && "failed to generate combiner"); + mlir::acc::YieldOp::create(builder, loc, dest); + return mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName()); +} + +mlir::Value fir::acc::getOriginalDef(mlir::Value value, bool stripDeclare) { + mlir::Value currentValue = value; + + while (currentValue) { + auto *definingOp = currentValue.getDefiningOp(); + if (!definingOp) + break; + + if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(definingOp)) { + currentValue = convertOp.getValue(); + continue; + } + + if (auto viewLike = mlir::dyn_cast<mlir::ViewLikeOpInterface>(definingOp)) { + currentValue = viewLike.getViewSource(); + continue; + } + + if (stripDeclare) { + if (auto declareOp = mlir::dyn_cast<hlfir::DeclareOp>(definingOp)) { + currentValue = declareOp.getMemref(); + continue; + } + + if (auto declareOp = mlir::dyn_cast<fir::DeclareOp>(definingOp)) { + currentValue = declareOp.getMemref(); + continue; + } + } + break; + } + + return currentValue; +} diff --git a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp index 717bf34..c0be247 100644 --- a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp @@ -11,8 +11,15 @@ //===----------------------------------------------------------------------===// #include "flang/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.h" + +#include "flang/Optimizer/Dialect/CUF/CUFDialect.h" +#include "flang/Optimizer/Dialect/CUF/CUFOps.h" #include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/HLFIR/HLFIRDialect.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h" #include "flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h" namespace fir::acc { @@ -37,7 +44,62 @@ void registerOpenACCExtensions(mlir::DialectRegistry ®istry) { fir::LLVMPointerType::attachInterface< OpenACCPointerLikeModel<fir::LLVMPointerType>>(*ctx); + + fir::LogicalType::attachInterface<OpenACCReducibleLogicalModel>(*ctx); + + fir::ArrayCoorOp::attachInterface< + PartialEntityAccessModel<fir::ArrayCoorOp>>(*ctx); + fir::CoordinateOp::attachInterface< + PartialEntityAccessModel<fir::CoordinateOp>>(*ctx); + fir::DeclareOp::attachInterface<PartialEntityAccessModel<fir::DeclareOp>>( + *ctx); + + fir::AddrOfOp::attachInterface<AddressOfGlobalModel>(*ctx); + fir::GlobalOp::attachInterface<GlobalVariableModel>(*ctx); + + fir::AllocaOp::attachInterface<IndirectGlobalAccessModel<fir::AllocaOp>>( + *ctx); + fir::EmboxOp::attachInterface<IndirectGlobalAccessModel<fir::EmboxOp>>( + *ctx); + fir::ReboxOp::attachInterface<IndirectGlobalAccessModel<fir::ReboxOp>>( + *ctx); + fir::TypeDescOp::attachInterface< + IndirectGlobalAccessModel<fir::TypeDescOp>>(*ctx); + + // Attach OutlineRematerializationOpInterface to FIR operations that + // produce synthetic types (shapes, field indices) which cannot be passed + // as arguments to outlined regions and must be rematerialized inside. + fir::ShapeOp::attachInterface<OutlineRematerializationModel<fir::ShapeOp>>( + *ctx); + fir::ShapeShiftOp::attachInterface< + OutlineRematerializationModel<fir::ShapeShiftOp>>(*ctx); + fir::ShiftOp::attachInterface<OutlineRematerializationModel<fir::ShiftOp>>( + *ctx); + fir::FieldIndexOp::attachInterface< + OutlineRematerializationModel<fir::FieldIndexOp>>(*ctx); + }); + + // Register HLFIR operation interfaces + registry.addExtension( + +[](mlir::MLIRContext *ctx, hlfir::hlfirDialect *dialect) { + hlfir::DesignateOp::attachInterface< + PartialEntityAccessModel<hlfir::DesignateOp>>(*ctx); + hlfir::DeclareOp::attachInterface< + PartialEntityAccessModel<hlfir::DeclareOp>>(*ctx); + }); + + // Register CUF operation interfaces + registry.addExtension(+[](mlir::MLIRContext *ctx, cuf::CUFDialect *dialect) { + cuf::KernelOp::attachInterface<OffloadRegionModel<cuf::KernelOp>>(*ctx); }); + + // Attach FIR dialect interfaces to OpenACC operations. + registry.addExtension(+[](mlir::MLIRContext *ctx, + mlir::acc::OpenACCDialect *dialect) { + mlir::acc::LoopOp::attachInterface<OperationMoveModel<mlir::acc::LoopOp>>( + *ctx); + }); + registerAttrsExtensions(registry); } diff --git a/flang/lib/Optimizer/OpenACC/Transforms/ACCInitializeFIRAnalyses.cpp b/flang/lib/Optimizer/OpenACC/Transforms/ACCInitializeFIRAnalyses.cpp new file mode 100644 index 0000000..679b29b --- /dev/null +++ b/flang/lib/Optimizer/OpenACC/Transforms/ACCInitializeFIRAnalyses.cpp @@ -0,0 +1,56 @@ +//===- ACCInitializeFIRAnalyses.cpp - Initialize FIR analyses ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass initializes analyses that can be reused by subsequent OpenACC +// passes in the pipeline. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Analysis/AliasAnalysis.h" +#include "flang/Optimizer/OpenACC/Analysis/FIROpenACCSupportAnalysis.h" +#include "flang/Optimizer/OpenACC/Passes.h" +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" + +namespace fir { +namespace acc { +#define GEN_PASS_DEF_ACCINITIALIZEFIRANALYSES +#include "flang/Optimizer/OpenACC/Passes.h.inc" +} // namespace acc +} // namespace fir + +#define DEBUG_TYPE "acc-initialize-fir-analyses" + +namespace { + +/// This pass initializes analyses for reuse by subsequent OpenACC passes in the +/// pipeline. It creates and caches analyses like OpenACCSupport so they can be +/// retrieved by later passes using getAnalysis() or getCachedAnalysis(). +class ACCInitializeFIRAnalysesPass + : public fir::acc::impl::ACCInitializeFIRAnalysesBase< + ACCInitializeFIRAnalysesPass> { +public: + void runOnOperation() override { + // Initialize OpenACCSupport with FIR-specific implementation. + auto &openACCSupport = getAnalysis<mlir::acc::OpenACCSupport>(); + openACCSupport.setImplementation(fir::acc::FIROpenACCSupportAnalysis()); + + // Initialize AliasAnalysis with FIR-specific implementation. + auto &aliasAnalysis = getAnalysis<mlir::AliasAnalysis>(); + aliasAnalysis.addAnalysisImplementation(fir::AliasAnalysis()); + + // Mark all analyses as preserved since this pass only initializes them + markAllAnalysesPreserved(); + } +}; + +} // namespace + +std::unique_ptr<mlir::Pass> fir::acc::createACCInitializeFIRAnalysesPass() { + return std::make_unique<ACCInitializeFIRAnalysesPass>(); +} diff --git a/flang/lib/Optimizer/OpenACC/Transforms/ACCOptimizeFirstprivateMap.cpp b/flang/lib/Optimizer/OpenACC/Transforms/ACCOptimizeFirstprivateMap.cpp new file mode 100644 index 0000000..ec40e12 --- /dev/null +++ b/flang/lib/Optimizer/OpenACC/Transforms/ACCOptimizeFirstprivateMap.cpp @@ -0,0 +1,193 @@ +//===- ACCOptimizeFirstprivateMap.cpp -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass optimizes firstprivate mapping operations (acc.firstprivate_map). +// The optimization hoists loads from the firstprivate variable to before the +// compute region, effectively converting the firstprivate copy to a +// pass-by-value pattern. This eliminates the need for runtime copying into +// global memory. +// +// Example transformation: +// +// Before: +// %decl = fir.declare %alloca : !fir.ref<i32> +// %fp = acc.firstprivate_map varPtr(%decl) -> !fir.ref<i32> +// acc.parallel { +// %val = fir.load %fp : !fir.ref<i32> // load inside region +// ... +// } +// +// After: +// %decl = fir.declare %alloca : !fir.ref<i32> +// %val = fir.load %decl : !fir.ref<i32> // load hoisted before region +// acc.parallel { +// ... // uses %val directly +// } +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Dialect/FortranVariableInterface.h" +#include "flang/Optimizer/OpenACC/Passes.h" +#include "flang/Optimizer/OpenACC/Support/FIROpenACCUtils.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "llvm/ADT/SmallVector.h" + +namespace fir::acc { +#define GEN_PASS_DEF_ACCOPTIMIZEFIRSTPRIVATEMAP +#include "flang/Optimizer/OpenACC/Passes.h.inc" +} // namespace fir::acc + +using namespace mlir; + +namespace { + +/// Returns the enclosing offload region interface, or nullptr if not inside +/// one. +static acc::OffloadRegionOpInterface getEnclosingOffloadRegion(Operation *op) { + return op->getParentOfType<acc::OffloadRegionOpInterface>(); +} + +/// Returns true if the value is defined by an OpenACC data clause operation. +static bool isDefinedByDataClause(Value value) { + Operation *defOp = value.getDefiningOp(); + if (!defOp) + return false; + return acc::getDataClause(defOp).has_value(); +} + +/// Returns true if the value is defined inside the given offload region. +/// This handles both operation results and block arguments. +static bool isDefinedInsideRegion(Value value, + acc::OffloadRegionOpInterface offloadOp) { + Region *valueRegion = value.getParentRegion(); + if (!valueRegion) + return false; + return offloadOp.getOffloadRegion().isAncestor(valueRegion); +} + +/// Returns true if the variable may be optional. +static bool mayBeOptionalVariable(Value var) { + // Don't strip declare ops - we need to check the optional attribute on them. + Value originalDef = fir::acc::getOriginalDef(var, /*stripDeclare=*/false); + if (auto varIface = dyn_cast_or_null<fir::FortranVariableOpInterface>( + originalDef.getDefiningOp())) + return varIface.isOptional(); + // If the defining op is an alloca, it's a local variable and not optional. + if (isa_and_nonnull<fir::AllocaOp, fir::AllocMemOp>( + originalDef.getDefiningOp())) + return false; + // Conservative: if we can't determine, assume it may be optional. + return true; +} + +/// Returns true if the type is a reference to a trivial type. +/// Note that this does not allow fir.heap, fir.ptr, or fir.llvm_ptr +/// types - since we would need to check if the load is valid via +/// a null-check to enable the optimization. +static bool isRefToTrivialType(Type type) { + if (!mlir::isa<fir::ReferenceType>(type)) + return false; + return fir::isa_trivial(fir::unwrapRefType(type)); +} + +/// Attempts to hoist loads from accVar to before firstprivateInitOp. +/// Returns true if all uses of accVar are loads and they were hoisted. +static bool hoistLoads(acc::FirstprivateMapInitialOp firstprivateInitOp, + Value var, Value accVar) { + // Check if all uses are loads - only hoist if we can optimize all uses. + bool allLoads = llvm::all_of(accVar.getUsers(), [](Operation *user) { + return isa<fir::LoadOp>(user); + }); + if (!allLoads) + return false; + + // Hoist all loads before the firstprivate_map operation. + for (Operation *user : llvm::make_early_inc_range(accVar.getUsers())) { + auto loadOp = cast<fir::LoadOp>(user); + loadOp.getMemrefMutable().assign(var); + loadOp->moveBefore(firstprivateInitOp); + } + return true; +} + +class ACCOptimizeFirstprivateMap + : public fir::acc::impl::ACCOptimizeFirstprivateMapBase< + ACCOptimizeFirstprivateMap> { +public: + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + + // Collect all firstprivate_map ops first to avoid modifying IR during walk. + llvm::SmallVector<acc::FirstprivateMapInitialOp> firstprivateOps; + funcOp.walk([&](acc::FirstprivateMapInitialOp op) { + firstprivateOps.push_back(op); + }); + + llvm::SmallVector<acc::FirstprivateMapInitialOp> opsToErase; + + for (acc::FirstprivateMapInitialOp firstprivateInitOp : firstprivateOps) { + Value var = firstprivateInitOp.getVar(); + + if (auto offloadOp = getEnclosingOffloadRegion(firstprivateInitOp)) { + // Inside an offload region. + if (isDefinedByDataClause(var) || + isDefinedInsideRegion(var, offloadOp)) { + // The variable is already mapped or defined locally - just replace + // uses and erase. + firstprivateInitOp.getAccVar().replaceAllUsesWith(var); + opsToErase.push_back(firstprivateInitOp); + } else { + // Variable is defined outside - hoist the op out of the region, + // then apply optimization. + firstprivateInitOp->moveBefore(offloadOp); + if (optimizeFirstprivateMapping(firstprivateInitOp)) + opsToErase.push_back(firstprivateInitOp); + } + } else { + // Outside offload region, apply type-restricted optimization + // to pre-load before the compute region. + if (optimizeFirstprivateMapping(firstprivateInitOp)) + opsToErase.push_back(firstprivateInitOp); + } + } + + for (auto op : opsToErase) + op.erase(); + } + +private: + /// Returns true if the operation was optimized and can be erased. + static bool optimizeFirstprivateMapping( + acc::FirstprivateMapInitialOp firstprivateInitOp) { + Value var = firstprivateInitOp.getVar(); + Value accVar = firstprivateInitOp.getAccVar(); + + // If there are no uses, we can erase the operation. + if (accVar.use_empty()) + return true; + + // Only optimize references to trivial types. + if (!isRefToTrivialType(var.getType())) + return false; + + // Avoid hoisting optional variables as they may be + // null and thus not safe to access. + if (mayBeOptionalVariable(var)) + return false; + + return hoistLoads(firstprivateInitOp, var, accVar); + } +}; + +} // namespace + +std::unique_ptr<Pass> fir::acc::createACCOptimizeFirstprivateMapPass() { + return std::make_unique<ACCOptimizeFirstprivateMap>(); +} diff --git a/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp index 4840a99..ad0cfa3 100644 --- a/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp +++ b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp @@ -39,13 +39,13 @@ public: static mlir::Operation *load(mlir::OpBuilder &builder, mlir::Location loc, mlir::Value value) { - return builder.create<fir::LoadOp>(loc, value); + return fir::LoadOp::create(builder, loc, value); } static mlir::Value placeInMemory(mlir::OpBuilder &builder, mlir::Location loc, mlir::Value value) { - auto alloca = builder.create<fir::AllocaOp>(loc, value.getType()); - builder.create<fir::StoreOp>(loc, value, alloca); + auto alloca = fir::AllocaOp::create(builder, loc, value.getType()); + fir::StoreOp::create(builder, loc, value, alloca); return alloca; } }; @@ -87,30 +87,26 @@ static void bufferizeRegionArgsAndYields(mlir::Region ®ion, } } -static void updateRecipeUse(mlir::ArrayAttr recipes, mlir::ValueRange operands, +template <typename OpTy> +static void updateRecipeUse(mlir::ValueRange operands, llvm::StringRef recipeSymName, mlir::Operation *computeOp) { - if (!recipes) - return; - for (auto [recipeSym, oldRes] : llvm::zip(recipes, operands)) { - if (llvm::cast<mlir::SymbolRefAttr>(recipeSym).getLeafReference() != - recipeSymName) + for (auto operand : operands) { + auto op = operand.getDefiningOp<OpTy>(); + if (!op || !op.getRecipe().has_value() || + op.getRecipeAttr().getLeafReference() != recipeSymName) continue; - mlir::Operation *dataOp = oldRes.getDefiningOp(); - assert(dataOp && "dataOp must be paired with computeOp"); - mlir::Location loc = dataOp->getLoc(); - mlir::OpBuilder builder(dataOp); - llvm::TypeSwitch<mlir::Operation *, void>(dataOp) - .Case<mlir::acc::PrivateOp, mlir::acc::FirstprivateOp, - mlir::acc::ReductionOp>([&](auto privateOp) { - builder.setInsertionPointAfterValue(privateOp.getVar()); - mlir::Value alloca = BufferizeInterface::placeInMemory( - builder, loc, privateOp.getVar()); - privateOp.getVarMutable().assign(alloca); - privateOp.getAccVar().setType(alloca.getType()); - }); + mlir::Location loc = op->getLoc(); + + mlir::OpBuilder builder(op); + builder.setInsertionPointAfterValue(op.getVar()); + mlir::Value alloca = + BufferizeInterface::placeInMemory(builder, loc, op.getVar()); + op.getVarMutable().assign(alloca); + op.getAccVar().setType(alloca.getType()); + mlir::Value oldRes = op.getAccVar(); llvm::SmallVector<mlir::Operation *> users(oldRes.getUsers().begin(), oldRes.getUsers().end()); for (mlir::Operation *useOp : users) { @@ -166,18 +162,15 @@ public: .Case<mlir::acc::LoopOp, mlir::acc::ParallelOp, mlir::acc::SerialOp>( [&](auto computeOp) { for (llvm::StringRef recipeName : recipeNames) { - if (computeOp.getPrivatizationRecipes()) - updateRecipeUse(computeOp.getPrivatizationRecipesAttr(), - computeOp.getPrivateOperands(), recipeName, - op); - if (computeOp.getFirstprivatizationRecipes()) - updateRecipeUse( - computeOp.getFirstprivatizationRecipesAttr(), + if (!computeOp.getPrivateOperands().empty()) + updateRecipeUse<mlir::acc::PrivateOp>( + computeOp.getPrivateOperands(), recipeName, op); + if (!computeOp.getFirstprivateOperands().empty()) + updateRecipeUse<mlir::acc::FirstprivateOp>( computeOp.getFirstprivateOperands(), recipeName, op); - if (computeOp.getReductionRecipes()) - updateRecipeUse(computeOp.getReductionRecipesAttr(), - computeOp.getReductionOperands(), - recipeName, op); + if (!computeOp.getReductionOperands().empty()) + updateRecipeUse<mlir::acc::ReductionOp>( + computeOp.getReductionOperands(), recipeName, op); } }); }); diff --git a/flang/lib/Optimizer/OpenACC/Transforms/ACCUseDeviceCanonicalizer.cpp b/flang/lib/Optimizer/OpenACC/Transforms/ACCUseDeviceCanonicalizer.cpp new file mode 100644 index 0000000..51ab7960 --- /dev/null +++ b/flang/lib/Optimizer/OpenACC/Transforms/ACCUseDeviceCanonicalizer.cpp @@ -0,0 +1,400 @@ +//===- ACCUseDeviceCanonicalizer.cpp --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass canonicalizes the use_device clause on a host_data construct such +// that use_device(x) can be lowered to a simple runtime call that takes the +// actual host pointer as argument. +// +// For a use_device operand that is a box type or a reference to a box, the +// pass: +// 1. Extracts the host base address for mapping to a device address using +// acc.use_device. +// 2. Creates a new boxed descriptor with the device address as the base +// address for use inside the host_data region. +// +// The pass also removes unused use_device clauses, reducing the number of +// runtime calls. +// +// Supported use_device operand types: +// +// Scalars: +// - !fir.ref<i32>, !fir.ref<f64>, etc. +// +// Arrays: +// - Explicit shape (no descriptor): !fir.ref<!fir.array<100xi32>> +// - Adjustable size: !fir.ref<!fir.array<?xi32>> +// - Assumed shape (handled by hoistBox): !fir.box<!fir.array<?xi32>> +// - Assumed size: !fir.ref<!fir.array<?xi32>> +// - Deferred shape (handled by hoistRefToBox): +// - Allocatable: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> +// - Pointer: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>> +// - Subarray specification (handled by hoistBox): +// !fir.box<!fir.array<?xi32>> +// +// Not yet supported: +// - Assumed rank arrays +// - Composite variables: !fir.ref<!fir.type<...>> +// - Array elements (device pointer arithmetic in host_data region) +// - Composite variable members +// - Fortran common blocks: use_device(/cm_block/) +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/OpenACC/Passes.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" +#include <cassert> + +namespace fir::acc { +#define GEN_PASS_DEF_ACCUSEDEVICECANONICALIZER +#include "flang/Optimizer/OpenACC/Passes.h.inc" +} // namespace fir::acc + +#define DEBUG_TYPE "acc-use-device-canonicalizer" + +using namespace mlir; + +namespace { + +struct UseDeviceHostDataHoisting : public OpRewritePattern<acc::HostDataOp> { + using OpRewritePattern<acc::HostDataOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(acc::HostDataOp op, + PatternRewriter &rewriter) const override { + SmallVector<Value> usedOperands; + SmallVector<Value> unusedUseDeviceOperands; + SmallVector<acc::UseDeviceOp> refToBoxUseDeviceOps; + SmallVector<acc::UseDeviceOp> boxUseDeviceOps; + + for (Value operand : op.getDataClauseOperands()) { + if (acc::UseDeviceOp useDeviceOp = + operand.getDefiningOp<acc::UseDeviceOp>()) { + if (fir::isBoxAddress(useDeviceOp.getVar().getType())) { + if (!llvm::hasSingleElement(useDeviceOp->getUsers())) + refToBoxUseDeviceOps.push_back(useDeviceOp); + } else if (isa<fir::BoxType>(useDeviceOp.getVar().getType())) { + if (!llvm::hasSingleElement(useDeviceOp->getUsers())) + boxUseDeviceOps.push_back(useDeviceOp); + } + + // host_data is the only user of this use_device operand - mark for + // removal + if (llvm::hasSingleElement(useDeviceOp->getUsers())) + unusedUseDeviceOperands.push_back(useDeviceOp.getResult()); + else + usedOperands.push_back(useDeviceOp.getResult()); + } else { + // Operand is not an `acc.use_device` result, keep it as is. + usedOperands.push_back(operand); + } + } + + assert(!usedOperands.empty() && "Host_data operation has no used operands"); + + if (!unusedUseDeviceOperands.empty()) { + LLVM_DEBUG(llvm::dbgs() + << "ACCUseDeviceCanonicalizer: Removing " + << unusedUseDeviceOperands.size() + << " unused use_device operands from host_data operation\n"); + + // Update the host_data operation to have only used operands + rewriter.modifyOpInPlace(op, [&]() { + op.getDataClauseOperandsMutable().assign(usedOperands); + }); + + // Remove unused use_device operations + for (Value operand : unusedUseDeviceOperands) { + acc::UseDeviceOp useDeviceOp = + operand.getDefiningOp<acc::UseDeviceOp>(); + LLVM_DEBUG(llvm::dbgs() << "ACCUseDeviceCanonicalizer: Erasing: " + << *useDeviceOp << "\n"); + rewriter.eraseOp(useDeviceOp); + } + return success(); + } + + // Handle references to box types + bool modified = false; + for (acc::UseDeviceOp useDeviceOp : refToBoxUseDeviceOps) + modified |= + hoistRefToBox(rewriter, useDeviceOp.getResult(), useDeviceOp, op); + + // Handle box types + for (acc::UseDeviceOp useDeviceOp : boxUseDeviceOps) + modified |= hoistBox(rewriter, useDeviceOp.getResult(), useDeviceOp, op); + + return modified ? success() : failure(); + } + +private: + /// Collect users of `acc.use_device` operation inside the `acc.host_data` + /// region that need to be updated with the final replacement value. + void collectUseDeviceUsersToUpdate( + acc::UseDeviceOp useDeviceOp, acc::HostDataOp hostDataOp, + SmallVectorImpl<Operation *> &usersToUpdate) const { + for (mlir::Operation *user : useDeviceOp->getUsers()) + if (hostDataOp.getRegion().isAncestor(user->getParentRegion())) + usersToUpdate.push_back(user); + } + + /// Create new `acc.use_device` operation with the given box address as + /// operand. Updates the `acc.host_data` operation to use the new + /// `acc.use_device` result. + acc::UseDeviceOp createNewUseDeviceOp(PatternRewriter &rewriter, + acc::UseDeviceOp useDeviceOp, + acc::HostDataOp hostDataOp, + fir::BoxAddrOp boxAddr) const { + // Create use_device on the raw pointer + acc::UseDeviceOp newUseDeviceOp = acc::UseDeviceOp::create( + rewriter, useDeviceOp.getLoc(), boxAddr.getType(), boxAddr.getResult(), + useDeviceOp.getVarTypeAttr(), useDeviceOp.getVarPtrPtr(), + useDeviceOp.getBounds(), useDeviceOp.getAsyncOperands(), + useDeviceOp.getAsyncOperandsDeviceTypeAttr(), + useDeviceOp.getAsyncOnlyAttr(), useDeviceOp.getDataClauseAttr(), + useDeviceOp.getStructuredAttr(), useDeviceOp.getImplicitAttr(), + useDeviceOp.getModifiersAttr(), useDeviceOp.getNameAttr(), + useDeviceOp.getRecipeAttr()); + + LLVM_DEBUG(llvm::dbgs() << "Created new hoisted pattern for box access:\n" + << " box_addr: " << *boxAddr << "\n" + << " new use_device: " << *newUseDeviceOp << "\n"); + + // Replace the old `acc.use_device` operand in the `acc.host_data` operation + // with the new one + rewriter.modifyOpInPlace(hostDataOp, [&]() { + hostDataOp->replaceUsesOfWith(useDeviceOp.getResult(), + newUseDeviceOp.getResult()); + }); + + return newUseDeviceOp; + } + + /// Canonicalize use_device operand that is a reference to a box. + /// Transforms: + /// %3 = fir.address_of(@_QFEtgt) : !fir.ref<i32> + /// %5 = fir.embox %3 : (!fir.ref<i32>) -> !fir.box<!fir.ptr<i32>> + /// fir.store %5 to %0 : !fir.ref<!fir.box<!fir.ptr<i32>>> + /// %9 = acc.use_device varPtr(%0 : !fir.ref<!fir.box<!fir.ptr<i32>>>) + /// -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "ptr"} + /// acc.host_data dataOperands(%9 : !fir.ref<!fir.box<!fir.ptr<i32>>>) { + /// %loaded = fir.load %9 : !fir.ref<!fir.box<!fir.ptr<i32>>> + /// %addr = fir.box_addr %loaded : (!fir.box<!fir.ptr<i32>>) -> + /// !fir.ptr<i32> %conv = fir.convert %addr : (!fir.ptr<i32>) -> i64 + /// fir.call @foo(%conv) : (i64) -> () + /// acc.terminator + /// } + /// into: + /// %loaded = fir.load %0 : !fir.ref<!fir.box<!fir.ptr<i32>>> + /// %addr = fir.box_addr %loaded : (!fir.box<!fir.ptr<i32>>) -> + /// !fir.ptr<i32> + /// %dev_ptr = acc.use_device varPtr(%addr : !fir.ptr<i32>) -> + /// !fir.ptr<i32> + /// -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "ptr"} + /// acc.host_data dataOperands(%dev_ptr : !fir.ref<!fir.box<!fir.ptr<i32>>>) + /// { + /// %embox = fir.embox %dev_ptr : (!fir.ptr<i32>) -> + /// !fir.box<!fir.ptr<i32>> %alloca = fir.alloca !fir.box<!fir.ptr<i32>> + /// fir.store %embox to %alloca : !fir.ref<!fir.box<!fir.ptr<i32>>> + /// %loaded2 = fir.load %alloca : !fir.ref<!fir.box<!fir.ptr<i32>>> + /// %addr2 = fir.box_addr %loaded2 : (!fir.box<!fir.ptr<i32>>) -> + /// !fir.ptr<i32> %conv = fir.convert %addr2 : (!fir.ptr<i32>) -> i64 + /// fir.call @foo(%conv) : (i64) -> () + /// acc.terminator + /// } + bool hoistRefToBox(PatternRewriter &rewriter, Value operand, + acc::UseDeviceOp useDeviceOp, + acc::HostDataOp hostDataOp) const { + + // Safety check: if the use_device operation is already using a box_addr + // result, it means it has already been processed, so skip to avoid infinite + // loop + if (useDeviceOp.getVar().getDefiningOp<fir::BoxAddrOp>()) { + LLVM_DEBUG(llvm::dbgs() << "ACCUseDeviceCanonicalizer: Skipping " + "already processed use_device operation\n"); + return false; + } + // Get the ModuleOp before we erase useDeviceOp to avoid invalid reference + ModuleOp mod = useDeviceOp->getParentOfType<ModuleOp>(); + + // Collect users of the original `acc.use_device` operation that need to be + // updated + SmallVector<Operation *> usersToUpdate; + collectUseDeviceUsersToUpdate(useDeviceOp, hostDataOp, usersToUpdate); + + rewriter.setInsertionPoint(useDeviceOp); + // Create a load operation to get the box from the variable + fir::LoadOp box = fir::LoadOp::create(rewriter, useDeviceOp.getLoc(), + useDeviceOp.getVar()); + // Create a box_addr operation to get the address from the box + fir::BoxAddrOp boxAddr = + fir::BoxAddrOp::create(rewriter, useDeviceOp.getLoc(), box); + + acc::UseDeviceOp newUseDeviceOp = + createNewUseDeviceOp(rewriter, useDeviceOp, hostDataOp, boxAddr); + + LLVM_DEBUG(llvm::dbgs() + << "Created new hoisted pattern for pointer access:\n" + << " load box: " << *box << "\n" + << " box_addr: " << *boxAddr << "\n" + << " new use_device: " << *newUseDeviceOp << "\n"); + + // Set insertion point to the first op inside the host_data region + rewriter.setInsertionPoint(&hostDataOp.getRegion().front().front()); + + // Create a FirOpBuilder from the PatternRewriter using the module we got + // earlier + fir::FirOpBuilder builder(rewriter, mod); + Value newBoxwithDevicePtr = fir::factory::getDescriptorWithNewBaseAddress( + builder, useDeviceOp.getLoc(), box.getResult(), + newUseDeviceOp.getResult()); + + // Create new memory location and store the newBoxwithDevicePtr into new + // memory location + fir::AllocaOp newMemLoc = fir::AllocaOp::create( + rewriter, useDeviceOp.getLoc(), newBoxwithDevicePtr.getType()); + [[maybe_unused]] fir::StoreOp newStoreOp = fir::StoreOp::create( + rewriter, useDeviceOp.getLoc(), newBoxwithDevicePtr, newMemLoc); + + LLVM_DEBUG(llvm::dbgs() + << "host_data region updated with new host descriptor " + "containing device pointer:\n" + << " box with device pointer: " + << *newBoxwithDevicePtr.getDefiningOp() << "\n" + << " mem loc: " << *newMemLoc << "\n" + << " store op: " << *newStoreOp << "\n"); + + // Replace all uses of the original `acc.use_device` operation inside the + // `acc.host_data` region with the new memory location containing the box + // with device pointer + for (mlir::Operation *user : usersToUpdate) + user->replaceUsesOfWith(useDeviceOp.getResult(), newMemLoc); + + assert(useDeviceOp.getResult().use_empty() && + "expected all uses of use_device to be replaced"); + rewriter.eraseOp(useDeviceOp); + return true; + } + + /// Canonicalize use_device operand that is a box type. + /// Transforms: + /// %box = ... : !fir.box<!fir.array<?xi32>> + /// %dev_box = acc.use_device varPtr(%box : !fir.box<!fir.array<?xi32>>) + /// -> !fir.box<!fir.array<?xi32>> + /// acc.host_data dataOperands(%dev_box : !fir.box<!fir.array<?xi32>>) { + /// %addr = fir.box_addr %dev_box : (!fir.box<!fir.array<?xi32>>) -> + /// !fir.heap<!fir.array<?xi32>> + /// // use %addr + /// } + /// into: + /// %box = ... : !fir.box<!fir.array<?xi32>> + /// %addr = fir.box_addr %box : (!fir.box<!fir.array<?xi32>>) -> + /// !fir.heap<!fir.array<?xi32>> + /// %dev_ptr = acc.use_device varPtr(%addr : !fir.heap<!fir.array<?xi32>>) + /// -> !fir.heap<!fir.array<?xi32>> + /// acc.host_data dataOperands(%dev_ptr : !fir.heap<!fir.array<?xi32>>) { + /// %new_box = fir.embox %dev_ptr ... : !fir.box<!fir.array<?xi32>> + /// %new_addr = fir.box_addr %new_box : (!fir.box<!fir.array<?xi32>>) -> + /// !fir.heap<!fir.array<?xi32>> + /// // use %new_addr instead of %addr + /// } + bool hoistBox(PatternRewriter &rewriter, Value operand, + acc::UseDeviceOp useDeviceOp, + acc::HostDataOp hostDataOp) const { + + // Safety check: if the use_device operation is already using a box_addr + // result, it means it has already been processed, so skip to avoid infinite + // loop + if (useDeviceOp.getVar().getDefiningOp<fir::BoxAddrOp>()) { + LLVM_DEBUG(llvm::dbgs() + << "ACCUseDeviceCanonicalizer: Skipping " + "already processed box use_device operation\n"); + return false; + } + + // Collect users of the original `acc.use_device` operation that need to be + // updated + SmallVector<Operation *> usersToUpdate; + collectUseDeviceUsersToUpdate(useDeviceOp, hostDataOp, usersToUpdate); + + // Get the ModuleOp before we erase useDeviceOp to avoid invalid reference + ModuleOp mod = useDeviceOp->getParentOfType<ModuleOp>(); + + rewriter.setInsertionPoint(useDeviceOp); + // Extract the raw pointer from the box descriptor + fir::BoxAddrOp boxAddr = fir::BoxAddrOp::create( + rewriter, useDeviceOp.getLoc(), useDeviceOp.getVar()); + + acc::UseDeviceOp newUseDeviceOp = + createNewUseDeviceOp(rewriter, useDeviceOp, hostDataOp, boxAddr); + + // Set insertion point to the first op inside the host_data region + rewriter.setInsertionPoint(&hostDataOp.getRegion().front().front()); + + // Create a FirOpBuilder from the PatternRewriter using the module we got + // earlier + fir::FirOpBuilder builder(rewriter, mod); + + // Create a new host descriptor at the start of the host_data region + // with the device pointer as the base address + Value newBoxWithDevicePtr = fir::factory::getDescriptorWithNewBaseAddress( + builder, useDeviceOp.getLoc(), useDeviceOp.getVar(), + newUseDeviceOp.getResult()); + + LLVM_DEBUG(llvm::dbgs() + << "host_data region updated with new host descriptor " + "containing device pointer:\n" + << " box with device pointer: " + << *newBoxWithDevicePtr.getDefiningOp() << "\n"); + + // Replace all uses of the original `acc.use_device` operation inside the + // `acc.host_data` region with the new box containing device pointer + for (mlir::Operation *user : usersToUpdate) + user->replaceUsesOfWith(useDeviceOp.getResult(), newBoxWithDevicePtr); + + assert(useDeviceOp.getResult().use_empty() && + "expected all uses of use_device to be replaced"); + rewriter.eraseOp(useDeviceOp); + return true; + } +}; + +class ACCUseDeviceCanonicalizer + : public fir::acc::impl::ACCUseDeviceCanonicalizerBase< + ACCUseDeviceCanonicalizer> { +public: + void runOnOperation() override { + MLIRContext *context = getOperation()->getContext(); + + RewritePatternSet patterns(context); + + // Add the custom use_device canonicalization patterns + patterns.insert<UseDeviceHostDataHoisting>(context); + + // Apply patterns greedily + GreedyRewriteConfig config; + // Prevent the pattern driver from merging blocks. + config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); + config.setUseTopDownTraversal(true); + + (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); + } +}; + +} // namespace + +std::unique_ptr<mlir::Pass> fir::acc::createACCUseDeviceCanonicalizerPass() { + return std::make_unique<ACCUseDeviceCanonicalizer>(); +} diff --git a/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt index ed177ba..27c5ee6 100644 --- a/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt @@ -1,14 +1,25 @@ add_flang_library(FIROpenACCTransforms + ACCInitializeFIRAnalyses.cpp + ACCOptimizeFirstprivateMap.cpp ACCRecipeBufferization.cpp + ACCUseDeviceCanonicalizer.cpp DEPENDS FIROpenACCPassesIncGen LINK_LIBS + FIRAnalysis + FIRBuilder FIRDialect + FIRDialectSupport + FIROpenACCAnalysis + FIROpenACCSupport + HLFIRDialect MLIR_LIBS MLIRIR MLIRPass MLIROpenACCDialect + MLIROpenACCUtils + MLIRTransformUtils ) diff --git a/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp b/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp index 8b99913..5793d46 100644 --- a/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp +++ b/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp @@ -20,8 +20,6 @@ #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" - namespace flangomp { #define GEN_PASS_DEF_AUTOMAPTOTARGETDATAPASS #include "flang/Optimizer/OpenMP/Passes.h.inc" @@ -120,12 +118,9 @@ class AutomapToTargetDataPass builder, memOp.getLoc(), memOp.getMemref().getType(), memOp.getMemref(), TypeAttr::get(fir::unwrapRefType(memOp.getMemref().getType())), - builder.getIntegerAttr( - builder.getIntegerType(64, false), - static_cast<unsigned>( - isa<fir::StoreOp>(memOp) - ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO - : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)), + builder.getAttr<omp::ClauseMapFlagsAttr>( + isa<fir::StoreOp>(memOp) ? omp::ClauseMapFlags::to + : omp::ClauseMapFlags::del), builder.getAttr<omp::VariableCaptureKindAttr>( omp::VariableCaptureKind::ByCopy), /*var_ptr_ptr=*/mlir::Value{}, @@ -135,8 +130,8 @@ class AutomapToTargetDataPass builder.getBoolAttr(false)); clauses.mapVars.push_back(mapInfo); isa<fir::StoreOp>(memOp) - ? builder.create<omp::TargetEnterDataOp>(memOp.getLoc(), clauses) - : builder.create<omp::TargetExitDataOp>(memOp.getLoc(), clauses); + ? omp::TargetEnterDataOp::create(builder, memOp.getLoc(), clauses) + : omp::TargetExitDataOp::create(builder, memOp.getLoc(), clauses); }; for (fir::GlobalOp globalOp : automapGlobals) { diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp index 03ff163..ff346e7 100644 --- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp @@ -22,7 +22,6 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" namespace flangomp { #define GEN_PASS_DEF_DOCONCURRENTCONVERSIONPASS @@ -484,6 +483,8 @@ private: } loopNestClauseOps.loopInclusive = rewriter.getUnitAttr(); + loopNestClauseOps.collapseNumLoops = + rewriter.getI64IntegerAttr(loopNestClauseOps.loopLowerBounds.size()); } std::pair<mlir::omp::LoopNestOp, mlir::omp::WsloopOp> @@ -568,16 +569,15 @@ private: if (auto refType = mlir::dyn_cast<fir::ReferenceType>(liveInType)) eleType = refType.getElementType(); - llvm::omp::OpenMPOffloadMappingFlags mapFlag = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::implicit; mlir::omp::VariableCaptureKind captureKind = mlir::omp::VariableCaptureKind::ByRef; if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) { captureKind = mlir::omp::VariableCaptureKind::ByCopy; } else if (!fir::isa_builtin_cptr_type(eleType)) { - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapFlag |= mlir::omp::ClauseMapFlags::to; + mapFlag |= mlir::omp::ClauseMapFlags::from; } llvm::SmallVector<mlir::Value> boundsOps; @@ -587,11 +587,8 @@ private: builder, liveIn.getLoc(), rawAddr, /*varPtrPtr=*/{}, name.str(), boundsOps, /*members=*/{}, - /*membersIndex=*/mlir::ArrayAttr{}, - static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - mapFlag), - captureKind, rawAddr.getType()); + /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, + rawAddr.getType()); } mlir::omp::TargetOp @@ -600,7 +597,7 @@ private: mlir::omp::TargetOperands &clauseOps, mlir::omp::LoopNestOperands &loopNestClauseOps, const LiveInShapeInfoMap &liveInShapeInfoMap) const { - auto targetOp = rewriter.create<mlir::omp::TargetOp>(loc, clauseOps); + auto targetOp = mlir::omp::TargetOp::create(rewriter, loc, clauseOps); auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp); mlir::Region ®ion = targetOp.getRegion(); @@ -677,7 +674,7 @@ private: // temporary. Fortran::utils::openmp::cloneOrMapRegionOutsiders(builder, targetOp); rewriter.setInsertionPoint( - rewriter.create<mlir::omp::TerminatorOp>(targetOp.getLoc())); + mlir::omp::TerminatorOp::create(rewriter, targetOp.getLoc())); return targetOp; } @@ -697,9 +694,6 @@ private: if (!targetShapeCreationInfo.isShapedValue()) return {}; - llvm::SmallVector<mlir::Value> extentOperands; - llvm::SmallVector<mlir::Value> startIndexOperands; - if (targetShapeCreationInfo.isShapeShiftedValue()) { llvm::SmallVector<mlir::Value> shapeShiftOperands; @@ -720,8 +714,8 @@ private: auto shapeShiftType = fir::ShapeShiftType::get( builder.getContext(), shapeShiftOperands.size() / 2); - return builder.create<fir::ShapeShiftOp>( - liveInArg.getLoc(), shapeShiftType, shapeShiftOperands); + return fir::ShapeShiftOp::create(builder, liveInArg.getLoc(), + shapeShiftType, shapeShiftOperands); } llvm::SmallVector<mlir::Value> shapeOperands; @@ -733,11 +727,11 @@ private: ++shapeIdx; } - return builder.create<fir::ShapeOp>(liveInArg.getLoc(), shapeOperands); + return fir::ShapeOp::create(builder, liveInArg.getLoc(), shapeOperands); }(); - return builder.create<hlfir::DeclareOp>(liveInArg.getLoc(), liveInArg, - liveInName, shape); + return hlfir::DeclareOp::create(builder, liveInArg.getLoc(), liveInArg, + liveInName, shape); } mlir::omp::TeamsOp genTeamsOp(mlir::ConversionPatternRewriter &rewriter, @@ -747,13 +741,13 @@ private: genReductions(rewriter, mapper, loop, teamsOps); mlir::Location loc = loop.getLoc(); - auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(loc, teamsOps); + auto teamsOp = mlir::omp::TeamsOp::create(rewriter, loc, teamsOps); Fortran::common::openmp::EntryBlockArgs teamsArgs; teamsArgs.reduction.vars = teamsOps.reductionVars; Fortran::common::openmp::genEntryBlock(rewriter, teamsArgs, teamsOp.getRegion()); - rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc)); + rewriter.setInsertionPoint(mlir::omp::TerminatorOp::create(rewriter, loc)); for (auto [loopVar, teamsArg] : llvm::zip_equal( loop.getReduceVars(), teamsOp.getRegion().getArguments())) { @@ -766,8 +760,8 @@ private: mlir::omp::DistributeOp genDistributeOp(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter) const { - auto distOp = rewriter.create<mlir::omp::DistributeOp>( - loc, /*clauses=*/mlir::omp::DistributeOperands{}); + auto distOp = mlir::omp::DistributeOp::create( + rewriter, loc, /*clauses=*/mlir::omp::DistributeOperands{}); rewriter.createBlock(&distOp.getRegion()); return distOp; @@ -856,7 +850,8 @@ private: if (!ompReducer) { ompReducer = mlir::omp::DeclareReductionOp::create( rewriter, firReducer.getLoc(), ompReducerName, - firReducer.getTypeAttr().getValue()); + firReducer.getTypeAttr().getValue(), + firReducer.getByrefElementTypeAttr()); cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(), ompReducer.getAllocRegion()); diff --git a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp index 3031bb5..0acee89 100644 --- a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp +++ b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp @@ -11,15 +11,19 @@ // //===----------------------------------------------------------------------===// +#include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/OpenMP/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" #include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" namespace flangomp { #define GEN_PASS_DEF_FUNCTIONFILTERINGPASS @@ -28,6 +32,77 @@ namespace flangomp { using namespace mlir; +/// This function triggers TODO errors and halts compilation if it detects +/// patterns representing unimplemented features. +/// +/// It exclusively checks situations that cannot be detected after all of the +/// MLIR pipeline has ran (i.e. at the MLIR to LLVM IR translation stage, where +/// the preferred location for these types of checks is), and it only checks for +/// features that have not been implemented for target offload, but are +/// supported on host execution. +static void +checkDeviceImplementationStatus(omp::OffloadModuleInterface offloadModule) { + if (!offloadModule.getIsGPU()) + return; + + offloadModule->walk<WalkOrder::PreOrder>([&](omp::DeclareReductionOp redOp) { + if (redOp.symbolKnownUseEmpty(offloadModule)) + return WalkResult::advance(); + + if (!redOp.getByrefElementType()) + return WalkResult::advance(); + + auto seqTy = + mlir::dyn_cast<fir::SequenceType>(*redOp.getByrefElementType()); + + bool isByRefReductionSupported = + !seqTy || !fir::sequenceWithNonConstantShape(seqTy); + + if (!isByRefReductionSupported) { + TODO(redOp.getLoc(), + "Reduction of dynamically-shaped arrays are not supported yet " + "on the GPU."); + } + + return WalkResult::advance(); + }); +} + +/// Add an operation to one of the output sets to be later rewritten. +template <typename OpTy> +static void collectRewrite(OpTy op, llvm::SetVector<OpTy> &rewrites) { + rewrites.insert(op); +} + +/// Add an \c omp.map.info operation and all its members recursively to the +/// output set to be later rewritten. +/// +/// Dependencies across \c omp.map.info are maintained by ensuring dependencies +/// are added to the output sets before operations based on them. +template <> +void collectRewrite(omp::MapInfoOp mapOp, + llvm::SetVector<omp::MapInfoOp> &rewrites) { + for (Value member : mapOp.getMembers()) + collectRewrite(cast<omp::MapInfoOp>(member.getDefiningOp()), rewrites); + + rewrites.insert(mapOp); +} + +/// Add the given value to a sorted set if it should be replaced by a +/// placeholder when used as an operand that must remain for the device. +/// +/// Values that are block arguments of \c func.func operations are skipped, +/// since they will still be available after all rewrites are completed. +static void collectRewrite(Value value, llvm::SetVector<Value> &rewrites) { + if ((isa<BlockArgument>(value) && + isa<func::FuncOp>( + cast<BlockArgument>(value).getOwner()->getParentOp())) || + rewrites.contains(value)) + return; + + rewrites.insert(value); +} + namespace { class FunctionFilteringPass : public flangomp::impl::FunctionFilteringPassBase<FunctionFilteringPass> { @@ -90,10 +165,17 @@ public: // Remove the callOp callOp->erase(); } + if (!hasTargetRegion) { funcOp.erase(); return WalkResult::skip(); } + + if (failed(rewriteHostFunction(funcOp))) { + funcOp.emitOpError() << "could not be rewritten for target device"; + return WalkResult::interrupt(); + } + if (declareTargetOp) declareTargetOp.setDeclareTarget( declareType, omp::DeclareTargetCaptureClause::to, @@ -101,6 +183,311 @@ public: } return WalkResult::advance(); }); + + checkDeviceImplementationStatus(op); + } + +private: + /// Rewrite the given host device function containing \c omp.target + /// operations, to remove host-only operations that are not used by device + /// codegen. + /// + /// It is based on the expected form of the MLIR module as produced by Flang + /// lowering and it performs the following mutations: + /// - Replace all values returned by the function with \c fir.undefined. + /// - \c omp.target operations are moved to the end of the function. If they + /// are nested inside of any other operations, they are hoisted out of + /// them. + /// - \c depend, \c device and \c if clauses are removed from these target + /// functions. Values used to initialize other clauses are replaced by + /// placeholders as follows: + /// - Values defined by block arguments are replaced by placeholders only + /// if they are not attached to the parent \c func.func operation. In + /// that case, they are passed unmodified. + /// - \c arith.constant and \c fir.address_of ops are maintained. + /// - Values of type \c fir.boxchar are replaced with a combination of + /// \c fir.alloca for a single bit and a \c fir.emboxchar. + /// - Other values are replaced by a combination of an \c fir.alloca for a + /// single bit and an \c fir.convert to the original type of the value. + /// This can be done because the code eventually generated for these + /// operations will be discarded, as they aren't runnable by the target + /// device. + /// - \c omp.map.info operations associated to these target regions are + /// preserved. These are moved above all \c omp.target and sorted to + /// satisfy dependencies among them. + /// - \c bounds arguments are removed from \c omp.map.info operations. + /// - \c var_ptr and \c var_ptr_ptr arguments of \c omp.map.info are + /// handled as follows: + /// - \c var_ptr_ptr is expected to be defined by a \c fir.box_offset + /// operation which is preserved. Otherwise, the pass will fail. + /// - \c var_ptr can be defined by an \c hlfir.declare which is also + /// preserved. Its \c memref argument is replaced by a placeholder or + /// maintained, similarly to non-map clauses of target operations + /// described above. If it has \c shape or \c typeparams arguments, they + /// are replaced by applicable constants. \c dummy_scope arguments + /// are discarded. + /// - Every other operation not located inside of an \c omp.target is + /// removed. + LogicalResult rewriteHostFunction(func::FuncOp funcOp) { + Region ®ion = funcOp.getRegion(); + + // Collect target operations inside of the function. + llvm::SmallVector<omp::TargetOp> targetOps; + region.walk<WalkOrder::PreOrder>([&](Operation *op) { + // Skip the inside of omp.target regions, since these contain device code. + if (auto targetOp = dyn_cast<omp::TargetOp>(op)) { + targetOps.push_back(targetOp); + return WalkResult::skip(); + } + + // Replace omp.target_data entry block argument uses with the value used + // to initialize the associated omp.map.info operation. This way, + // references are still valid once the omp.target operation has been + // extracted out of the omp.target_data region. + if (auto targetDataOp = dyn_cast<omp::TargetDataOp>(op)) { + llvm::SmallVector<std::pair<Value, BlockArgument>> argPairs; + cast<omp::BlockArgOpenMPOpInterface>(*targetDataOp) + .getBlockArgsPairs(argPairs); + for (auto [operand, blockArg] : argPairs) { + auto mapInfo = cast<omp::MapInfoOp>(operand.getDefiningOp()); + Value varPtr = mapInfo.getVarPtr(); + + // If the var_ptr operand of the omp.map.info op defining this entry + // block argument is an hlfir.declare, the uses of all users of that + // entry block argument that are themselves hlfir.declare are replaced + // by values produced by the outer one. + // + // This prevents this pass from producing chains of hlfir.declare of + // the type: + // %0 = ... + // %1:2 = hlfir.declare %0 + // %2:2 = hlfir.declare %1#1... + // %3 = omp.map.info var_ptr(%2#1 ... + if (auto outerDeclare = varPtr.getDefiningOp<hlfir::DeclareOp>()) + for (Operation *user : blockArg.getUsers()) + if (isa<hlfir::DeclareOp>(user)) + user->replaceAllUsesWith(outerDeclare); + + // All remaining uses of the entry block argument are replaced with + // the var_ptr initialization value. + blockArg.replaceAllUsesWith(varPtr); + } + } + return WalkResult::advance(); + }); + + // Make a temporary clone of the parent operation with an empty region, + // and update all references to entry block arguments to those of the new + // region. Users will later either be moved to the new region or deleted + // when the original region is replaced by the new. + OpBuilder builder(&getContext()); + builder.setInsertionPointAfter(funcOp); + Operation *newOp = builder.cloneWithoutRegions(funcOp); + Block &block = newOp->getRegion(0).emplaceBlock(); + + llvm::SmallVector<Location> locs; + locs.reserve(region.getNumArguments()); + llvm::transform(region.getArguments(), std::back_inserter(locs), + [](const BlockArgument &arg) { return arg.getLoc(); }); + block.addArguments(region.getArgumentTypes(), locs); + + for (auto [oldArg, newArg] : + llvm::zip_equal(region.getArguments(), block.getArguments())) + oldArg.replaceAllUsesWith(newArg); + + // Collect omp.map.info ops while satisfying interdependencies and remove + // operands that aren't used by target device codegen. + // + // This logic must be updated whenever operands to omp.target change. + llvm::SetVector<Value> rewriteValues; + llvm::SetVector<omp::MapInfoOp> mapInfos; + for (omp::TargetOp targetOp : targetOps) { + assert(targetOp.getHostEvalVars().empty() && + "unexpected host_eval in target device module"); + + // Variables unused by the device. + targetOp.getDependVarsMutable().clear(); + targetOp.setDependKindsAttr(nullptr); + targetOp.getDeviceMutable().clear(); + targetOp.getIfExprMutable().clear(); + + // TODO: Clear some of these operands rather than rewriting them, + // depending on whether they are needed by device codegen once support for + // them is fully implemented. + for (Value allocVar : targetOp.getAllocateVars()) + collectRewrite(allocVar, rewriteValues); + for (Value allocVar : targetOp.getAllocatorVars()) + collectRewrite(allocVar, rewriteValues); + for (Value inReduction : targetOp.getInReductionVars()) + collectRewrite(inReduction, rewriteValues); + for (Value isDevPtr : targetOp.getIsDevicePtrVars()) + collectRewrite(isDevPtr, rewriteValues); + for (Value mapVar : targetOp.getHasDeviceAddrVars()) + collectRewrite(cast<omp::MapInfoOp>(mapVar.getDefiningOp()), mapInfos); + for (Value mapVar : targetOp.getMapVars()) + collectRewrite(cast<omp::MapInfoOp>(mapVar.getDefiningOp()), mapInfos); + for (Value privateVar : targetOp.getPrivateVars()) + collectRewrite(privateVar, rewriteValues); + for (Value threadLimit : targetOp.getThreadLimitVars()) + collectRewrite(threadLimit, rewriteValues); + } + + // Move omp.map.info ops to the new block and collect dependencies. + llvm::SetVector<hlfir::DeclareOp> declareOps; + llvm::SetVector<fir::BoxOffsetOp> boxOffsets; + for (omp::MapInfoOp mapOp : mapInfos) { + if (auto declareOp = dyn_cast_if_present<hlfir::DeclareOp>( + mapOp.getVarPtr().getDefiningOp())) + collectRewrite(declareOp, declareOps); + else + collectRewrite(mapOp.getVarPtr(), rewriteValues); + + if (Value varPtrPtr = mapOp.getVarPtrPtr()) { + if (auto boxOffset = llvm::dyn_cast_if_present<fir::BoxOffsetOp>( + varPtrPtr.getDefiningOp())) + collectRewrite(boxOffset, boxOffsets); + else + return mapOp->emitOpError() << "var_ptr_ptr rewrite only supported " + "if defined by fir.box_offset"; + } + + // Bounds are not used during target device codegen. + mapOp.getBoundsMutable().clear(); + mapOp->moveBefore(&block, block.end()); + } + + // Create a temporary marker to simplify the op moving process below. + builder.setInsertionPointToStart(&block); + auto marker = fir::UndefOp::create(builder, builder.getUnknownLoc(), + builder.getNoneType()); + builder.setInsertionPoint(marker); + + // Handle dependencies of hlfir.declare ops. + for (hlfir::DeclareOp declareOp : declareOps) { + collectRewrite(declareOp.getMemref(), rewriteValues); + + if (declareOp.getStorage()) + collectRewrite(declareOp.getStorage(), rewriteValues); + + // Shape and typeparams aren't needed for target device codegen, but + // removing them would break verifiers. + Value zero; + if (declareOp.getShape() || !declareOp.getTypeparams().empty()) + zero = arith::ConstantOp::create(builder, declareOp.getLoc(), + builder.getI64IntegerAttr(0)); + + if (auto shape = declareOp.getShape()) { + // The pre-cg rewrite pass requires the shape to be defined by one of + // fir.shape, fir.shapeshift or fir.shift, so we need to make sure it's + // still defined by one of these after this pass. + Operation *shapeOp = shape.getDefiningOp(); + llvm::SmallVector<Value> extents(shapeOp->getNumOperands(), zero); + Value newShape = + llvm::TypeSwitch<Operation *, Value>(shapeOp) + .Case([&](fir::ShapeOp op) { + return fir::ShapeOp::create(builder, op.getLoc(), extents); + }) + .Case([&](fir::ShapeShiftOp op) { + auto type = fir::ShapeShiftType::get(op.getContext(), + extents.size() / 2); + return fir::ShapeShiftOp::create(builder, op.getLoc(), type, + extents); + }) + .Case([&](fir::ShiftOp op) { + auto type = + fir::ShiftType::get(op.getContext(), extents.size()); + return fir::ShiftOp::create(builder, op.getLoc(), type, + extents); + }) + .Default([](Operation *op) { + op->emitOpError() + << "hlfir.declare shape expected to be one of: " + "fir.shape, fir.shapeshift or fir.shift"; + return nullptr; + }); + + if (!newShape) + return failure(); + + declareOp.getShapeMutable().assign(newShape); + } + + for (OpOperand &typeParam : declareOp.getTypeparamsMutable()) + typeParam.assign(zero); + + declareOp.getDummyScopeMutable().clear(); + } + + // We don't actually need the proper initialization, but rather just + // maintain the basic form of these operands. Generally, we create 1-bit + // placeholder allocas that we "typecast" to the expected type and replace + // all uses. Using fir.undefined here instead is not possible because these + // variables cannot be constants, as that would trigger different codegen + // for target regions. + for (Value value : rewriteValues) { + Location loc = value.getLoc(); + Value rewriteValue; + if (isa_and_present<arith::ConstantOp, fir::AddrOfOp>( + value.getDefiningOp())) { + // If it's defined by fir.address_of, then we need to keep that op as + // well because it might be pointing to a 'declare target' global. + // Constants can also trigger different codegen paths, so we keep them + // as well. + rewriteValue = builder.clone(*value.getDefiningOp())->getResult(0); + } else if (auto boxCharType = + dyn_cast<fir::BoxCharType>(value.getType())) { + // !fir.boxchar types cannot be directly obtained by converting a + // !fir.ref<i1>, as they aren't reference types. Since they can appear + // representing some `target firstprivate` clauses, we need to create + // a special case here based on creating a placeholder fir.emboxchar op. + MLIRContext *ctx = &getContext(); + fir::KindTy kind = boxCharType.getKind(); + auto placeholder = fir::AllocaOp::create( + builder, loc, fir::CharacterType::getSingleton(ctx, kind)); + auto one = arith::ConstantOp::create(builder, loc, builder.getI32Type(), + builder.getI32IntegerAttr(1)); + rewriteValue = fir::EmboxCharOp::create(builder, loc, boxCharType, + placeholder, one); + } else { + Value placeholder = + fir::AllocaOp::create(builder, loc, builder.getI1Type()); + rewriteValue = + fir::ConvertOp::create(builder, loc, value.getType(), placeholder); + } + value.replaceAllUsesWith(rewriteValue); + } + + // Move omp.map.info dependencies. + for (hlfir::DeclareOp declareOp : declareOps) + declareOp->moveBefore(marker); + + // The box_ref argument of fir.box_offset is expected to be the same value + // that was passed as var_ptr to the corresponding omp.map.info, so we don't + // need to handle its defining op here. + for (fir::BoxOffsetOp boxOffset : boxOffsets) + boxOffset->moveBefore(marker); + + marker->erase(); + + // Move target operations to the end of the new block. + for (omp::TargetOp targetOp : targetOps) + targetOp->moveBefore(&block, block.end()); + + // Add terminator to the new block. + builder.setInsertionPointToEnd(&block); + llvm::SmallVector<Value> returnValues; + returnValues.reserve(funcOp.getNumResults()); + for (auto type : funcOp.getResultTypes()) + returnValues.push_back( + fir::UndefOp::create(builder, funcOp.getLoc(), type)); + + func::ReturnOp::create(builder, funcOp.getLoc(), returnValues); + + // Replace old region (now missing ops) with the new one and remove the + // temporary operation clone. + region.takeBody(newOp->getRegion(0)); + newOp->erase(); + return success(); } }; } // namespace diff --git a/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp b/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp index 5aa1273..be0bdb7 100644 --- a/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp @@ -41,7 +41,7 @@ class LowerNontemporalPass operand = op.getMemref(); defOp = operand.getDefiningOp(); }) - .Case<fir::BoxAddrOp>([&](auto op) { + .Case([&](fir::BoxAddrOp op) { operand = op.getVal(); defOp = operand.getDefiningOp(); }) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 9278e17..2c79800 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -282,14 +282,14 @@ fissionWorkdistribute(omp::WorkdistributeOp workdistribute) { &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {}); for (auto arg : teamsBlock->getArguments()) newTeamsBlock->addArgument(arg.getType(), arg.getLoc()); - auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc); - rewriter.create<omp::TerminatorOp>(loc); + auto newWorkdistribute = omp::WorkdistributeOp::create(rewriter, loc); + omp::TerminatorOp::create(rewriter, loc); rewriter.createBlock(&newWorkdistribute.getRegion(), newWorkdistribute.getRegion().begin(), {}, {}); auto *cloned = rewriter.clone(*parallelize); parallelize->replaceAllUsesWith(cloned); parallelize->erase(); - rewriter.create<omp::TerminatorOp>(loc); + omp::TerminatorOp::create(rewriter, loc); changed = true; } } @@ -298,10 +298,10 @@ fissionWorkdistribute(omp::WorkdistributeOp workdistribute) { /// Generate omp.parallel operation with an empty region. static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { - auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc); + auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc); parallelOp.setComposite(composite); rewriter.createBlock(¶llelOp.getRegion()); - rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc)); + rewriter.setInsertionPoint(mlir::omp::TerminatorOp::create(rewriter, loc)); return; } @@ -309,7 +309,7 @@ static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) { mlir::omp::DistributeOperands distributeClauseOps; auto distributeOp = - rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps); + mlir::omp::DistributeOp::create(rewriter, loc, distributeClauseOps); distributeOp.setComposite(composite); auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion()); rewriter.setInsertionPointToStart(distributeBlock); @@ -334,12 +334,12 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, const mlir::omp::LoopNestOperands &clauseOps, bool composite) { - auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc()); + auto wsloopOp = mlir::omp::WsloopOp::create(rewriter, doLoop.getLoc()); wsloopOp.setComposite(composite); rewriter.createBlock(&wsloopOp.getRegion()); auto loopNestOp = - rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps); + mlir::omp::LoopNestOp::create(rewriter, doLoop.getLoc(), clauseOps); // Clone the loop's body inside the loop nest construct using the // mapped values. @@ -351,7 +351,7 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, // Erase fir.result op of do loop and create yield op. if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) { rewriter.setInsertionPoint(terminatorOp); - rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc()); + mlir::omp::YieldOp::create(rewriter, doLoop->getLoc()); terminatorOp->erase(); } } @@ -494,15 +494,15 @@ static SmallVector<Value> convertFlatToMultiDim(OpBuilder &builder, // Convert flat index to multi-dimensional indices SmallVector<Value> indices(rank); Value temp = flatIdx; - auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1); + auto c1 = arith::ConstantIndexOp::create(builder, loc, 1); // Work backwards through dimensions (row-major order) for (int i = rank - 1; i >= 0; --i) { - Value zeroBasedIdx = builder.create<arith::RemSIOp>(loc, temp, extents[i]); + Value zeroBasedIdx = arith::RemSIOp::create(builder, loc, temp, extents[i]); // Convert to one-based index - indices[i] = builder.create<arith::AddIOp>(loc, zeroBasedIdx, c1); + indices[i] = arith::AddIOp::create(builder, loc, zeroBasedIdx, c1); if (i > 0) { - temp = builder.create<arith::DivSIOp>(loc, temp, extents[i]); + temp = arith::DivSIOp::create(builder, loc, temp, extents[i]); } } @@ -525,7 +525,7 @@ static Value CalculateTotalElements(OpBuilder &builder, Location loc, if (i == 0) { totalElems = extent; } else { - totalElems = builder.create<arith::MulIOp>(loc, totalElems, extent); + totalElems = arith::MulIOp::create(builder, loc, totalElems, extent); } } return totalElems; @@ -562,14 +562,14 @@ static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc, // Load destination array box (if it's a reference) Value arrayBox = destBox; if (isa<fir::ReferenceType>(destBox.getType())) - arrayBox = builder.create<fir::LoadOp>(loc, destBox); + arrayBox = fir::LoadOp::create(builder, loc, destBox); - auto scalarValue = builder.create<fir::BoxAddrOp>(loc, srcBox); - Value scalar = builder.create<fir::LoadOp>(loc, scalarValue); + auto scalarValue = fir::BoxAddrOp::create(builder, loc, srcBox); + Value scalar = fir::LoadOp::create(builder, loc, scalarValue); // Calculate total number of elements (flattened) - auto c0 = builder.create<arith::ConstantIndexOp>(loc, 0); - auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1); + auto c0 = arith::ConstantIndexOp::create(builder, loc, 0); + auto c1 = arith::ConstantIndexOp::create(builder, loc, 1); Value totalElems = CalculateTotalElements(builder, loc, arrayBox); auto *workdistributeBlock = &workdistribute.getRegion().front(); @@ -587,7 +587,7 @@ static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc, builder, loc, fir::ReferenceType::get(scalar.getType()), arrayBox, nullptr, nullptr, ValueRange{indices}, ValueRange{}); - builder.create<fir::StoreOp>(loc, scalar, elemPtr); + fir::StoreOp::create(builder, loc, scalar, elemPtr); } /// workdistributeRuntimeCallLower method finds the runtime calls @@ -719,10 +719,9 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, SmallVector<Value> outerMapInfos; // Create new mapinfo ops for the inner target region for (auto mapInfo : mapInfos) { - auto originalMapType = - (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType()); + mlir::omp::ClauseMapFlags originalMapType = mapInfo.getMapType(); auto originalCaptureType = mapInfo.getMapCaptureType(); - llvm::omp::OpenMPOffloadMappingFlags newMapType; + mlir::omp::ClauseMapFlags newMapType; mlir::omp::VariableCaptureKind newCaptureType; // For bycopy, we keep the same map type and capture type // For byref, we change the map type to none and keep the capture type @@ -730,7 +729,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, newMapType = originalMapType; newCaptureType = originalCaptureType; } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) { - newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + newMapType = mlir::omp::ClauseMapFlags::storage; newCaptureType = originalCaptureType; outerMapInfos.push_back(mapInfo); } else { @@ -738,11 +737,8 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, return failure(); } auto innerMapInfo = cast<omp::MapInfoOp>(rewriter.clone(*mapInfo)); - innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr( - rewriter.getIntegerType(64, false), - static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - newMapType))); + innerMapInfo.setMapTypeAttr( + rewriter.getAttr<omp::ClauseMapFlagsAttr>(newMapType)); innerMapInfo.setMapCaptureType(newCaptureType); innerMapInfos.push_back(innerMapInfo.getResult()); } @@ -753,14 +749,15 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, auto deviceAddrVars = targetOp.getHasDeviceAddrVars(); auto devicePtrVars = targetOp.getIsDevicePtrVars(); // Create the target data op - auto targetDataOp = rewriter.create<omp::TargetDataOp>( - loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars); + auto targetDataOp = + omp::TargetDataOp::create(rewriter, loc, device, ifExpr, outerMapInfos, + deviceAddrVars, devicePtrVars); auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion()); - rewriter.create<mlir::omp::TerminatorOp>(loc); + mlir::omp::TerminatorOp::create(rewriter, loc); rewriter.setInsertionPointToStart(taregtDataBlock); // Create the inner target op - auto newTargetOp = rewriter.create<omp::TargetOp>( - targetOp.getLoc(), targetOp.getAllocateVars(), + auto newTargetOp = omp::TargetOp::create( + rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), @@ -769,7 +766,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), - targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + targetOp.getThreadLimitVars(), targetOp.getPrivateMapsAttr()); rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(), newTargetOp.getRegion().begin()); rewriter.replaceOp(targetOp, targetDataOp); @@ -825,20 +822,20 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty, // Get the appropriate type for allocation if (isPtr(ty)) { Type intTy = rewriter.getI32Type(); - auto one = rewriter.create<LLVM::ConstantOp>(loc, intTy, 1); + auto one = LLVM::ConstantOp::create(rewriter, loc, intTy, 1); allocType = llvmPtrTy; - alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one); + alloc = LLVM::AllocaOp::create(rewriter, loc, llvmPtrTy, allocType, one); allocType = intTy; } else { allocType = ty; - alloc = rewriter.create<fir::AllocaOp>(loc, allocType); + alloc = fir::AllocaOp::create(rewriter, loc, allocType); } // Lambda to create mapinfo ops - auto getMapInfo = [&](uint64_t mappingFlags, const char *name) { - return rewriter.create<omp::MapInfoOp>( - loc, alloc.getType(), alloc, TypeAttr::get(allocType), - rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false), - mappingFlags), + auto getMapInfo = [&](mlir::omp::ClauseMapFlags mappingFlags, + const char *name) { + return omp::MapInfoOp::create( + rewriter, loc, alloc.getType(), alloc, TypeAttr::get(allocType), + rewriter.getAttr<omp::ClauseMapFlagsAttr>(mappingFlags), rewriter.getAttr<omp::VariableCaptureKindAttr>( omp::VariableCaptureKind::ByRef), /*varPtrPtr=*/Value{}, @@ -849,14 +846,10 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty, /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false)); }; // Create mapinfo ops. - uint64_t mapFrom = - static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); - uint64_t mapTo = - static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); - auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from"); - auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to"); + auto mapInfoFrom = getMapInfo(mlir::omp::ClauseMapFlags::from, + "__flang_workdistribute_from"); + auto mapInfoTo = + getMapInfo(mlir::omp::ClauseMapFlags::to, "__flang_workdistribute_to"); return TempOmpVar{mapInfoFrom, mapInfoTo}; } @@ -987,12 +980,12 @@ static void reloadCacheAndRecompute( // If the original value is a pointer or reference, load and convert if // necessary. if (isPtr(original.getType())) { - restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg); + restored = LLVM::LoadOp::create(rewriter, loc, llvmPtrTy, newArg); if (!isa<LLVM::LLVMPointerType>(original.getType())) restored = - rewriter.create<fir::ConvertOp>(loc, original.getType(), restored); + fir::ConvertOp::create(rewriter, loc, original.getType(), restored); } else { - restored = rewriter.create<fir::LoadOp>(loc, newArg); + restored = fir::LoadOp::create(rewriter, loc, newArg); } irMapping.map(original, restored); } @@ -1061,7 +1054,7 @@ static mlir::LLVM::ConstantOp genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { mlir::Type i32Ty = rewriter.getI32Type(); mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value); - return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr); + return mlir::LLVM::ConstantOp::create(rewriter, loc, i32Ty, attr); } /// Given a box descriptor, extract the base address of the data it describes. @@ -1238,8 +1231,8 @@ static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder, genOmpGetMappedPtrIfPresent(builder, loc, destBase, device, module); Value srcPtr = genOmpGetMappedPtrIfPresent(builder, loc, srcBase, device, module); - Value zero = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(), - builder.getI64IntegerAttr(0)); + Value zero = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(), + builder.getI64IntegerAttr(0)); // Generate the call to omp_target_memcpy to perform the data copy on the // device. @@ -1356,23 +1349,24 @@ static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, for (Operation *op : opsToReplace) { if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) { rewriter.setInsertionPoint(allocOp); - auto ompAllocmemOp = rewriter.create<omp::TargetAllocMemOp>( - allocOp.getLoc(), rewriter.getI64Type(), device, + auto ompAllocmemOp = omp::TargetAllocMemOp::create( + rewriter, allocOp.getLoc(), rewriter.getI64Type(), device, allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(), allocOp.getBindcNameAttr(), allocOp.getTypeparams(), allocOp.getShape()); - auto firConvertOp = rewriter.create<fir::ConvertOp>( - allocOp.getLoc(), allocOp.getResult().getType(), - ompAllocmemOp.getResult()); + auto firConvertOp = fir::ConvertOp::create(rewriter, allocOp.getLoc(), + allocOp.getResult().getType(), + ompAllocmemOp.getResult()); rewriter.replaceOp(allocOp, firConvertOp.getResult()); } // Replace fir.freemem with omp.target_freemem. else if (auto freeOp = dyn_cast<fir::FreeMemOp>(op)) { rewriter.setInsertionPoint(freeOp); - auto firConvertOp = rewriter.create<fir::ConvertOp>( - freeOp.getLoc(), rewriter.getI64Type(), freeOp.getHeapref()); - rewriter.create<omp::TargetFreeMemOp>(freeOp.getLoc(), device, - firConvertOp.getResult()); + auto firConvertOp = + fir::ConvertOp::create(rewriter, freeOp.getLoc(), + rewriter.getI64Type(), freeOp.getHeapref()); + omp::TargetFreeMemOp::create(rewriter, freeOp.getLoc(), device, + firConvertOp.getResult()); rewriter.eraseOp(freeOp); } // fir.declare changes its type when hoisting it out of omp.target to @@ -1384,8 +1378,9 @@ static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, dyn_cast<fir::ReferenceType>(clonedInType); Type clonedEleTy = clonedRefType.getElementType(); rewriter.setInsertionPoint(op); - Value loadedValue = rewriter.create<fir::LoadOp>( - clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref()); + Value loadedValue = + fir::LoadOp::create(rewriter, clonedDeclareOp.getLoc(), clonedEleTy, + clonedDeclareOp.getMemref()); clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue); } // Replace runtime calls with omp versions. @@ -1481,8 +1476,8 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands, auto *targetBlock = &targetOp.getRegion().front(); SmallVector<Value> preHostEvalVars{targetOp.getHostEvalVars()}; // update the hostEvalVars of preTargetOp - omp::TargetOp preTargetOp = rewriter.create<omp::TargetOp>( - targetOp.getLoc(), targetOp.getAllocateVars(), + omp::TargetOp preTargetOp = omp::TargetOp::create( + rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars, @@ -1490,7 +1485,7 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands, targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), - targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVars(), targetOp.getPrivateMapsAttr()); auto *preTargetBlock = rewriter.createBlock( &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {}); @@ -1521,13 +1516,13 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands, // Create the store operation. if (isPtr(originalResult.getType())) { if (!isa<LLVM::LLVMPointerType>(toStore.getType())) - toStore = rewriter.create<fir::ConvertOp>(loc, llvmPtrTy, toStore); - rewriter.create<LLVM::StoreOp>(loc, toStore, newArg); + toStore = fir::ConvertOp::create(rewriter, loc, llvmPtrTy, toStore); + LLVM::StoreOp::create(rewriter, loc, toStore, newArg); } else { - rewriter.create<fir::StoreOp>(loc, toStore, newArg); + fir::StoreOp::create(rewriter, loc, toStore, newArg); } } - rewriter.create<omp::TerminatorOp>(loc); + omp::TerminatorOp::create(rewriter, loc); // Update hostEvalVars with the mapped values for the loop bounds if we have // a loopNestOp and we are not generating code for the target device. @@ -1571,8 +1566,8 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands, hostEvalVars.steps.end()); } // Create the isolated target op - omp::TargetOp isolatedTargetOp = rewriter.create<omp::TargetOp>( - targetOp.getLoc(), targetOp.getAllocateVars(), + omp::TargetOp isolatedTargetOp = omp::TargetOp::create( + rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), @@ -1580,7 +1575,7 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands, targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), - targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVars(), targetOp.getPrivateMapsAttr()); auto *isolatedTargetBlock = rewriter.createBlock(&isolatedTargetOp.getRegion(), @@ -1598,7 +1593,7 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands, // Clone the original operations. rewriter.clone(*splitBeforeOp, isolatedMapping); - rewriter.create<omp::TerminatorOp>(loc); + omp::TerminatorOp::create(rewriter, loc); // update the loop bounds in the isolatedTargetOp if we have host_eval vars // and we are not generating code for the target device. @@ -1651,8 +1646,8 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp, auto *targetBlock = &targetOp.getRegion().front(); SmallVector<Value> postHostEvalVars{targetOp.getHostEvalVars()}; // Create the post target op - omp::TargetOp postTargetOp = rewriter.create<omp::TargetOp>( - targetOp.getLoc(), targetOp.getAllocateVars(), + omp::TargetOp postTargetOp = omp::TargetOp::create( + rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars, @@ -1660,7 +1655,7 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp, targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), - targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVars(), targetOp.getPrivateMapsAttr()); // Create the block for postTargetOp auto *postTargetBlock = rewriter.createBlock( diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp index 2bbd803..a60960e 100644 --- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp +++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp @@ -43,7 +43,6 @@ #include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringSet.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Support/raw_ostream.h" #include <algorithm> #include <cstddef> @@ -348,10 +347,10 @@ class MapInfoFinalizationPass /// base address (BoxOffsetOp) and a MapInfoOp for it. The most /// important thing to note is that we normally move the bounds from /// the descriptor map onto the base address map. - mlir::omp::MapInfoOp genBaseAddrMap(mlir::Value descriptor, - mlir::OperandRange bounds, - int64_t mapType, - fir::FirOpBuilder &builder) { + mlir::omp::MapInfoOp + genBaseAddrMap(mlir::Value descriptor, mlir::OperandRange bounds, + mlir::omp::ClauseMapFlags mapType, fir::FirOpBuilder &builder, + mlir::FlatSymbolRefAttr mapperId = mlir::FlatSymbolRefAttr()) { mlir::Location loc = descriptor.getLoc(); mlir::Value baseAddrAddr = fir::BoxOffsetOp::create( builder, loc, descriptor, fir::BoxFieldAttr::base_addr); @@ -368,12 +367,12 @@ class MapInfoFinalizationPass return mlir::omp::MapInfoOp::create( builder, loc, baseAddrAddr.getType(), descriptor, mlir::TypeAttr::get(underlyingVarType), - builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(mapType), builder.getAttr<mlir::omp::VariableCaptureKindAttr>( mlir::omp::VariableCaptureKind::ByRef), baseAddrAddr, /*members=*/mlir::SmallVector<mlir::Value>{}, /*membersIndex=*/mlir::ArrayAttr{}, bounds, - /*mapperId*/ mlir::FlatSymbolRefAttr(), + /*mapperId=*/mapperId, /*name=*/builder.getStringAttr(""), /*partial_map=*/builder.getBoolAttr(false)); } @@ -428,22 +427,36 @@ class MapInfoFinalizationPass /// allowing `to` mappings, and `target update` not allowing both `to` and /// `from` simultaneously. We currently try to maintain the `implicit` flag /// where necessary, although it does not seem strictly required. - unsigned long getDescriptorMapType(unsigned long mapTypeFlag, - mlir::Operation *target) { - using mapFlags = llvm::omp::OpenMPOffloadMappingFlags; + mlir::omp::ClauseMapFlags + getDescriptorMapType(mlir::omp::ClauseMapFlags mapTypeFlag, + mlir::Operation *target) { + using mapFlags = mlir::omp::ClauseMapFlags; if (llvm::isa_and_nonnull<mlir::omp::TargetExitDataOp, mlir::omp::TargetUpdateOp>(target)) return mapTypeFlag; - mapFlags flags = mapFlags::OMP_MAP_TO | - (mapFlags(mapTypeFlag) & - (mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_ALWAYS)); + mapFlags flags = + mapFlags::to | (mapTypeFlag & (mapFlags::implicit | mapFlags::always)); + + // Descriptors for objects will always be copied. This is because the + // descriptor can be rematerialized by the compiler, and so the address + // of the descriptor for a given object at one place in the code may + // differ from that address in another place. The contents of the + // descriptor (the base address in particular) will remain unchanged + // though. + // TODO/FIXME: We currently cannot have MAP_CLOSE and MAP_ALWAYS on + // the descriptor at once, these are mutually exclusive and when + // both are applied the runtime will fail to map. + flags |= ((mapFlags(mapTypeFlag) & mapFlags::close) == mapFlags::close) + ? mapFlags::close + : mapFlags::always; + // For unified_shared_memory, we additionally add `CLOSE` on the descriptor // to ensure device-local placement where required by tests relying on USM + // close semantics. if (moduleRequiresUSM(target->getParentOfType<mlir::ModuleOp>())) - flags |= mapFlags::OMP_MAP_CLOSE; - return llvm::to_underlying(flags); + flags |= mapFlags::close; + return flags; } /// Check if the mapOp is present in the HasDeviceAddr clause on @@ -478,62 +491,6 @@ class MapInfoFinalizationPass return false; } - mlir::omp::MapInfoOp genBoxcharMemberMap(mlir::omp::MapInfoOp op, - fir::FirOpBuilder &builder) { - if (!op.getMembers().empty()) - return op; - mlir::Location loc = op.getVarPtr().getLoc(); - mlir::Value boxChar = op.getVarPtr(); - - if (mlir::isa<fir::ReferenceType>(op.getVarPtr().getType())) - boxChar = fir::LoadOp::create(builder, loc, op.getVarPtr()); - - fir::BoxCharType boxCharType = - mlir::dyn_cast<fir::BoxCharType>(boxChar.getType()); - mlir::Value boxAddr = fir::BoxOffsetOp::create( - builder, loc, op.getVarPtr(), fir::BoxFieldAttr::base_addr); - - uint64_t mapTypeToImplicit = static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT); - - mlir::ArrayAttr newMembersAttr; - llvm::SmallVector<llvm::SmallVector<int64_t>> memberIdx = {{0}}; - newMembersAttr = builder.create2DI64ArrayAttr(memberIdx); - - mlir::Value varPtr = op.getVarPtr(); - mlir::omp::MapInfoOp memberMapInfoOp = mlir::omp::MapInfoOp::create( - builder, op.getLoc(), varPtr.getType(), varPtr, - mlir::TypeAttr::get(boxCharType.getEleTy()), - builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false), - mapTypeToImplicit), - builder.getAttr<mlir::omp::VariableCaptureKindAttr>( - mlir::omp::VariableCaptureKind::ByRef), - /*varPtrPtr=*/boxAddr, - /*members=*/llvm::SmallVector<mlir::Value>{}, - /*member_index=*/mlir::ArrayAttr{}, - /*bounds=*/op.getBounds(), - /*mapperId=*/mlir::FlatSymbolRefAttr(), /*name=*/op.getNameAttr(), - builder.getBoolAttr(false)); - - mlir::omp::MapInfoOp newMapInfoOp = mlir::omp::MapInfoOp::create( - builder, op.getLoc(), op.getResult().getType(), varPtr, - mlir::TypeAttr::get( - llvm::cast<mlir::omp::PointerLikeType>(varPtr.getType()) - .getElementType()), - op.getMapTypeAttr(), op.getMapCaptureTypeAttr(), - /*varPtrPtr=*/mlir::Value{}, - /*members=*/llvm::SmallVector<mlir::Value>{memberMapInfoOp}, - /*member_index=*/newMembersAttr, - /*bounds=*/llvm::SmallVector<mlir::Value>{}, - /*mapperId=*/mlir::FlatSymbolRefAttr(), op.getNameAttr(), - /*partial_map=*/builder.getBoolAttr(false)); - op.replaceAllUsesWith(newMapInfoOp.getResult()); - op->erase(); - return newMapInfoOp; - } - // Expand mappings of type(C_PTR) to map their `__address` field explicitly // as a single pointer-sized member (USM-gated at callsite). This helps in // USM scenarios to ensure the pointer-sized mapping is used. @@ -568,12 +525,9 @@ class MapInfoFinalizationPass mlir::ArrayAttr newMembersAttr = builder.create2DI64ArrayAttr(memberIdx); // Force CLOSE in USM paths so the pointer gets device-local placement // when required by tests relying on USM + close semantics. - uint64_t mapTypeVal = - op.getMapType() | - llvm::to_underlying( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE); - mlir::IntegerAttr mapTypeAttr = builder.getIntegerAttr( - builder.getIntegerType(64, /*isSigned=*/false), mapTypeVal); + mlir::omp::ClauseMapFlagsAttr mapTypeAttr = + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>( + op.getMapType() | mlir::omp::ClauseMapFlags::close); mlir::omp::MapInfoOp memberMap = mlir::omp::MapInfoOp::create( builder, loc, coord.getType(), coord, @@ -638,6 +592,7 @@ class MapInfoFinalizationPass // from the descriptor to be used verbatim, i.e. without additional // remapping. To avoid this remapping, simply don't generate any map // information for the descriptor members. + mlir::FlatSymbolRefAttr mapperId = op.getMapperIdAttr(); if (!mapMemberUsers.empty()) { // Currently, there should only be one user per map when this pass // is executed. Either a parent map, holding the current map in its @@ -648,8 +603,8 @@ class MapInfoFinalizationPass assert(mapMemberUsers.size() == 1 && "OMPMapInfoFinalization currently only supports single users of a " "MapInfoOp"); - auto baseAddr = - genBaseAddrMap(descriptor, op.getBounds(), op.getMapType(), builder); + auto baseAddr = genBaseAddrMap(descriptor, op.getBounds(), + op.getMapType(), builder, mapperId); ParentAndPlacement mapUser = mapMemberUsers[0]; adjustMemberIndices(memberIndices, mapUser.index); llvm::SmallVector<mlir::Value> newMemberOps; @@ -662,8 +617,8 @@ class MapInfoFinalizationPass mapUser.parent.setMembersIndexAttr( builder.create2DI64ArrayAttr(memberIndices)); } else if (!isHasDeviceAddrFlag) { - auto baseAddr = - genBaseAddrMap(descriptor, op.getBounds(), op.getMapType(), builder); + auto baseAddr = genBaseAddrMap(descriptor, op.getBounds(), + op.getMapType(), builder, mapperId); newMembers.push_back(baseAddr); if (!op.getMembers().empty()) { for (auto &indices : memberIndices) @@ -683,20 +638,19 @@ class MapInfoFinalizationPass // one place in the code may differ from that address in another place. // The contents of the descriptor (the base address in particular) will // remain unchanged though. - uint64_t mapType = op.getMapType(); + mlir::omp::ClauseMapFlags mapType = op.getMapType(); if (isHasDeviceAddrFlag) { - mapType |= llvm::to_underlying( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS); + mapType |= mlir::omp::ClauseMapFlags::always; } mlir::omp::MapInfoOp newDescParentMapOp = mlir::omp::MapInfoOp::create( builder, op->getLoc(), op.getResult().getType(), descriptor, mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())), - builder.getIntegerAttr(builder.getIntegerType(64, false), - getDescriptorMapType(mapType, target)), + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>( + getDescriptorMapType(mapType, target)), op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, newMembers, newMembersAttr, /*bounds=*/mlir::SmallVector<mlir::Value>{}, - /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(), + /*mapperId=*/mlir::FlatSymbolRefAttr(), op.getNameAttr(), /*partial_map=*/builder.getBoolAttr(false)); op.replaceAllUsesWith(newDescParentMapOp.getResult()); op->erase(); @@ -892,20 +846,16 @@ class MapInfoFinalizationPass if (explicitMappingPresent(op, targetDataOp)) return; - mlir::omp::MapInfoOp newDescParentMapOp = - builder.create<mlir::omp::MapInfoOp>( - op->getLoc(), op.getResult().getType(), op.getVarPtr(), - op.getVarTypeAttr(), - builder.getIntegerAttr( - builder.getIntegerType(64, false), - llvm::to_underlying( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)), - op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, - mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{}, - /*bounds=*/mlir::SmallVector<mlir::Value>{}, - /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(), - /*partial_map=*/builder.getBoolAttr(false)); + mlir::omp::MapInfoOp newDescParentMapOp = mlir::omp::MapInfoOp::create( + builder, op->getLoc(), op.getResult().getType(), op.getVarPtr(), + op.getVarTypeAttr(), + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>( + mlir::omp::ClauseMapFlags::to | mlir::omp::ClauseMapFlags::always), + op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, + mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{}, + /*bounds=*/mlir::SmallVector<mlir::Value>{}, + /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(), + /*partial_map=*/builder.getBoolAttr(false)); targetDataOp.getMapVarsMutable().append({newDescParentMapOp}); } @@ -957,19 +907,26 @@ class MapInfoFinalizationPass // need to see how well this alteration works. auto loadBaseAddr = builder.loadIfRef(op->getLoc(), baseAddr.getVarPtrPtr()); - mlir::omp::MapInfoOp newBaseAddrMapOp = - builder.create<mlir::omp::MapInfoOp>( - op->getLoc(), loadBaseAddr.getType(), loadBaseAddr, - baseAddr.getVarTypeAttr(), baseAddr.getMapTypeAttr(), - baseAddr.getMapCaptureTypeAttr(), mlir::Value{}, members, - membersAttr, baseAddr.getBounds(), - /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(), - /*partial_map=*/builder.getBoolAttr(false)); + mlir::omp::MapInfoOp newBaseAddrMapOp = mlir::omp::MapInfoOp::create( + builder, op->getLoc(), loadBaseAddr.getType(), loadBaseAddr, + baseAddr.getVarTypeAttr(), baseAddr.getMapTypeAttr(), + baseAddr.getMapCaptureTypeAttr(), mlir::Value{}, members, membersAttr, + baseAddr.getBounds(), + /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(), + /*partial_map=*/builder.getBoolAttr(false)); op.replaceAllUsesWith(newBaseAddrMapOp.getResult()); op->erase(); baseAddr.erase(); } + static bool hasADescriptor(mlir::Operation *varOp, mlir::Type varType) { + if (fir::isTypeWithDescriptor(varType) || + mlir::isa<fir::BoxCharType>(varType) || + mlir::isa_and_present<fir::BoxAddrOp>(varOp)) + return true; + return false; + } + // This pass executes on omp::MapInfoOp's containing descriptor based types // (allocatables, pointers, assumed shape etc.) and expanding them into // multiple omp::MapInfoOp's for each pointer member contained within the @@ -1001,38 +958,6 @@ class MapInfoFinalizationPass localBoxAllocas.clear(); deferrableDesc.clear(); - // First, walk `omp.map.info` ops to see if any of them have varPtrs - // with an underlying type of fir.char<k, ?>, i.e a character - // with dynamic length. If so, check if they need bounds added. - func->walk([&](mlir::omp::MapInfoOp op) { - if (!op.getBounds().empty()) - return; - - mlir::Value varPtr = op.getVarPtr(); - mlir::Type underlyingVarType = fir::unwrapRefType(varPtr.getType()); - - if (!fir::characterWithDynamicLen(underlyingVarType)) - return; - - fir::factory::AddrAndBoundsInfo info = - fir::factory::getDataOperandBaseAddr( - builder, varPtr, /*isOptional=*/false, varPtr.getLoc()); - - fir::ExtendedValue extendedValue = - hlfir::translateToExtendedValue(varPtr.getLoc(), builder, - hlfir::Entity{info.addr}, - /*continguousHint=*/true) - .first; - builder.setInsertionPoint(op); - llvm::SmallVector<mlir::Value> boundsOps = - fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp, - mlir::omp::MapBoundsType>( - builder, info, extendedValue, - /*dataExvIsAssumedSize=*/false, varPtr.getLoc()); - - op.getBoundsMutable().append(boundsOps); - }); - // Next, walk `omp.map.info` ops to see if any record members should be // implicitly mapped. func->walk([&](mlir::omp::MapInfoOp op) { @@ -1218,42 +1143,12 @@ class MapInfoFinalizationPass newMemberIndices.emplace_back(path); op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices)); - op.setPartialMap(true); + // Set to partial map only if there is no user-defined mapper. + op.setPartialMap(op.getMapperIdAttr() == nullptr); return mlir::WalkResult::advance(); }); - func->walk([&](mlir::omp::MapInfoOp op) { - if (!op.getMembers().empty()) - return; - - if (!mlir::isa<fir::BoxCharType>(fir::unwrapRefType(op.getVarType()))) - return; - - // POSSIBLE_HACK_ALERT: If the boxchar has been implicitly mapped then - // it is likely that the underlying pointer to the data - // (!fir.ref<fir.char<k,?>>) has already been mapped. So, skip such - // boxchars. We are primarily interested in boxchars that were mapped - // by passes such as MapsForPrivatizedSymbols that map boxchars that - // are privatized. At present, such boxchar maps are not marked - // implicit. Should they be? I don't know. If they should be then - // we need to change this check for early return OR live with - // over-mapping. - bool hasImplicitMap = - (llvm::omp::OpenMPOffloadMappingFlags(op.getMapType()) & - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT) == - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; - if (hasImplicitMap) - return; - - assert(llvm::hasSingleElement(op->getUsers()) && - "OMPMapInfoFinalization currently only supports single users " - "of a MapInfoOp"); - - builder.setInsertionPoint(op); - genBoxcharMemberMap(op, builder); - }); - // Expand type(C_PTR) only when unified_shared_memory is required, // to ensure device-visible pointer size/behavior in USM scenarios // without changing default expectations elsewhere. @@ -1281,9 +1176,8 @@ class MapInfoFinalizationPass "OMPMapInfoFinalization currently only supports single users " "of a MapInfoOp"); - if (fir::isTypeWithDescriptor(op.getVarType()) || - mlir::isa_and_present<fir::BoxAddrOp>( - op.getVarPtr().getDefiningOp())) { + if (hasADescriptor(op.getVarPtr().getDefiningOp(), + fir::unwrapRefType(op.getVarType()))) { builder.setInsertionPoint(op); mlir::Operation *targetUser = getFirstTargetUser(op); assert(targetUser && "expected user of map operation was not found"); diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp index 3032857..6404e18 100644 --- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp +++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp @@ -35,7 +35,6 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Support/Debug.h" #include <type_traits> @@ -70,9 +69,6 @@ class MapsForPrivatizedSymbolsPass return size <= ptrSize && align <= ptrAlign; }; - uint64_t mapTypeTo = static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); Operation *definingOp = var.getDefiningOp(); Value varPtr = var; @@ -108,22 +104,31 @@ class MapsForPrivatizedSymbolsPass llvm::SmallVector<mlir::Value> boundsOps; if (needsBoundsOps(varPtr)) genBoundsOps(builder, varPtr, boundsOps); + mlir::Type varType = varPtr.getType(); mlir::omp::VariableCaptureKind captureKind = mlir::omp::VariableCaptureKind::ByRef; - if (fir::isa_trivial(fir::unwrapRefType(varPtr.getType())) || - fir::isa_char(fir::unwrapRefType(varPtr.getType()))) { - if (canPassByValue(fir::unwrapRefType(varPtr.getType()))) { + if (fir::isa_trivial(fir::unwrapRefType(varType)) || + fir::isa_char(fir::unwrapRefType(varType))) { + if (canPassByValue(fir::unwrapRefType(varType))) { captureKind = mlir::omp::VariableCaptureKind::ByCopy; } } + // Use tofrom if what we are mapping is not a trivial type. In all + // likelihood, it is a descriptor + mlir::omp::ClauseMapFlags mapFlag; + if (fir::isa_trivial(fir::unwrapRefType(varType)) || + fir::isa_char(fir::unwrapRefType(varType))) + mapFlag = mlir::omp::ClauseMapFlags::to; + else + mapFlag = mlir::omp::ClauseMapFlags::to | mlir::omp::ClauseMapFlags::from; + return omp::MapInfoOp::create( - builder, loc, varPtr.getType(), varPtr, - TypeAttr::get(llvm::cast<omp::PointerLikeType>(varPtr.getType()) - .getElementType()), - builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false), - mapTypeTo), + builder, loc, varType, varPtr, + TypeAttr::get( + llvm::cast<omp::PointerLikeType>(varType).getElementType()), + builder.getAttr<omp::ClauseMapFlagsAttr>(mapFlag), builder.getAttr<omp::VariableCaptureKindAttr>(captureKind), /*varPtrPtr=*/Value{}, /*members=*/SmallVector<Value>{}, diff --git a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp index 0b0e6bd..5fa77fb 100644 --- a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp +++ b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp @@ -21,6 +21,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/TypeSwitch.h" namespace flangomp { #define GEN_PASS_DEF_MARKDECLARETARGETPASS @@ -31,9 +32,93 @@ namespace { class MarkDeclareTargetPass : public flangomp::impl::MarkDeclareTargetPassBase<MarkDeclareTargetPass> { - void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy, - mlir::omp::DeclareTargetCaptureClause parentCapClause, - bool parentAutomap, mlir::Operation *currOp, + struct ParentInfo { + mlir::omp::DeclareTargetDeviceType devTy; + mlir::omp::DeclareTargetCaptureClause capClause; + bool automap; + }; + + void processSymbolRef(mlir::SymbolRefAttr symRef, ParentInfo parentInfo, + llvm::SmallPtrSet<mlir::Operation *, 16> visited) { + if (auto currFOp = + getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) { + auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>( + currFOp.getOperation()); + + if (current.isDeclareTarget()) { + auto currentDt = current.getDeclareTargetDeviceType(); + + // Found the same function twice, with different device_types, + // mark as Any as it belongs to both + if (currentDt != parentInfo.devTy && + currentDt != mlir::omp::DeclareTargetDeviceType::any) { + current.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any, + current.getDeclareTargetCaptureClause(), + current.getDeclareTargetAutomap()); + } + } else { + current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause, + parentInfo.automap); + } + + markNestedFuncs(parentInfo, currFOp, visited); + } + } + + void processReductionRefs(std::optional<mlir::ArrayAttr> symRefs, + ParentInfo parentInfo, + llvm::SmallPtrSet<mlir::Operation *, 16> visited) { + if (!symRefs) + return; + + for (auto symRef : symRefs->getAsRange<mlir::SymbolRefAttr>()) { + if (auto declareReductionOp = + getOperation().lookupSymbol<mlir::omp::DeclareReductionOp>( + symRef)) { + markNestedFuncs(parentInfo, declareReductionOp, visited); + } + } + } + + void + processReductionClauses(mlir::Operation *op, ParentInfo parentInfo, + llvm::SmallPtrSet<mlir::Operation *, 16> visited) { + llvm::TypeSwitch<mlir::Operation &>(*op) + .Case([&](mlir::omp::LoopOp op) { + processReductionRefs(op.getReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::ParallelOp op) { + processReductionRefs(op.getReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::SectionsOp op) { + processReductionRefs(op.getReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::SimdOp op) { + processReductionRefs(op.getReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::TargetOp op) { + processReductionRefs(op.getInReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::TaskgroupOp op) { + processReductionRefs(op.getTaskReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::TaskloopOp op) { + processReductionRefs(op.getReductionSyms(), parentInfo, visited); + processReductionRefs(op.getInReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::TaskOp op) { + processReductionRefs(op.getInReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::TeamsOp op) { + processReductionRefs(op.getReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::WsloopOp op) { + processReductionRefs(op.getReductionSyms(), parentInfo, visited); + }) + .Default([](mlir::Operation &) {}); + } + + void markNestedFuncs(ParentInfo parentInfo, mlir::Operation *currOp, llvm::SmallPtrSet<mlir::Operation *, 16> visited) { if (visited.contains(currOp)) return; @@ -43,33 +128,10 @@ class MarkDeclareTargetPass if (auto callOp = llvm::dyn_cast<mlir::CallOpInterface>(op)) { if (auto symRef = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>( callOp.getCallableForCallee())) { - if (auto currFOp = - getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) { - auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>( - currFOp.getOperation()); - - if (current.isDeclareTarget()) { - auto currentDt = current.getDeclareTargetDeviceType(); - - // Found the same function twice, with different device_types, - // mark as Any as it belongs to both - if (currentDt != parentDevTy && - currentDt != mlir::omp::DeclareTargetDeviceType::any) { - current.setDeclareTarget( - mlir::omp::DeclareTargetDeviceType::any, - current.getDeclareTargetCaptureClause(), - current.getDeclareTargetAutomap()); - } - } else { - current.setDeclareTarget(parentDevTy, parentCapClause, - parentAutomap); - } - - markNestedFuncs(parentDevTy, parentCapClause, parentAutomap, - currFOp, visited); - } + processSymbolRef(symRef, parentInfo, visited); } } + processReductionClauses(op, parentInfo, visited); }); } @@ -82,10 +144,10 @@ class MarkDeclareTargetPass functionOp.getOperation()); if (declareTargetOp.isDeclareTarget()) { llvm::SmallPtrSet<mlir::Operation *, 16> visited; - markNestedFuncs(declareTargetOp.getDeclareTargetDeviceType(), - declareTargetOp.getDeclareTargetCaptureClause(), - declareTargetOp.getDeclareTargetAutomap(), functionOp, - visited); + ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(), + declareTargetOp.getDeclareTargetCaptureClause(), + declareTargetOp.getDeclareTargetAutomap()}; + markNestedFuncs(parentInfo, functionOp, visited); } } @@ -96,12 +158,13 @@ class MarkDeclareTargetPass // the contents of the device clause getOperation()->walk([&](mlir::omp::TargetOp tarOp) { llvm::SmallPtrSet<mlir::Operation *, 16> visited; - markNestedFuncs( - /*parentDevTy=*/mlir::omp::DeclareTargetDeviceType::nohost, - /*parentCapClause=*/mlir::omp::DeclareTargetCaptureClause::to, - /*parentAutomap=*/false, tarOp, visited); + ParentInfo parentInfo = { + /*devTy=*/mlir::omp::DeclareTargetDeviceType::nohost, + /*capClause=*/mlir::omp::DeclareTargetCaptureClause::to, + /*automap=*/false, + }; + markNestedFuncs(parentInfo, tarOp, visited); }); } }; - } // namespace diff --git a/flang/lib/Optimizer/OpenMP/Support/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/Support/CMakeLists.txt index dee35e4..004753d 100644 --- a/flang/lib/Optimizer/OpenMP/Support/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenMP/Support/CMakeLists.txt @@ -2,6 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) add_flang_library(FIROpenMPSupport FIROpenMPAttributes.cpp + FIROpenMPOpsInterfaces.cpp RegisterOpenMPExtensions.cpp DEPENDS diff --git a/flang/lib/Optimizer/OpenMP/Support/FIROpenMPOpsInterfaces.cpp b/flang/lib/Optimizer/OpenMP/Support/FIROpenMPOpsInterfaces.cpp new file mode 100644 index 0000000..a396ef0 --- /dev/null +++ b/flang/lib/Optimizer/OpenMP/Support/FIROpenMPOpsInterfaces.cpp @@ -0,0 +1,102 @@ +//===-- FIROpenMPOpsInterfaces.cpp ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file implements FIR operation interfaces, which may be attached +/// to OpenMP dialect operations. +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/FIROperationMoveOpInterface.h" +#include "flang/Optimizer/OpenMP/Support/RegisterOpenMPExtensions.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" + +namespace { +/// Helper template that must be specialized for each operation. +/// The methods are declared just for documentation. +template <typename OP, typename Enable = void> +struct OperationMoveModel { + // Returns true if it is allowed to move the given 'candidate' + // operation from the 'descendant' operation into operation 'op'. + // If 'candidate' is nullptr, then the caller is querying whether + // any operation from any descendant can be moved into 'op' operation. + bool canMoveFromDescendant(mlir::Operation *op, mlir::Operation *descendant, + mlir::Operation *candidate) const; + + // Returns true if it is allowed to move the given 'candidate' + // operation out of operation 'op'. If 'candidate' is nullptr, + // then the caller is querying whether any operation can be moved + // out of 'op' operation. + bool canMoveOutOf(mlir::Operation *op, mlir::Operation *candidate) const; +}; + +// Helpers to check if T is one of Ts. +template <typename T, typename... Ts> +struct is_any_type : std::disjunction<std::is_same<T, Ts>...> {}; + +template <typename T, typename... Ts> +struct is_any_omp_op + : std::integral_constant< + bool, is_any_type<typename std::remove_cv<T>::type, Ts...>::value> {}; + +template <typename T, typename... Ts> +constexpr bool is_any_omp_op_v = is_any_omp_op<T, Ts...>::value; + +/// OperationMoveModel specialization for OMP_LOOP_WRAPPER_OPS. +template <typename OP> +struct OperationMoveModel< + OP, + typename std::enable_if<is_any_omp_op_v<OP, OMP_LOOP_WRAPPER_OPS>>::type> + : public fir::OperationMoveOpInterface::ExternalModel< + OperationMoveModel<OP>, OP> { + bool canMoveFromDescendant(mlir::Operation *op, mlir::Operation *descendant, + mlir::Operation *candidate) const { + // Operations cannot be moved from descendants of LoopWrapperInterface + // operation into the LoopWrapperInterface operation. + return false; + } + bool canMoveOutOf(mlir::Operation *op, mlir::Operation *candidate) const { + // The LoopWrapperInterface operations are only supposed to contain + // a loop operation, and it is probably okay to move operations + // from the descendant loop operation out of the LoopWrapperInterface + // operation. For now, return false to be conservative. + return false; + } +}; + +/// OperationMoveModel specialization for OMP_OUTLINEABLE_OPS. +template <typename OP> +struct OperationMoveModel< + OP, typename std::enable_if<is_any_omp_op_v<OP, OMP_OUTLINEABLE_OPS>>::type> + : public fir::OperationMoveOpInterface::ExternalModel< + OperationMoveModel<OP>, OP> { + bool canMoveFromDescendant(mlir::Operation *op, mlir::Operation *descendant, + mlir::Operation *candidate) const { + // Operations can be moved from descendants of OutlineableOpenMPOpInterface + // operation into the OutlineableOpenMPOpInterface operation. + return true; + } + bool canMoveOutOf(mlir::Operation *op, mlir::Operation *candidate) const { + // Operations cannot be moved out of OutlineableOpenMPOpInterface operation. + return false; + } +}; + +// Helper to call attachInterface<OperationMoveModel> for all Ts +// (types of operations). +template <typename... Ts> +void attachInterfaces(mlir::MLIRContext *ctx) { + (Ts::template attachInterface<OperationMoveModel<Ts>>(*ctx), ...); +} +} // anonymous namespace + +void fir::omp::registerOpInterfacesExtensions(mlir::DialectRegistry ®istry) { + registry.addExtension( + +[](mlir::MLIRContext *ctx, mlir::omp::OpenMPDialect *dialect) { + attachInterfaces<OMP_LOOP_WRAPPER_OPS>(ctx); + attachInterfaces<OMP_OUTLINEABLE_OPS>(ctx); + }); +} diff --git a/flang/lib/Optimizer/OpenMP/Support/RegisterOpenMPExtensions.cpp b/flang/lib/Optimizer/OpenMP/Support/RegisterOpenMPExtensions.cpp index 2495d54..de4906e 100644 --- a/flang/lib/Optimizer/OpenMP/Support/RegisterOpenMPExtensions.cpp +++ b/flang/lib/Optimizer/OpenMP/Support/RegisterOpenMPExtensions.cpp @@ -15,6 +15,7 @@ namespace fir::omp { void registerOpenMPExtensions(mlir::DialectRegistry ®istry) { registerAttrsExtensions(registry); + registerOpInterfacesExtensions(registry); } } // namespace fir::omp diff --git a/flang/lib/Optimizer/Passes/CommandLineOpts.cpp b/flang/lib/Optimizer/Passes/CommandLineOpts.cpp index 0142375..75e818d 100644 --- a/flang/lib/Optimizer/Passes/CommandLineOpts.cpp +++ b/flang/lib/Optimizer/Passes/CommandLineOpts.cpp @@ -61,6 +61,7 @@ cl::opt<bool> useOldAliasTags( cl::desc("Use a single TBAA tree for all functions and do not use " "the FIR alias tags pass"), cl::init(false), cl::Hidden); +EnableOption(FirLICM, "fir-licm", "FIR loop invariant code motion"); /// CodeGen Passes DisableOption(CodeGenRewrite, "codegen-rewrite", "rewrite FIR for codegen"); diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 6dae39b..18ad22f 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -20,8 +20,9 @@ namespace fir { template <typename F> void addNestedPassToAllTopLevelOperations(mlir::PassManager &pm, F ctor) { - addNestedPassToOps<F, mlir::func::FuncOp, mlir::omp::DeclareReductionOp, - mlir::omp::PrivateClauseOp, fir::GlobalOp>(pm, ctor); + addNestedPassToOps<F, mlir::func::FuncOp, mlir::omp::DeclareMapperOp, + mlir::omp::DeclareReductionOp, mlir::omp::PrivateClauseOp, + fir::GlobalOp>(pm, ctor); } template <typename F> @@ -107,8 +108,8 @@ void addDebugInfoPass(mlir::PassManager &pm, [&]() { return fir::createAddDebugInfoPass(options); }); } -void addFIRToLLVMPass(mlir::PassManager &pm, - const MLIRToLLVMPassPipelineConfig &config) { +fir::FIRToLLVMPassOptions +getFIRToLLVMPassOptions(const MLIRToLLVMPassPipelineConfig &config) { fir::FIRToLLVMPassOptions options; options.ignoreMissingTypeDescriptors = ignoreMissingTypeDescriptors; options.skipExternalRttiDefinition = skipExternalRttiDefinition; @@ -117,6 +118,12 @@ void addFIRToLLVMPass(mlir::PassManager &pm, options.typeDescriptorsRenamedForAssembly = !disableCompilerGeneratedNamesConversion; options.ComplexRange = config.ComplexRange; + return options; +} + +void addFIRToLLVMPass(mlir::PassManager &pm, + const MLIRToLLVMPassPipelineConfig &config) { + fir::FIRToLLVMPassOptions options = getFIRToLLVMPassOptions(config); addPassConditionally(pm, disableFirToLlvmIr, [&]() { return fir::createFIRToLLVMPass(options); }); // The dialect conversion framework may leave dead unrealized_conversion_cast @@ -206,6 +213,10 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm, pm.addPass(fir::createSimplifyRegionLite()); pm.addPass(mlir::createCSEPass()); + // Run LICM after CSE, which may reduce the number of operations to hoist. + if (enableFirLICM && pc.OptLevel.isOptimizingForSpeed()) + pm.addPass(fir::createLoopInvariantCodeMotion()); + // Polymorphic types pm.addPass(fir::createPolymorphicOpConversion()); pm.addPass(fir::createAssumedRankOpConversion()); @@ -279,7 +290,8 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, pm, hlfir::createInlineHLFIRCopyIn); } } - pm.addPass(hlfir::createLowerHLFIROrderedAssignments()); + pm.addPass(hlfir::createLowerHLFIROrderedAssignments( + {/*tryFusingAssignments=*/optLevel.isOptimizingForSpeed()})); pm.addPass(hlfir::createLowerHLFIRIntrinsics()); hlfir::BufferizeHLFIROptions bufferizeOptions; @@ -372,7 +384,7 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm, fir::addCompilerGeneratedNamesConversionPass(pm); if (config.VScaleMin != 0) - pm.addPass(fir::createVScaleAttr({{config.VScaleMin, config.VScaleMax}})); + pm.addPass(fir::createVScaleAttr({config.VScaleMin, config.VScaleMax})); // Add function attributes mlir::LLVM::framePointerKind::FramePointerKind framePointerKind; @@ -383,6 +395,9 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm, framePointerKind = mlir::LLVM::framePointerKind::FramePointerKind::All; else if (config.FramePointerKind == llvm::FramePointerKind::Reserved) framePointerKind = mlir::LLVM::framePointerKind::FramePointerKind::Reserved; + else if (config.FramePointerKind == llvm::FramePointerKind::NonLeafNoReserve) + framePointerKind = + mlir::LLVM::framePointerKind::FramePointerKind::NonLeafNoReserve; else framePointerKind = mlir::LLVM::framePointerKind::FramePointerKind::None; @@ -426,6 +441,12 @@ void createMLIRToLLVMPassPipeline(mlir::PassManager &pm, // Add codegen pass pipeline. fir::createDefaultFIRCodeGenPassPipeline(pm, config, inputFilename); + + // Run a pass to prepare for translation of delayed privatization in the + // context of deferred target tasks. + addPassConditionally(pm, disableFirToLlvmIr, [&]() { + return mlir::omp::createPrepareForOMPOffloadPrivatizationPass(); + }); } } // namespace fir diff --git a/flang/lib/Optimizer/Support/CMakeLists.txt b/flang/lib/Optimizer/Support/CMakeLists.txt index 38038e1..6f3652b 100644 --- a/flang/lib/Optimizer/Support/CMakeLists.txt +++ b/flang/lib/Optimizer/Support/CMakeLists.txt @@ -7,9 +7,11 @@ add_flang_library(FIRSupport DEPENDS FIROpsIncGen HLFIROpsIncGen + MIFOpsIncGen LINK_LIBS FIRDialect + MIFDialect LINK_COMPONENTS TargetParser diff --git a/flang/lib/Optimizer/Support/Utils.cpp b/flang/lib/Optimizer/Support/Utils.cpp index 92390e4a..2f33d89 100644 --- a/flang/lib/Optimizer/Support/Utils.cpp +++ b/flang/lib/Optimizer/Support/Utils.cpp @@ -66,7 +66,7 @@ fir::genConstantIndex(mlir::Location loc, mlir::Type ity, mlir::ConversionPatternRewriter &rewriter, std::int64_t offset) { auto cattr = rewriter.getI64IntegerAttr(offset); - return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr); + return mlir::LLVM::ConstantOp::create(rewriter, loc, ity, cattr); } mlir::Value @@ -125,9 +125,9 @@ mlir::Value fir::integerCast(const fir::LLVMTypeConverter &converter, return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val); } else { if (toSize < fromSize) - return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val); + return mlir::LLVM::TruncOp::create(rewriter, loc, ty, val); if (toSize > fromSize) - return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val); + return mlir::LLVM::SExtOp::create(rewriter, loc, ty, val); } return val; } diff --git a/flang/lib/Optimizer/Transforms/AddAliasTags.cpp b/flang/lib/Optimizer/Transforms/AddAliasTags.cpp index 0221c7a..142e4c8 100644 --- a/flang/lib/Optimizer/Transforms/AddAliasTags.cpp +++ b/flang/lib/Optimizer/Transforms/AddAliasTags.cpp @@ -60,6 +60,9 @@ static llvm::cl::opt<unsigned> localAllocsThreshold( llvm::cl::desc("If present, stops generating TBAA tags for accesses of " "local allocations after N accesses in a module")); +// Defined in AliasAnalysis.cpp +extern llvm::cl::opt<bool> supportCrayPointers; + namespace { // Return the size and alignment (in bytes) for the given type. @@ -210,10 +213,7 @@ public: void processFunctionScopes(mlir::func::FuncOp func); // For the given fir.declare returns the dominating fir.dummy_scope // operation. - fir::DummyScopeOp getDeclarationScope(fir::DeclareOp declareOp) const; - // For the given fir.declare returns the outermost fir.dummy_scope - // in the current function. - fir::DummyScopeOp getOutermostScope(fir::DeclareOp declareOp) const; + fir::DummyScopeOp getDeclarationScope(fir::DeclareOp declareOp); // Returns true, if the given type of a memref of a FirAliasTagOpInterface // operation is a descriptor or contains a descriptor // (e.g. !fir.ref<!fir.type<Derived{f:!fir.box<!fir.heap<f32>>}>>). @@ -353,8 +353,9 @@ void PassState::processFunctionScopes(mlir::func::FuncOp func) { } } -fir::DummyScopeOp -PassState::getDeclarationScope(fir::DeclareOp declareOp) const { +// For the given fir.declare returns the dominating fir.dummy_scope +// operation. +fir::DummyScopeOp PassState::getDeclarationScope(fir::DeclareOp declareOp) { auto func = declareOp->getParentOfType<mlir::func::FuncOp>(); assert(func && "fir.declare does not have parent func.func"); auto &scopeOps = sortedScopeOperations.at(func); @@ -365,15 +366,6 @@ PassState::getDeclarationScope(fir::DeclareOp declareOp) const { return nullptr; } -fir::DummyScopeOp PassState::getOutermostScope(fir::DeclareOp declareOp) const { - auto func = declareOp->getParentOfType<mlir::func::FuncOp>(); - assert(func && "fir.declare does not have parent func.func"); - auto &scopeOps = sortedScopeOperations.at(func); - if (!scopeOps.empty()) - return scopeOps[0]; - return nullptr; -} - bool PassState::typeReferencesDescriptor(mlir::Type type) { type = fir::unwrapAllRefAndSeqType(type); if (mlir::isa<fir::BaseBoxType>(type)) @@ -668,6 +660,7 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op, LLVM_DEBUG(llvm::dbgs() << "Analysing " << op << "\n"); const fir::AliasAnalysis::Source &source = state.getSource(memref); + LLVM_DEBUG(llvm::dbgs() << "Got source " << source << "\n"); // Process the scopes, if not processed yet. state.processFunctionScopes(func); @@ -686,14 +679,22 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op, } mlir::LLVM::TBAATagAttr tag; - // TBAA for dummy arguments - if (enableDummyArgs && - source.kind == fir::AliasAnalysis::SourceKind::Argument) { + // Cray pointer/pointee is a special case. These might alias with any data. + if (supportCrayPointers && source.isCrayPointerOrPointee()) { + LLVM_DEBUG(llvm::dbgs().indent(2) + << "Found reference to Cray pointer/pointee at " << *op << "\n"); + mlir::LLVM::TBAATypeDescriptorAttr anyDataDesc = + state.getFuncTreeWithScope(func, scopeOp).anyDataTypeDesc; + tag = mlir::LLVM::TBAATagAttr::get(anyDataDesc, anyDataDesc, /*offset=*/0); + // TBAA for dummy arguments + } else if (enableDummyArgs && + source.kind == fir::AliasAnalysis::SourceKind::Argument) { LLVM_DEBUG(llvm::dbgs().indent(2) << "Found reference to dummy argument at " << *op << "\n"); std::string name = getFuncArgName(llvm::cast<mlir::Value>(source.origin.u)); - // If it is a TARGET or POINTER, then we do not care about the name, - // because the tag points to the root of the subtree currently. + // POINTERS can alias with any POINTER or TARGET. Assume that TARGET dummy + // arguments might alias with each other (because of the "TARGET" hole for + // dummy arguments). See flang/docs/Aliasing.md. if (source.isTargetOrPointer()) { tag = state.getFuncTreeWithScope(func, scopeOp).targetDataTree.getTag(); } else if (!name.empty()) { @@ -715,13 +716,10 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op, LLVM_DEBUG(llvm::dbgs().indent(2) << "Found reference to global " << globalName.str() << " at " << *op << "\n"); - if (source.isPointer()) { - tag = state.getFuncTreeWithScope(func, scopeOp).targetDataTree.getTag(); - } else { - // In general, place the tags under the "global data" root. - fir::TBAATree::SubtreeState *subTree = - &state.getMutableFuncTreeWithScope(func, scopeOp).globalDataTree; + // Add a named tag inside the given subtree, disambiguating members of a + // common block + auto addTagUsingStorageDesc = [&](fir::TBAATree::SubtreeState *subTree) { mlir::Operation *instantiationPoint = source.origin.instantiationPoint; auto storageIface = mlir::dyn_cast_or_null<fir::FortranVariableStorageOpInterface>( @@ -766,6 +764,19 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op, LLVM_DEBUG(llvm::dbgs() << "Tagged under '" << globalName << "' root\n"); } + }; + + if (source.isPointer()) { + // Pointers can alias with any pointer or target. + tag = state.getFuncTreeWithScope(func, scopeOp).targetDataTree.getTag(); + } else if (source.isTarget()) { + // Targets could alias with any pointer but not with each other. + addTagUsingStorageDesc( + &state.getMutableFuncTreeWithScope(func, scopeOp).targetDataTree); + } else { + // In general, place the tags under the "global data" root. + addTagUsingStorageDesc( + &state.getMutableFuncTreeWithScope(func, scopeOp).globalDataTree); } // TBAA for global variables with descriptors @@ -776,9 +787,17 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op, const char *name = glbl.getRootReference().data(); LLVM_DEBUG(llvm::dbgs().indent(2) << "Found reference to direct " << name << " at " << *op << "\n"); + // Pointer can alias with any pointer or target so that gets the root. if (source.isPointer()) tag = state.getFuncTreeWithScope(func, scopeOp).targetDataTree.getTag(); + // Targets could alias with any pointer but not with each other so they + // get their own node inside of the target data tree. + else if (source.isTarget()) + tag = state.getFuncTreeWithScope(func, scopeOp) + .targetDataTree.getTag(name); else + // Boxes that are not pointers or targets cannot alias with those that + // are. Put them under global data. tag = state.getFuncTreeWithScope(func, scopeOp) .directDataTree.getTag(name); } else { @@ -800,22 +819,23 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op, else unknownAllocOp = true; - if (auto declOp = source.origin.instantiationPoint) { - // Use the outermost scope for local allocations, - // because using the innermost scope may result - // in incorrect TBAA, when calls are inlined in MLIR. - auto declareOp = mlir::dyn_cast<fir::DeclareOp>(declOp); - assert(declareOp && "Instantiation point must be fir.declare"); - scopeOp = state.getOutermostScope(declareOp); - } - if (unknownAllocOp) { LLVM_DEBUG(llvm::dbgs().indent(2) << "WARN: unknown defining op for SourceKind::Allocate " << *op << "\n"); } else if (source.isPointer() && state.attachLocalAllocTag()) { LLVM_DEBUG(llvm::dbgs().indent(2) - << "Found reference to allocation at " << *op << "\n"); + << "Found reference to POINTER allocation at " << *op << "\n"); + tag = state.getFuncTreeWithScope(func, scopeOp).targetDataTree.getTag(); + } else if (name && source.isTarget() && state.attachLocalAllocTag()) { + LLVM_DEBUG(llvm::dbgs().indent(2) + << "Found reference to TARGET allocation at " << *op << "\n"); + tag = state.getFuncTreeWithScope(func, scopeOp) + .targetDataTree.getTag(*name); + } else if (source.isTarget() && state.attachLocalAllocTag()) { + LLVM_DEBUG(llvm::dbgs().indent(2) + << "WARN: couldn't find a name for TARGET allocation " << *op + << "\n"); tag = state.getFuncTreeWithScope(func, scopeOp).targetDataTree.getTag(); } else if (name && state.attachLocalAllocTag()) { LLVM_DEBUG(llvm::dbgs().indent(2) << "Found reference to allocation " diff --git a/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp b/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp index e006d2e..35d8a2f6 100644 --- a/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp +++ b/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp @@ -53,7 +53,7 @@ class AddDebugInfoPass : public fir::impl::AddDebugInfoBase<AddDebugInfoPass> { mlir::LLVM::DIFileAttr fileAttr, mlir::LLVM::DIScopeAttr scopeAttr, fir::DebugTypeGenerator &typeGen, - mlir::SymbolTable *symbolTable); + mlir::SymbolTable *symbolTable, mlir::Value dummyScope); public: AddDebugInfoPass(fir::AddDebugInfoOptions options) : Base(options) {} @@ -84,6 +84,24 @@ private: mlir::LLVM::DICompileUnitAttr cuAttr, fir::DebugTypeGenerator &typeGen, mlir::SymbolTable *symbolTable); + void handleOnlyClause( + fir::UseStmtOp useOp, mlir::LLVM::DISubprogramAttr spAttr, + mlir::LLVM::DIFileAttr fileAttr, mlir::SymbolTable *symbolTable, + llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> &importedModules); + void handleRenamesWithoutOnly( + fir::UseStmtOp useOp, mlir::LLVM::DISubprogramAttr spAttr, + mlir::LLVM::DIModuleAttr modAttr, mlir::LLVM::DIFileAttr fileAttr, + mlir::SymbolTable *symbolTable, + llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> &importedModules); + void handleUseStatements( + mlir::func::FuncOp funcOp, mlir::LLVM::DISubprogramAttr spAttr, + mlir::LLVM::DIFileAttr fileAttr, mlir::LLVM::DICompileUnitAttr cuAttr, + mlir::SymbolTable *symbolTable, + llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> &importedEntities); + std::optional<mlir::LLVM::DIImportedEntityAttr> createImportedDeclForGlobal( + llvm::StringRef symbolName, mlir::LLVM::DISubprogramAttr spAttr, + mlir::LLVM::DIFileAttr fileAttr, mlir::StringAttr localNameAttr, + mlir::SymbolTable *symbolTable); bool createCommonBlockGlobal(fir::cg::XDeclareOp declOp, const std::string &name, mlir::LLVM::DIFileAttr fileAttr, @@ -138,75 +156,122 @@ mlir::StringAttr getTargetFunctionName(mlir::MLIRContext *context, } // namespace +// Check if a global represents a module variable +static bool isModuleVariable(fir::GlobalOp globalOp) { + std::pair result = fir::NameUniquer::deconstruct(globalOp.getSymName()); + return result.first == fir::NameUniquer::NameKind::VARIABLE && + result.second.procs.empty() && !result.second.modules.empty(); +} + +// Look up DIGlobalVariable from a global symbol +static std::optional<mlir::LLVM::DIGlobalVariableAttr> +lookupDIGlobalVariable(llvm::StringRef symbolName, + mlir::SymbolTable *symbolTable) { + if (auto globalOp = symbolTable->lookup<fir::GlobalOp>(symbolName)) { + if (auto fusedLoc = mlir::dyn_cast<mlir::FusedLoc>(globalOp.getLoc())) { + if (auto metadata = fusedLoc.getMetadata()) { + if (auto arrayAttr = mlir::dyn_cast<mlir::ArrayAttr>(metadata)) { + for (auto elem : arrayAttr) { + if (auto gvExpr = + mlir::dyn_cast<mlir::LLVM::DIGlobalVariableExpressionAttr>( + elem)) + return gvExpr.getVar(); + } + } + } + } + } + return std::nullopt; +} + bool AddDebugInfoPass::createCommonBlockGlobal( fir::cg::XDeclareOp declOp, const std::string &name, mlir::LLVM::DIFileAttr fileAttr, mlir::LLVM::DIScopeAttr scopeAttr, fir::DebugTypeGenerator &typeGen, mlir::SymbolTable *symbolTable) { mlir::MLIRContext *context = &getContext(); mlir::OpBuilder builder(context); - std::optional<std::int64_t> optint; - mlir::Operation *op = declOp.getMemref().getDefiningOp(); - - if (auto conOp = mlir::dyn_cast_if_present<fir::ConvertOp>(op)) - op = conOp.getValue().getDefiningOp(); - if (auto cordOp = mlir::dyn_cast_if_present<fir::CoordinateOp>(op)) { - auto coors = cordOp.getCoor(); - if (coors.size() != 1) - return false; - optint = fir::getIntIfConstant(coors[0]); - if (!optint) - return false; - op = cordOp.getRef().getDefiningOp(); - if (auto conOp2 = mlir::dyn_cast_if_present<fir::ConvertOp>(op)) - op = conOp2.getValue().getDefiningOp(); - - if (auto addrOfOp = mlir::dyn_cast_if_present<fir::AddrOfOp>(op)) { - mlir::SymbolRefAttr sym = addrOfOp.getSymbol(); - if (auto global = - symbolTable->lookup<fir::GlobalOp>(sym.getRootReference())) { - - unsigned line = getLineFromLoc(global.getLoc()); - llvm::StringRef commonName(sym.getRootReference()); - // FIXME: We are trying to extract the name of the common block from the - // name of the global. As part of mangling, GetCommonBlockObjectName can - // add a trailing _ in the name of that global. The demangle function - // does not seem to handle such cases. So the following hack is used to - // remove the trailing '_'. - if (commonName != Fortran::common::blankCommonObjectName && - commonName.back() == '_') - commonName = commonName.drop_back(); - mlir::LLVM::DICommonBlockAttr commonBlock = - getOrCreateCommonBlockAttr(commonName, fileAttr, scopeAttr, line); - mlir::LLVM::DITypeAttr diType = typeGen.convertType( - fir::unwrapRefType(declOp.getType()), fileAttr, scopeAttr, declOp); - line = getLineFromLoc(declOp.getLoc()); - auto gvAttr = mlir::LLVM::DIGlobalVariableAttr::get( - context, commonBlock, mlir::StringAttr::get(context, name), - declOp.getUniqName(), fileAttr, line, diType, - /*isLocalToUnit*/ false, /*isDefinition*/ true, /* alignInBits*/ 0); - mlir::LLVM::DIExpressionAttr expr; - if (*optint != 0) { - llvm::SmallVector<mlir::LLVM::DIExpressionElemAttr> ops; - ops.push_back(mlir::LLVM::DIExpressionElemAttr::get( - context, llvm::dwarf::DW_OP_plus_uconst, *optint)); - expr = mlir::LLVM::DIExpressionAttr::get(context, ops); - } - auto dbgExpr = mlir::LLVM::DIGlobalVariableExpressionAttr::get( - global.getContext(), gvAttr, expr); - globalToGlobalExprsMap[global].push_back(dbgExpr); - return true; - } - } + std::optional<std::int64_t> offset; + mlir::Value storage = declOp.getStorage(); + if (!storage) + return false; + + // Extract offset from storage_offset attribute + uint64_t storageOffset = declOp.getStorageOffset(); + if (storageOffset != 0) + offset = static_cast<std::int64_t>(storageOffset); + + // Get the GlobalOp from the storage value. + // The storage may be wrapped in ConvertOp, so unwrap it first. + mlir::Operation *storageOp = storage.getDefiningOp(); + if (auto convertOp = mlir::dyn_cast_if_present<fir::ConvertOp>(storageOp)) + storageOp = convertOp.getValue().getDefiningOp(); + + auto addrOfOp = mlir::dyn_cast_if_present<fir::AddrOfOp>(storageOp); + if (!addrOfOp) + return false; + + mlir::SymbolRefAttr sym = addrOfOp.getSymbol(); + fir::GlobalOp global = + symbolTable->lookup<fir::GlobalOp>(sym.getRootReference()); + if (!global) + return false; + + // Check if the global is actually a common block by demangling its name. + // Module EQUIVALENCE variables also use storage operands but are mangled + // as VARIABLE type, so we reject them to avoid treating them as common + // blocks. + llvm::StringRef globalSymbol = sym.getRootReference(); + auto globalResult = fir::NameUniquer::deconstruct(globalSymbol); + if (globalResult.first == fir::NameUniquer::NameKind::VARIABLE) + return false; + + // FIXME: We are trying to extract the name of the common block from the + // name of the global. As part of mangling, GetCommonBlockObjectName can + // add a trailing _ in the name of that global. The demangle function + // does not seem to handle such cases. So the following hack is used to + // remove the trailing '_'. + llvm::StringRef commonName = globalSymbol; + if (commonName != Fortran::common::blankCommonObjectName && + !commonName.empty() && commonName.back() == '_') + commonName = commonName.drop_back(); + + // Create the debug attributes. + unsigned line = getLineFromLoc(global.getLoc()); + mlir::LLVM::DICommonBlockAttr commonBlock = + getOrCreateCommonBlockAttr(commonName, fileAttr, scopeAttr, line); + + mlir::LLVM::DITypeAttr diType = typeGen.convertType( + fir::unwrapRefType(declOp.getType()), fileAttr, scopeAttr, declOp); + + line = getLineFromLoc(declOp.getLoc()); + auto gvAttr = mlir::LLVM::DIGlobalVariableAttr::get( + context, commonBlock, mlir::StringAttr::get(context, name), + declOp.getUniqName(), fileAttr, line, diType, + /*isLocalToUnit*/ false, /*isDefinition*/ true, /* alignInBits*/ 0); + + // Create DIExpression for offset if needed + mlir::LLVM::DIExpressionAttr expr; + if (offset && *offset != 0) { + llvm::SmallVector<mlir::LLVM::DIExpressionElemAttr> ops; + ops.push_back(mlir::LLVM::DIExpressionElemAttr::get( + context, llvm::dwarf::DW_OP_plus_uconst, *offset)); + expr = mlir::LLVM::DIExpressionAttr::get(context, ops); } - return false; + + auto dbgExpr = mlir::LLVM::DIGlobalVariableExpressionAttr::get( + global.getContext(), gvAttr, expr); + globalToGlobalExprsMap[global].push_back(dbgExpr); + + return true; } void AddDebugInfoPass::handleDeclareOp(fir::cg::XDeclareOp declOp, mlir::LLVM::DIFileAttr fileAttr, mlir::LLVM::DIScopeAttr scopeAttr, fir::DebugTypeGenerator &typeGen, - mlir::SymbolTable *symbolTable) { + mlir::SymbolTable *symbolTable, + mlir::Value dummyScope) { mlir::MLIRContext *context = &getContext(); mlir::OpBuilder builder(context); auto result = fir::NameUniquer::deconstruct(declOp.getUniqName()); @@ -228,24 +293,11 @@ void AddDebugInfoPass::handleDeclareOp(fir::cg::XDeclareOp declOp, } } - // FIXME: There may be cases where an argument is processed a bit before - // DeclareOp is generated. In that case, DeclareOp may point to an - // intermediate op and not to BlockArgument. - // Moreover, with MLIR inlining we cannot use the BlockArgument - // position to identify the original number of the dummy argument. - // If we want to keep running AddDebugInfoPass late, the dummy argument - // position in the argument list has to be expressed in FIR (e.g. as a - // constant attribute of [hl]fir.declare/fircg.ext_declare operation that has - // a dummy_scope operand). + // Get the dummy argument position from the explicit attribute. unsigned argNo = 0; - if (declOp.getDummyScope()) { - if (auto arg = llvm::dyn_cast<mlir::BlockArgument>(declOp.getMemref())) { - // Check if it is the BlockArgument of the function's entry block. - if (auto funcLikeOp = - declOp->getParentOfType<mlir::FunctionOpInterface>()) - if (arg.getOwner() == &funcLikeOp.front()) - argNo = arg.getArgNumber() + 1; - } + if (dummyScope && declOp.getDummyScope() == dummyScope) { + if (auto argNoOpt = declOp.getDummyArgNo()) + argNo = *argNoOpt; } auto tyAttr = typeGen.convertType(fir::unwrapRefType(declOp.getType()), @@ -520,7 +572,7 @@ void AddDebugInfoPass::handleFuncOp(mlir::func::FuncOp funcOp, CC = llvm::dwarf::getCallingConvention("DW_CC_normal"); mlir::LLVM::DISubroutineTypeAttr spTy = mlir::LLVM::DISubroutineTypeAttr::get(context, CC, types); - if (lineTableOnly) { + if (lineTableOnly || entities.empty()) { auto spAttr = mlir::LLVM::DISubprogramAttr::get( context, id, compilationUnit, Scope, name, name, funcFileAttr, line, line, flags, spTy, /*retainedNodes=*/{}, /*annotations=*/{}); @@ -540,9 +592,9 @@ void AddDebugInfoPass::handleFuncOp(mlir::func::FuncOp funcOp, for (mlir::LLVM::DINodeAttr N : entities) { if (auto entity = mlir::dyn_cast<mlir::LLVM::DIImportedEntityAttr>(N)) { auto importedEntity = mlir::LLVM::DIImportedEntityAttr::get( - context, llvm::dwarf::DW_TAG_imported_module, spAttr, - entity.getEntity(), fileAttr, /*line=*/1, /*name=*/nullptr, - /*elements*/ {}); + context, entity.getTag(), spAttr, entity.getEntity(), + entity.getFile(), entity.getLine(), entity.getName(), + entity.getElements()); opEntities.push_back(importedEntity); } } @@ -567,61 +619,72 @@ void AddDebugInfoPass::handleFuncOp(mlir::func::FuncOp funcOp, return; } - mlir::DistinctAttr recId = - mlir::DistinctAttr::create(mlir::UnitAttr::get(context)); - - // The debug attribute in MLIR are readonly once created. But in case of - // imported entities, we have a circular dependency. The - // DIImportedEntityAttr requires scope information (DISubprogramAttr in this - // case) and DISubprogramAttr requires the list of imported entities. The - // MLIR provides a way where a DISubprogramAttr an be created with a certain - // recID and be used in places like DIImportedEntityAttr. After that another - // DISubprogramAttr can be created with same recID but with list of entities - // now available. The MLIR translation code takes care of updating the - // references. Note that references will be updated only in the things that - // are part of DISubprogramAttr (like DIImportedEntityAttr) so we have to - // create the final DISubprogramAttr before we process local variables. - // Look at DIRecursiveTypeAttrInterface for more details. - - auto spAttr = mlir::LLVM::DISubprogramAttr::get( - context, recId, /*isRecSelf=*/true, id, compilationUnit, Scope, funcName, - fullName, funcFileAttr, line, line, subprogramFlags, subTypeAttr, - /*retainedNodes=*/{}, /*annotations=*/{}); - - // There is no direct information in the IR for any 'use' statement in the - // function. We have to extract that information from the DeclareOp. We do - // a pass on the DeclareOp and generate ModuleAttr and corresponding - // DIImportedEntityAttr for that module. - // FIXME: As we are depending on the variables to see which module is being - // 'used' in the function, there are certain limitations. - // For things like 'use mod1, only: v1', whole module will be brought into the - // namespace in the debug info. It is not a problem as such unless there is a - // clash of names. - // There is no information about module variable renaming - llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> importedModules; - funcOp.walk([&](fir::cg::XDeclareOp declOp) { - if (&funcOp.front() == declOp->getBlock()) - if (auto global = - symbolTable->lookup<fir::GlobalOp>(declOp.getUniqName())) { - std::optional<mlir::LLVM::DIModuleAttr> modOpt = - getModuleAttrFromGlobalOp(global, fileAttr, cuAttr); - if (modOpt) { - auto importedEntity = mlir::LLVM::DIImportedEntityAttr::get( - context, llvm::dwarf::DW_TAG_imported_module, spAttr, *modOpt, - fileAttr, /*line=*/1, /*name=*/nullptr, /*elements*/ {}); - importedModules.insert(importedEntity); - } - } + // Check if there are any USE statements + bool hasUseStmts = false; + funcOp.walk([&](fir::UseStmtOp useOp) { + hasUseStmts = true; + return mlir::WalkResult::interrupt(); }); - llvm::SmallVector<mlir::LLVM::DINodeAttr> entities(importedModules.begin(), - importedModules.end()); - // We have the imported entities now. Generate the final DISubprogramAttr. - spAttr = mlir::LLVM::DISubprogramAttr::get( - context, recId, /*isRecSelf=*/false, id2, compilationUnit, Scope, - funcName, fullName, funcFileAttr, line, line, subprogramFlags, - subTypeAttr, entities, /*annotations=*/{}); + + mlir::LLVM::DISubprogramAttr spAttr; + llvm::SmallVector<mlir::LLVM::DINodeAttr> retainedNodes; + + if (hasUseStmts) { + mlir::DistinctAttr recId = + mlir::DistinctAttr::create(mlir::UnitAttr::get(context)); + // The debug attribute in MLIR are readonly once created. But in case of + // imported entities, we have a circular dependency. The + // DIImportedEntityAttr requires scope information (DISubprogramAttr in this + // case) and DISubprogramAttr requires the list of imported entities. The + // MLIR provides a way where a DISubprogramAttr an be created with a certain + // recID and be used in places like DIImportedEntityAttr. After that another + // DISubprogramAttr can be created with same recID but with list of entities + // now available. The MLIR translation code takes care of updating the + // references. Note that references will be updated only in the things that + // are part of DISubprogramAttr (like DIImportedEntityAttr) so we have to + // create the final DISubprogramAttr before we process local variables. + // Look at DIRecursiveTypeAttrInterface for more details. + spAttr = mlir::LLVM::DISubprogramAttr::get( + context, recId, /*isRecSelf=*/true, id, compilationUnit, Scope, + funcName, fullName, funcFileAttr, line, line, subprogramFlags, + subTypeAttr, /*retainedNodes=*/{}, /*annotations=*/{}); + + // Process USE statements (module globals are already processed) + llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> importedEntities; + handleUseStatements(funcOp, spAttr, fileAttr, cuAttr, symbolTable, + importedEntities); + + retainedNodes.append(importedEntities.begin(), importedEntities.end()); + + // Create final DISubprogramAttr with imported entities and same recId + spAttr = mlir::LLVM::DISubprogramAttr::get( + context, recId, /*isRecSelf=*/false, id2, compilationUnit, Scope, + funcName, fullName, funcFileAttr, line, line, subprogramFlags, + subTypeAttr, retainedNodes, /*annotations=*/{}); + } else + // No USE statements - create final DISubprogramAttr directly + spAttr = mlir::LLVM::DISubprogramAttr::get( + context, id, compilationUnit, Scope, funcName, fullName, funcFileAttr, + line, line, subprogramFlags, subTypeAttr, /*retainedNodes=*/{}, + /*annotations=*/{}); + funcOp->setLoc(builder.getFusedLoc({l}, spAttr)); - addTargetOpDISP(/*lineTableOnly=*/false, entities); + addTargetOpDISP(/*lineTableOnly=*/false, retainedNodes); + + // Find the first dummy_scope definition. This is the one of the current + // function. The other ones may come from inlined calls. The variables inside + // those inlined calls should not be identified as arguments of the current + // function. + mlir::Value dummyScope; + funcOp.walk([&](fir::UndefOp undef) -> mlir::WalkResult { + // TODO: delay fir.dummy_scope translation to undefined until + // codegeneration. This is nicer and safer to match. + if (llvm::isa<fir::DummyScopeType>(undef.getType())) { + dummyScope = undef; + return mlir::WalkResult::interrupt(); + } + return mlir::WalkResult::advance(); + }); funcOp.walk([&](fir::cg::XDeclareOp declOp) { mlir::LLVM::DISubprogramAttr spTy = spAttr; @@ -632,7 +695,7 @@ void AddDebugInfoPass::handleFuncOp(mlir::func::FuncOp funcOp, spTy = sp; } } - handleDeclareOp(declOp, fileAttr, spTy, typeGen, symbolTable); + handleDeclareOp(declOp, fileAttr, spTy, typeGen, symbolTable, dummyScope); }); // commonBlockMap ensures that we don't create multiple DICommonBlockAttr of // the same name in one function. But it is ok (rather required) to create @@ -641,6 +704,110 @@ void AddDebugInfoPass::handleFuncOp(mlir::func::FuncOp funcOp, commonBlockMap.clear(); } +// Helper function to create a DIImportedEntityAttr for an imported declaration. +// Looks up the DIGlobalVariable for the given symbol and creates an imported +// declaration with the optional local name (for renames). +// Returns std::nullopt if the symbol's DIGlobalVariable is not found. +std::optional<mlir::LLVM::DIImportedEntityAttr> +AddDebugInfoPass::createImportedDeclForGlobal( + llvm::StringRef symbolName, mlir::LLVM::DISubprogramAttr spAttr, + mlir::LLVM::DIFileAttr fileAttr, mlir::StringAttr localNameAttr, + mlir::SymbolTable *symbolTable) { + mlir::MLIRContext *context = &getContext(); + if (auto gvAttr = lookupDIGlobalVariable(symbolName, symbolTable)) { + return mlir::LLVM::DIImportedEntityAttr::get( + context, llvm::dwarf::DW_TAG_imported_declaration, spAttr, *gvAttr, + fileAttr, /*line=*/1, /*name=*/localNameAttr, /*elements*/ {}); + } + return std::nullopt; +} + +// Process USE with ONLY clause +void AddDebugInfoPass::handleOnlyClause( + fir::UseStmtOp useOp, mlir::LLVM::DISubprogramAttr spAttr, + mlir::LLVM::DIFileAttr fileAttr, mlir::SymbolTable *symbolTable, + llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> &importedModules) { + // Process ONLY symbols (without renames) + if (auto onlySymbols = useOp.getOnlySymbols()) { + for (mlir::Attribute attr : *onlySymbols) { + auto symbolRef = mlir::cast<mlir::FlatSymbolRefAttr>(attr); + if (auto importedDecl = createImportedDeclForGlobal( + symbolRef.getValue(), spAttr, fileAttr, mlir::StringAttr(), + symbolTable)) + importedModules.insert(*importedDecl); + } + } + + // Process renames within ONLY clause + if (auto renames = useOp.getRenames()) { + for (auto attr : *renames) { + auto renameAttr = mlir::cast<fir::UseRenameAttr>(attr); + if (auto importedDecl = createImportedDeclForGlobal( + renameAttr.getSymbol().getValue(), spAttr, fileAttr, + renameAttr.getLocalName(), symbolTable)) + importedModules.insert(*importedDecl); + } + } +} + +// Process USE with renames but no ONLY clause +void AddDebugInfoPass::handleRenamesWithoutOnly( + fir::UseStmtOp useOp, mlir::LLVM::DISubprogramAttr spAttr, + mlir::LLVM::DIModuleAttr modAttr, mlir::LLVM::DIFileAttr fileAttr, + mlir::SymbolTable *symbolTable, + llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> &importedModules) { + mlir::MLIRContext *context = &getContext(); + llvm::SmallVector<mlir::LLVM::DINodeAttr> childDeclarations; + + if (auto renames = useOp.getRenames()) { + for (auto attr : *renames) { + auto renameAttr = mlir::cast<fir::UseRenameAttr>(attr); + if (auto importedDecl = createImportedDeclForGlobal( + renameAttr.getSymbol().getValue(), spAttr, fileAttr, + renameAttr.getLocalName(), symbolTable)) + childDeclarations.push_back(*importedDecl); + } + } + + // Create module import with renamed declarations as children + auto moduleImport = mlir::LLVM::DIImportedEntityAttr::get( + context, llvm::dwarf::DW_TAG_imported_module, spAttr, modAttr, fileAttr, + /*line=*/1, /*name=*/nullptr, childDeclarations); + importedModules.insert(moduleImport); +} + +// Process all USE statements in a function and collect imported entities +void AddDebugInfoPass::handleUseStatements( + mlir::func::FuncOp funcOp, mlir::LLVM::DISubprogramAttr spAttr, + mlir::LLVM::DIFileAttr fileAttr, mlir::LLVM::DICompileUnitAttr cuAttr, + mlir::SymbolTable *symbolTable, + llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> &importedEntities) { + mlir::MLIRContext *context = &getContext(); + + funcOp.walk([&](fir::UseStmtOp useOp) { + mlir::LLVM::DIModuleAttr modAttr = getOrCreateModuleAttr( + useOp.getModuleName().str(), fileAttr, cuAttr, /*line=*/1, + /*decl=*/true); + + llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> importedModules; + + if (useOp.hasOnlyClause()) + handleOnlyClause(useOp, spAttr, fileAttr, symbolTable, importedModules); + else if (useOp.hasRenames()) + handleRenamesWithoutOnly(useOp, spAttr, modAttr, fileAttr, symbolTable, + importedModules); + else { + // Simple module import + auto importedEntity = mlir::LLVM::DIImportedEntityAttr::get( + context, llvm::dwarf::DW_TAG_imported_module, spAttr, modAttr, + fileAttr, /*line=*/1, /*name=*/nullptr, /*elements*/ {}); + importedModules.insert(importedEntity); + } + + importedEntities.insert(importedModules.begin(), importedModules.end()); + }); +} + void AddDebugInfoPass::runOnOperation() { mlir::ModuleOp module = getOperation(); mlir::MLIRContext *context = &getContext(); @@ -704,6 +871,26 @@ void AddDebugInfoPass::runOnOperation() { splitDwarfFile.empty() ? mlir::StringAttr() : mlir::StringAttr::get(context, splitDwarfFile)); + // Process module globals early. + // Walk through all DeclareOps in functions and process globals that are + // module variables. This ensures that when we process USE statements, + // the DIGlobalVariable lookups will succeed. + if (debugLevel == mlir::LLVM::DIEmissionKind::Full) { + module.walk([&](fir::cg::XDeclareOp declOp) { + mlir::Operation *defOp = declOp.getMemref().getDefiningOp(); + if (defOp && llvm::isa<fir::AddrOfOp>(defOp)) { + if (auto globalOp = + symbolTable.lookup<fir::GlobalOp>(declOp.getUniqName())) { + // Only process module variables here, not SAVE variables + if (isModuleVariable(globalOp)) { + handleGlobalOp(globalOp, fileAttr, cuAttr, typeGen, &symbolTable, + declOp); + } + } + } + }); + } + module.walk([&](mlir::func::FuncOp funcOp) { handleFuncOp(funcOp, fileAttr, cuAttr, typeGen, &symbolTable); }); diff --git a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp index ed9a2ae..5bf783db 100644 --- a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp +++ b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp @@ -832,8 +832,8 @@ static mlir::Type getEleTy(mlir::Type ty) { static bool isAssumedSize(llvm::SmallVectorImpl<mlir::Value> &extents) { if (extents.empty()) return false; - auto cstLen = fir::getIntIfConstant(extents.back()); - return cstLen.has_value() && *cstLen == -1; + return llvm::isa_and_nonnull<fir::AssumedSizeExtentOp>( + extents.back().getDefiningOp()); } // Extract extents from the ShapeOp/ShapeShiftOp into the result vector. diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt index 0388439..5a3059eb 100644 --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -1,22 +1,35 @@ add_flang_library(FIRTransforms AbstractResult.cpp AddAliasTags.cpp - AffinePromotion.cpp + AddDebugInfo.cpp AffineDemotion.cpp + AffinePromotion.cpp + AlgebraicSimplification.cpp AnnotateConstant.cpp + ArrayValueCopy.cpp + ArrayValueCopy.cpp AssumedRankOpConversion.cpp + CUDA/CUFAddConstructor.cpp + CUDA/CUFAllocationConversion.cpp + CUDA/CUFAllocationConversion.cpp + CUDA/CUFComputeSharedMemoryOffsetsAndSize.cpp + CUDA/CUFDeviceFuncTransform.cpp + CUDA/CUFDeviceGlobal.cpp + CUDA/CUFFunctionRewrite.cpp + CUDA/CUFGPUToLLVMConversion.cpp + CUDA/CUFLaunchAttachAttr.cpp + CUDA/CUFOpConversion.cpp + CUDA/CUFOpConversionLate.cpp + CUDA/CUFPredefinedVarToGPU.cpp CharacterConversion.cpp CompilerGeneratedNames.cpp ConstantArgumentGlobalisation.cpp ControlFlowConverter.cpp - CUFAddConstructor.cpp - CUFDeviceGlobal.cpp - CUFOpConversion.cpp - CUFGPUToLLVMConversion.cpp - CUFComputeSharedMemoryOffsetsAndSize.cpp - ArrayValueCopy.cpp + ConvertComplexPow.cpp + DebugTypeGenerator.cpp ExternalNameConversion.cpp FIRToSCF.cpp + FIRToMemRef.cpp MemoryUtils.cpp MemoryAllocation.cpp StackArrays.cpp @@ -26,17 +39,23 @@ add_flang_library(FIRTransforms SimplifyIntrinsics.cpp AddDebugInfo.cpp PolymorphicOpConversion.cpp - LoopVersioning.cpp - StackReclaim.cpp - VScaleAttr.cpp FunctionAttr.cpp - DebugTypeGenerator.cpp - SetRuntimeCallAttributes.cpp GenRuntimeCallsForTest.cpp - SimplifyFIROperations.cpp - OptimizeArrayRepacking.cpp - ConvertComplexPow.cpp + LoopInvariantCodeMotion.cpp + LoopVersioning.cpp MIFOpConversion.cpp + MemRefDataFlowOpt.cpp + MemoryAllocation.cpp + MemoryUtils.cpp + OptimizeArrayRepacking.cpp + PolymorphicOpConversion.cpp + SetRuntimeCallAttributes.cpp + SimplifyFIROperations.cpp + SimplifyIntrinsics.cpp + SimplifyRegionLite.cpp + StackArrays.cpp + StackReclaim.cpp + VScaleAttr.cpp DEPENDS CUFAttrs @@ -62,12 +81,14 @@ add_flang_library(FIRTransforms MLIR_LIBS MLIRAffineUtils + MLIRAnalysis MLIRFuncDialect MLIRGPUDialect - MLIRLLVMDialect MLIRLLVMCommonConversion + MLIRLLVMDialect MLIRMathTransforms MLIROpenACCDialect MLIROpenACCToLLVMIRTranslation MLIROpenMPDialect + MLIRTransformUtils ) diff --git a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp index baa8e59..baa8e59 100644 --- a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp +++ b/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp new file mode 100644 index 0000000..4e2bcb6 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp @@ -0,0 +1,445 @@ +//===-- CUFAllocationConversion.cpp ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Transforms/CUDA/CUFAllocationConversion.h" +#include "flang/Optimizer/Builder/CUFCommon.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h" +#include "flang/Optimizer/Builder/Runtime/RTBuilder.h" +#include "flang/Optimizer/CodeGen/TypeConverter.h" +#include "flang/Optimizer/Dialect/CUF/CUFOps.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Optimizer/Support/DataLayout.h" +#include "flang/Runtime/CUDA/allocatable.h" +#include "flang/Runtime/CUDA/common.h" +#include "flang/Runtime/CUDA/descriptor.h" +#include "flang/Runtime/CUDA/memory.h" +#include "flang/Runtime/CUDA/pointer.h" +#include "flang/Runtime/allocatable.h" +#include "flang/Runtime/allocator-registry-consts.h" +#include "flang/Support/Fortran.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace fir { +#define GEN_PASS_DEF_CUFALLOCATIONCONVERSION +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +using namespace fir; +using namespace mlir; +using namespace Fortran::runtime; +using namespace Fortran::runtime::cuda; + +namespace { + +template <typename OpTy> +static bool isPinned(OpTy op) { + if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned) + return true; + return false; +} + +static inline unsigned getMemType(cuf::DataAttribute attr) { + if (attr == cuf::DataAttribute::Device) + return kMemTypeDevice; + if (attr == cuf::DataAttribute::Managed) + return kMemTypeManaged; + if (attr == cuf::DataAttribute::Pinned) + return kMemTypePinned; + if (attr == cuf::DataAttribute::Unified) + return kMemTypeUnified; + llvm_unreachable("unsupported memory type"); +} + +static bool inDeviceContext(mlir::Operation *op) { + if (op->getParentOfType<cuf::KernelOp>()) + return true; + if (auto funcOp = op->getParentOfType<mlir::gpu::GPUFuncOp>()) + return true; + if (auto funcOp = op->getParentOfType<mlir::gpu::LaunchOp>()) + return true; + if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) { + if (auto cudaProcAttr = + funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>( + cuf::getProcAttrName())) { + return cudaProcAttr.getValue() != cuf::ProcAttribute::Host && + cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice; + } + } + return false; +} + +template <typename OpTy> +static mlir::LogicalResult convertOpToCall(OpTy op, + mlir::PatternRewriter &rewriter, + mlir::func::FuncOp func) { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + auto fTy = func.getFunctionType(); + + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + mlir::Value sourceLine; + if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>) + sourceLine = fir::factory::locationToLineNo( + builder, loc, op.getSource() ? fTy.getInput(7) : fTy.getInput(6)); + else + sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); + + mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true) + : builder.createBool(loc, false); + mlir::Value errmsg; + if (op.getErrmsg()) { + errmsg = op.getErrmsg(); + } else { + mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType()); + errmsg = fir::AbsentOp::create(builder, loc, boxNoneTy).getResult(); + } + llvm::SmallVector<mlir::Value> args; + if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>) { + mlir::Value pinned = + op.getPinned() + ? op.getPinned() + : builder.createNullConstant( + loc, fir::ReferenceType::get( + mlir::IntegerType::get(op.getContext(), 1))); + if (op.getSource()) { + mlir::Value isDeviceSource = op.getDeviceSource() + ? builder.createBool(loc, true) + : builder.createBool(loc, false); + mlir::Value stream = + op.getStream() ? op.getStream() + : builder.createNullConstant(loc, fTy.getInput(2)); + args = fir::runtime::createArguments( + builder, loc, fTy, op.getBox(), op.getSource(), stream, pinned, + hasStat, errmsg, sourceFile, sourceLine, isDeviceSource); + } else { + mlir::Value stream = + op.getStream() ? op.getStream() + : builder.createNullConstant(loc, fTy.getInput(1)); + mlir::Value deviceInit = + (op.getDataAttrAttr() && + op.getDataAttrAttr().getValue() == cuf::DataAttribute::Device) + ? builder.createBool(loc, true) + : builder.createBool(loc, false); + args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(), + stream, pinned, hasStat, errmsg, + sourceFile, sourceLine, deviceInit); + } + } else { + args = + fir::runtime::createArguments(builder, loc, fTy, op.getBox(), hasStat, + errmsg, sourceFile, sourceLine); + } + auto callOp = fir::CallOp::create(builder, loc, func, args); + rewriter.replaceOp(op, callOp); + return mlir::success(); +} + +struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> { + using OpRewritePattern::OpRewritePattern; + + CUFAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl, + const fir::LLVMTypeConverter *typeConverter) + : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {} + + mlir::LogicalResult + matchAndRewrite(cuf::AllocOp op, + mlir::PatternRewriter &rewriter) const override { + + mlir::Location loc = op.getLoc(); + + if (inDeviceContext(op.getOperation())) { + // In device context just replace the cuf.alloc operation with a fir.alloc + // the cuf.free will be removed. + auto allocaOp = + fir::AllocaOp::create(rewriter, loc, op.getInType(), + op.getUniqName() ? *op.getUniqName() : "", + op.getBindcName() ? *op.getBindcName() : "", + op.getTypeparams(), op.getShape()); + allocaOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr()); + rewriter.replaceOp(op, allocaOp); + return mlir::success(); + } + + auto mod = op->getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + + if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) { + // Convert scalar and known size array allocations. + mlir::Value bytes; + fir::KindMapping kindMap{fir::getKindMapping(mod)}; + if (fir::isa_trivial(op.getInType())) { + int width = cuf::computeElementByteSize(loc, op.getInType(), kindMap); + bytes = + builder.createIntegerConstant(loc, builder.getIndexType(), width); + } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>( + op.getInType())) { + std::size_t size = 0; + if (fir::isa_derived(seqTy.getEleTy())) { + mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy()); + size = dl->getTypeSizeInBits(structTy) / 8; + } else { + size = cuf::computeElementByteSize(loc, seqTy.getEleTy(), kindMap); + } + mlir::Value width = + builder.createIntegerConstant(loc, builder.getIndexType(), size); + mlir::Value nbElem; + if (fir::sequenceWithNonConstantShape(seqTy)) { + assert(!op.getShape().empty() && "expect shape with dynamic arrays"); + nbElem = builder.loadIfRef(loc, op.getShape()[0]); + for (unsigned i = 1; i < op.getShape().size(); ++i) { + nbElem = mlir::arith::MulIOp::create( + rewriter, loc, nbElem, + builder.loadIfRef(loc, op.getShape()[i])); + } + } else { + nbElem = builder.createIntegerConstant(loc, builder.getIndexType(), + seqTy.getConstantArraySize()); + } + bytes = mlir::arith::MulIOp::create(rewriter, loc, nbElem, width); + } else if (fir::isa_derived(op.getInType())) { + mlir::Type structTy = typeConverter->convertType(op.getInType()); + std::size_t structSize = dl->getTypeSizeInBits(structTy) / 8; + bytes = builder.createIntegerConstant(loc, builder.getIndexType(), + structSize); + } else if (fir::isa_char(op.getInType())) { + mlir::Type charTy = typeConverter->convertType(op.getInType()); + std::size_t charSize = dl->getTypeSizeInBits(charTy) / 8; + bytes = builder.createIntegerConstant(loc, builder.getIndexType(), + charSize); + } else { + mlir::emitError(loc, "unsupported type in cuf.alloc\n"); + } + mlir::func::FuncOp func = + fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder); + auto fTy = func.getFunctionType(); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(3)); + mlir::Value memTy = builder.createIntegerConstant( + loc, builder.getI32Type(), getMemType(op.getDataAttr())); + llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( + builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)}; + auto callOp = fir::CallOp::create(builder, loc, func, args); + callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr()); + auto convOp = builder.createConvert(loc, op.getResult().getType(), + callOp.getResult(0)); + rewriter.replaceOp(op, convOp); + return mlir::success(); + } + + // Convert descriptor allocations to function call. + auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType()); + mlir::func::FuncOp func = + fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDescriptor)>(loc, builder); + auto fTy = func.getFunctionType(); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); + + mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy); + std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; + mlir::Value sizeInBytes = + builder.createIntegerConstant(loc, builder.getIndexType(), boxSize); + + llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( + builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)}; + auto callOp = fir::CallOp::create(builder, loc, func, args); + callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr()); + auto convOp = builder.createConvert(loc, op.getResult().getType(), + callOp.getResult(0)); + rewriter.replaceOp(op, convOp); + return mlir::success(); + } + +private: + mlir::DataLayout *dl; + const fir::LLVMTypeConverter *typeConverter; +}; + +struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(cuf::FreeOp op, + mlir::PatternRewriter &rewriter) const override { + if (inDeviceContext(op.getOperation())) { + rewriter.eraseOp(op); + return mlir::success(); + } + + if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType())) + return failure(); + + auto mod = op->getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + + auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType()); + if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) { + mlir::func::FuncOp func = + fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder); + auto fTy = func.getFunctionType(); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(3)); + mlir::Value memTy = builder.createIntegerConstant( + loc, builder.getI32Type(), getMemType(op.getDataAttr())); + llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( + builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)}; + fir::CallOp::create(builder, loc, func, args); + rewriter.eraseOp(op); + return mlir::success(); + } + + // Convert cuf.free on descriptors. + mlir::func::FuncOp func = + fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDescriptor)>(loc, builder); + auto fTy = func.getFunctionType(); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); + llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( + builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)}; + auto callOp = fir::CallOp::create(builder, loc, func, args); + callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr()); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +struct CUFAllocateOpConversion + : public mlir::OpRewritePattern<cuf::AllocateOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(cuf::AllocateOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + bool isPointer = op.getPointer(); + if (op.getHasDoubleDescriptor()) { + // Allocation for module variable are done with custom runtime entry point + // so the descriptors can be synchronized. + mlir::func::FuncOp func; + if (op.getSource()) { + func = isPointer ? fir::runtime::getRuntimeFunc<mkRTKey( + CUFPointerAllocateSourceSync)>(loc, builder) + : fir::runtime::getRuntimeFunc<mkRTKey( + CUFAllocatableAllocateSourceSync)>(loc, builder); + } else { + func = + isPointer + ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSync)>( + loc, builder) + : fir::runtime::getRuntimeFunc<mkRTKey( + CUFAllocatableAllocateSync)>(loc, builder); + } + return convertOpToCall<cuf::AllocateOp>(op, rewriter, func); + } + + mlir::func::FuncOp func; + if (op.getSource()) { + func = + isPointer + ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSource)>( + loc, builder) + : fir::runtime::getRuntimeFunc<mkRTKey( + CUFAllocatableAllocateSource)>(loc, builder); + } else { + func = + isPointer + ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocate)>( + loc, builder) + : fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>( + loc, builder); + } + + return convertOpToCall<cuf::AllocateOp>(op, rewriter, func); + } +}; + +struct CUFDeallocateOpConversion + : public mlir::OpRewritePattern<cuf::DeallocateOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(cuf::DeallocateOp op, + mlir::PatternRewriter &rewriter) const override { + + auto mod = op->getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + if (op.getHasDoubleDescriptor()) { + // Deallocation for module variable are done with custom runtime entry + // point so the descriptors can be synchronized. + mlir::func::FuncOp func = + fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableDeallocate)>( + loc, builder); + return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func); + } + + // Deallocation for local descriptor falls back on the standard runtime + // AllocatableDeallocate as the dedicated deallocator is set in the + // descriptor before the call. + mlir::func::FuncOp func = + fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc, + builder); + return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func); + } +}; + +class CUFAllocationConversion + : public fir::impl::CUFAllocationConversionBase<CUFAllocationConversion> { +public: + void runOnOperation() override { + auto *ctx = &getContext(); + mlir::RewritePatternSet patterns(ctx); + mlir::ConversionTarget target(*ctx); + + mlir::Operation *op = getOperation(); + mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op); + if (!module) + return signalPassFailure(); + mlir::SymbolTable symtab(module); + + std::optional<mlir::DataLayout> dl = fir::support::getOrSetMLIRDataLayout( + module, /*allowDefaultLayout=*/false); + fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false, + /*forceUnifiedTBAATree=*/false, *dl); + target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, + mlir::gpu::GPUDialect>(); + target.addLegalOp<cuf::StreamCastOp>(); + cuf::populateCUFAllocationConversionPatterns(typeConverter, *dl, symtab, + patterns); + if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + mlir::emitError(mlir::UnknownLoc::get(ctx), + "error in CUF allocation conversion\n"); + signalPassFailure(); + } + } +}; + +} // namespace + +void cuf::populateCUFAllocationConversionPatterns( + const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl, + const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) { + patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter); + patterns.insert<CUFFreeOpConversion, CUFAllocateOpConversion, + CUFDeallocateOpConversion>(patterns.getContext()); +} diff --git a/flang/lib/Optimizer/Transforms/CUFComputeSharedMemoryOffsetsAndSize.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFComputeSharedMemoryOffsetsAndSize.cpp index 09126e0..87dc27e 100644 --- a/flang/lib/Optimizer/Transforms/CUFComputeSharedMemoryOffsetsAndSize.cpp +++ b/flang/lib/Optimizer/Transforms/CUDA/CUFComputeSharedMemoryOffsetsAndSize.cpp @@ -41,12 +41,41 @@ namespace { static bool isAssumedSize(mlir::ValueRange shape) { if (shape.size() != 1) return false; - std::optional<std::int64_t> val = fir::getIntIfConstant(shape[0]); - if (val && *val == -1) + if (llvm::isa_and_nonnull<fir::AssumedSizeExtentOp>(shape[0].getDefiningOp())) return true; return false; } +static void createSharedMemoryGlobal(fir::FirOpBuilder &builder, + mlir::Location loc, llvm::StringRef prefix, + llvm::StringRef suffix, + mlir::gpu::GPUModuleOp gpuMod, + mlir::Type sharedMemType, unsigned size, + unsigned align, bool isDynamic) { + std::string sharedMemGlobalName = + isDynamic ? (prefix + llvm::Twine(cudaSharedMemSuffix)).str() + : (prefix + llvm::Twine(cudaSharedMemSuffix) + suffix).str(); + + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToEnd(gpuMod.getBody()); + + mlir::StringAttr linkage = isDynamic ? builder.createExternalLinkage() + : builder.createInternalLinkage(); + llvm::SmallVector<mlir::NamedAttribute> attrs; + auto globalOpName = mlir::OperationName(fir::GlobalOp::getOperationName(), + gpuMod.getContext()); + attrs.push_back(mlir::NamedAttribute( + fir::GlobalOp::getDataAttrAttrName(globalOpName), + cuf::DataAttributeAttr::get(gpuMod.getContext(), + cuf::DataAttribute::Shared))); + + mlir::DenseElementsAttr init = {}; + auto sharedMem = + fir::GlobalOp::create(builder, loc, sharedMemGlobalName, false, false, + sharedMemType, init, linkage, attrs); + sharedMem.setAlignment(align); +} + struct CUFComputeSharedMemoryOffsetsAndSize : public fir::impl::CUFComputeSharedMemoryOffsetsAndSizeBase< CUFComputeSharedMemoryOffsetsAndSize> { @@ -109,18 +138,23 @@ struct CUFComputeSharedMemoryOffsetsAndSize crtDynOffset, dynSize); else crtDynOffset = dynSize; - - continue; + } else { + // Static shared memory. + auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash( + loc, sharedOp.getInType(), *dl, kindMap); + createSharedMemoryGlobal( + builder, sharedOp.getLoc(), funcOp.getName(), + *sharedOp.getBindcName(), gpuMod, + fir::SequenceType::get(size, i8Ty), size, + sharedOp.getAlignment() ? *sharedOp.getAlignment() : align, + /*isDynamic=*/false); + mlir::Value zero = builder.createIntegerConstant(loc, i32Ty, 0); + sharedOp.getOffsetMutable().assign(zero); + if (!sharedOp.getAlignment()) + sharedOp.setAlignment(align); + sharedOp.setIsStatic(true); + ++nbStaticSharedVariables; } - auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash( - sharedOp.getLoc(), sharedOp.getInType(), *dl, kindMap); - ++nbStaticSharedVariables; - mlir::Value offset = builder.createIntegerConstant( - loc, i32Ty, llvm::alignTo(sharedMemSize, align)); - sharedOp.getOffsetMutable().assign(offset); - sharedMemSize = - llvm::alignTo(sharedMemSize, align) + llvm::alignTo(size, align); - alignment = std::max(alignment, align); } if (nbDynamicSharedVariables == 0 && nbStaticSharedVariables == 0) @@ -131,35 +165,13 @@ struct CUFComputeSharedMemoryOffsetsAndSize funcOp.getLoc(), "static and dynamic shared variables in a single kernel"); - mlir::DenseElementsAttr init = {}; - if (sharedMemSize > 0) { - auto vecTy = mlir::VectorType::get(sharedMemSize, i8Ty); - mlir::Attribute zero = mlir::IntegerAttr::get(i8Ty, 0); - init = mlir::DenseElementsAttr::get(vecTy, llvm::ArrayRef(zero)); - } + if (nbStaticSharedVariables > 0) + continue; - // Create the shared memory global where each shared variable will point - // to. auto sharedMemType = fir::SequenceType::get(sharedMemSize, i8Ty); - std::string sharedMemGlobalName = - (funcOp.getName() + llvm::Twine(cudaSharedMemSuffix)).str(); - // Dynamic shared memory needs an external linkage while static shared - // memory needs an internal linkage. - mlir::StringAttr linkage = nbDynamicSharedVariables > 0 - ? builder.createExternalLinkage() - : builder.createInternalLinkage(); - builder.setInsertionPointToEnd(gpuMod.getBody()); - llvm::SmallVector<mlir::NamedAttribute> attrs; - auto globalOpName = mlir::OperationName(fir::GlobalOp::getOperationName(), - gpuMod.getContext()); - attrs.push_back(mlir::NamedAttribute( - fir::GlobalOp::getDataAttrAttrName(globalOpName), - cuf::DataAttributeAttr::get(gpuMod.getContext(), - cuf::DataAttribute::Shared))); - auto sharedMem = fir::GlobalOp::create( - builder, funcOp.getLoc(), sharedMemGlobalName, false, false, - sharedMemType, init, linkage, attrs); - sharedMem.setAlignment(alignment); + createSharedMemoryGlobal(builder, funcOp.getLoc(), funcOp.getName(), "", + gpuMod, sharedMemType, sharedMemSize, alignment, + /*isDynamic=*/true); } } }; diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFDeviceFuncTransform.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFDeviceFuncTransform.cpp new file mode 100644 index 0000000..4532af9 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/CUDA/CUFDeviceFuncTransform.cpp @@ -0,0 +1,250 @@ +//===-- CUFDeviceFuncTransform.cpp ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/CUFCommon.h" +#include "flang/Optimizer/Builder/Todo.h" +#include "flang/Optimizer/Dialect/CUF/CUFOps.h" +#include "flang/Optimizer/Dialect/FIRAttr.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROpsSupport.h" +#include "flang/Optimizer/Support/InternalNames.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/StringSet.h" + +namespace fir { +#define GEN_PASS_DEF_CUFDEVICEFUNCTRANSFORM +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +using namespace mlir; + +namespace { + +class CUFDeviceFuncTransform + : public fir::impl::CUFDeviceFuncTransformBase<CUFDeviceFuncTransform> { + using CUFDeviceFuncTransformBase< + CUFDeviceFuncTransform>::CUFDeviceFuncTransformBase; + + static gpu::GPUFuncOp createGPUFuncOp(mlir::func::FuncOp funcOp, + bool isGlobal, int computeCap) { + mlir::OpBuilder builder(funcOp.getContext()); + + mlir::Region &funcOpBody = funcOp.getBody(); + SetVector<Value> operands; + for (mlir::Value operand : funcOp.getArguments()) + operands.insert(operand); + + llvm::SmallVector<mlir::Type> funcOperandTypes; + llvm::SmallVector<mlir::Type> funcResultTypes; + funcOperandTypes.reserve(funcOp.getArgumentTypes().size()); + funcResultTypes.reserve(funcOp.getResultTypes().size()); + for (mlir::Type opTy : funcOp.getArgumentTypes()) + funcOperandTypes.push_back(opTy); + for (mlir::Type resTy : funcOp.getResultTypes()) + funcResultTypes.push_back(resTy); + + mlir::Location loc = funcOp.getLoc(); + + mlir::FunctionType type = mlir::FunctionType::get( + funcOp.getContext(), funcOperandTypes, funcResultTypes); + + auto deviceFuncOp = + gpu::GPUFuncOp::create(builder, loc, funcOp.getName(), type, + mlir::TypeRange{}, mlir::TypeRange{}); + if (isGlobal) + deviceFuncOp->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), + builder.getUnitAttr()); + + mlir::Region &deviceFuncBody = deviceFuncOp.getBody(); + mlir::Block &entryBlock = deviceFuncBody.front(); + + mlir::IRMapping map; + for (const auto &operand : enumerate(operands)) + map.map(operand.value(), entryBlock.getArgument(operand.index())); + + funcOpBody.cloneInto(&deviceFuncBody, map); + + deviceFuncOp.walk([](func::ReturnOp op) { + mlir::OpBuilder replacer(op); + gpu::ReturnOp gpuReturnOp = gpu::ReturnOp::create(replacer, op.getLoc()); + gpuReturnOp->setOperands(op.getOperands()); + op.erase(); + }); + + mlir::Block &funcOpEntry = funcOp.front(); + mlir::Block *clonedFuncOpEntry = map.lookup(&funcOpEntry); + + entryBlock.getOperations().splice(entryBlock.getOperations().end(), + clonedFuncOpEntry->getOperations()); + clonedFuncOpEntry->erase(); + + auto launchBoundsAttr = + funcOp.getOperation()->getAttrOfType<cuf::LaunchBoundsAttr>( + cuf::getLaunchBoundsAttrName()); + if (launchBoundsAttr) { + auto maxTPB = launchBoundsAttr.getMaxTPB().getInt(); + auto maxntid = + builder.getDenseI32ArrayAttr({static_cast<int32_t>(maxTPB), 1, 1}); + deviceFuncOp->setAttr(NVVM::NVVMDialect::getMaxntidAttrName(), maxntid); + deviceFuncOp->setAttr(NVVM::NVVMDialect::getMinctasmAttrName(), + launchBoundsAttr.getMinBPM()); + if (computeCap >= 90 && launchBoundsAttr.getUpperBoundClusterSize()) + deviceFuncOp->setAttr(NVVM::NVVMDialect::getClusterMaxBlocksAttrName(), + launchBoundsAttr.getUpperBoundClusterSize()); + } + + return deviceFuncOp; + } + + static void createHostStub(mlir::func::FuncOp funcOp, + mlir::SymbolTable &symTab, mlir::ModuleOp mod) { + mlir::Location loc = funcOp.getLoc(); + mlir::OpBuilder modBuilder(mod.getBodyRegion()); + modBuilder.setInsertionPointToEnd(mod.getBody()); + auto emptyStub = func::FuncOp::create(modBuilder, loc, funcOp.getName(), + funcOp.getFunctionType()); + emptyStub.setVisibility(funcOp.getVisibility()); + emptyStub->setAttrs(funcOp->getAttrs()); + auto entryBlock = emptyStub.addEntryBlock(); + modBuilder.setInsertionPointToEnd(entryBlock); + func::ReturnOp::create(modBuilder, loc); + + symTab.erase(funcOp); + symTab.insert(emptyStub); + } + + static bool isDeviceFunc(mlir::func::FuncOp funcOp) { + if (auto cudaProcAttr = + funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>( + cuf::getProcAttrName())) + if (cudaProcAttr.getValue() == cuf::ProcAttribute::Device || + cudaProcAttr.getValue() == cuf::ProcAttribute::Global || + cudaProcAttr.getValue() == cuf::ProcAttribute::GridGlobal || + cudaProcAttr.getValue() == cuf::ProcAttribute::HostDevice) + return true; + return false; + } + + void runOnOperation() override { + // Working on Module operation because inserting/removing function from the + // module is not thread-safe. + ModuleOp mod = getOperation(); + mlir::SymbolTable symbolTable(getOperation()); + + auto *ctx = getOperation().getContext(); + mlir::OpBuilder builder(ctx); + + gpu::GPUModuleOp gpuMod = cuf::getOrCreateGPUModule(mod, symbolTable); + mlir::SymbolTable gpuModSymTab(gpuMod); + + llvm::SetVector<mlir::func::FuncOp> funcsToClone; + llvm::SetVector<mlir::func::FuncOp> deviceFuncs; + llvm::SetVector<mlir::func::FuncOp> keepInModule; + llvm::StringSet<> deviceFuncNames; + + // Look for all function to migrate to the GPU module. + mod.walk([&](mlir::func::FuncOp op) { + if (isDeviceFunc(op)) { + deviceFuncs.insert(op); + deviceFuncNames.insert(op.getSymName()); + } + }); + + auto processCallOp = [&](fir::CallOp op) { + if (op.getCallee()) { + auto func = symbolTable.lookup<mlir::func::FuncOp>( + op.getCallee()->getLeafReference()); + if (deviceFuncs.count(func) == 0) + funcsToClone.insert(func); + } + }; + + // Gather all function called by device functions. + for (auto funcOp : deviceFuncs) { + funcOp.walk([&](fir::CallOp op) { processCallOp(op); }); + funcOp.walk([&](fir::DispatchOp op) { + TODO(op.getLoc(), "type-bound procedure call with dynamic dispatch " + "in device procedure"); + }); + } + + // Functions that are referenced in a derived-type binding table must be + // kept in the host module to avoid LLVM dialect verification errors. + for (auto globalOp : mod.getOps<fir::GlobalOp>()) { + if (globalOp.getName().contains(fir::kBindingTableSeparator)) { + globalOp.walk([&](fir::AddrOfOp addrOfOp) { + if (deviceFuncNames.contains(addrOfOp.getSymbol().getLeafReference())) + keepInModule.insert( + *llvm::find_if(deviceFuncs, [&](mlir::func::FuncOp f) { + return f.getSymName() == + addrOfOp.getSymbol().getLeafReference(); + })); + }); + } + } + + // Gather all functions called by CUF kernels. + mod.walk([&](cuf::KernelOp kernelOp) { + kernelOp.walk([&](fir::CallOp op) { processCallOp(op); }); + kernelOp.walk([&](fir::DispatchOp op) { + TODO(op.getLoc(), + "type-bound procedure call with dynamic dispatch in cuf kernel"); + }); + }); + + for (auto funcOp : funcsToClone) + gpuModSymTab.insert(funcOp->clone()); + + for (auto funcOp : deviceFuncs) { + auto cudaProcAttr = + funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>( + cuf::getProcAttrName()); + auto isGlobal = cudaProcAttr.getValue() == cuf::ProcAttribute::Global || + cudaProcAttr.getValue() == cuf::ProcAttribute::GridGlobal; + if (funcOp.isDeclaration()) { + mlir::Operation *clonedFuncOp = funcOp->clone(); + if (isGlobal) { + clonedFuncOp->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), + builder.getUnitAttr()); + clonedFuncOp->removeAttr(cuf::getProcAttrName()); + if (auto funcOp = mlir::dyn_cast<func::FuncOp>(clonedFuncOp)) + funcOp.setNested(); + } + gpuModSymTab.insert(clonedFuncOp); + } else { + gpu::GPUFuncOp deviceFuncOp = + createGPUFuncOp(funcOp, isGlobal, computeCap); + gpuModSymTab.insert(deviceFuncOp); + + if (cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice) { + // If the function is a global, we need to keep the host side + // declaration for the kernel registration. Currently we just + // erase its body but in the future, the body should be rewritten + // to be able to launch CUDA Fortran kernel from C code. + if (isGlobal || keepInModule.contains(funcOp)) + createHostStub(funcOp, symbolTable, mod); + else + funcOp.erase(); + } + } + } + } +}; + +} // end anonymous namespace diff --git a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFDeviceGlobal.cpp index 35badb6..35badb6 100644 --- a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp +++ b/flang/lib/Optimizer/Transforms/CUDA/CUFDeviceGlobal.cpp diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp new file mode 100644 index 0000000..bcbfb529 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp @@ -0,0 +1,103 @@ +//===-- CUFFunctionRewrite.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/CodeGen/TypeConverter.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Support/DataLayout.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" +#include <string_view> + +#define DEBUG_TYPE "flang-cuf-function-rewrite" + +namespace fir { +#define GEN_PASS_DEF_CUFFUNCTIONREWRITE +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +using namespace mlir; + +namespace { + +using genFunctionType = + std::function<mlir::Value(mlir::PatternRewriter &, fir::CallOp op)>; + +class CallConversion : public OpRewritePattern<fir::CallOp> { +public: + CallConversion(MLIRContext *context) + : OpRewritePattern<fir::CallOp>(context) {} + + LogicalResult + matchAndRewrite(fir::CallOp op, + mlir::PatternRewriter &rewriter) const override { + auto callee = op.getCallee(); + if (!callee) + return failure(); + auto name = callee->getRootReference().getValue(); + + if (genMappings_.contains(name)) { + auto fct = genMappings_.find(name); + mlir::Value result = fct->second(rewriter, op); + if (result) + rewriter.replaceOp(op, result); + else + rewriter.eraseOp(op); + return success(); + } + return failure(); + } + +private: + static mlir::Value genOnDevice(mlir::PatternRewriter &rewriter, + fir::CallOp op) { + assert(op.getArgs().size() == 0 && "expect 0 arguments"); + mlir::Location loc = op.getLoc(); + unsigned inGPUMod = op->getParentOfType<gpu::GPUModuleOp>() ? 1 : 0; + mlir::Type i1Ty = rewriter.getIntegerType(1); + mlir::Value t = mlir::arith::ConstantOp::create( + rewriter, loc, i1Ty, rewriter.getIntegerAttr(i1Ty, inGPUMod)); + return fir::ConvertOp::create(rewriter, loc, op.getResult(0).getType(), t); + } + + const llvm::StringMap<genFunctionType> genMappings_ = { + {"on_device", &genOnDevice}}; +}; + +class CUFFunctionRewrite + : public fir::impl::CUFFunctionRewriteBase<CUFFunctionRewrite> { +public: + void runOnOperation() override { + auto *ctx = &getContext(); + mlir::RewritePatternSet patterns(ctx); + + patterns.insert<CallConversion>(patterns.getContext()); + + if (mlir::failed( + mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { + mlir::emitError(mlir::UnknownLoc::get(ctx), + "error in CUFFunctionRewrite op conversion\n"); + signalPassFailure(); + } + } +}; + +} // namespace diff --git a/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFGPUToLLVMConversion.cpp index 40f180a..d5a8212 100644 --- a/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUDA/CUFGPUToLLVMConversion.cpp @@ -249,8 +249,13 @@ struct CUFSharedMemoryOpConversion "cuf.shared_memory must have an offset for code gen"); auto gpuMod = op->getParentOfType<gpu::GPUModuleOp>(); + std::string sharedGlobalName = - (getFuncName(op) + llvm::Twine(cudaSharedMemSuffix)).str(); + op.getIsStatic() + ? (getFuncName(op) + llvm::Twine(cudaSharedMemSuffix) + + *op.getBindcName()) + .str() + : (getFuncName(op) + llvm::Twine(cudaSharedMemSuffix)).str(); mlir::Value sharedGlobalAddr = createAddressOfOp(rewriter, loc, gpuMod, sharedGlobalName); diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFLaunchAttachAttr.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFLaunchAttachAttr.cpp new file mode 100644 index 0000000..41a0e5c --- /dev/null +++ b/flang/lib/Optimizer/Transforms/CUDA/CUFLaunchAttachAttr.cpp @@ -0,0 +1,70 @@ +//===-- CUFLaunchAttachAttr.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/CUF/CUFDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIROpsSupport.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace fir { +#define GEN_PASS_DEF_CUFLAUNCHATTACHATTR +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +using namespace mlir; + +namespace { + +static constexpr llvm::StringRef cudaKernelInfix = "_cufk_"; + +class CUFGPUAttachAttrPattern + : public OpRewritePattern<mlir::gpu::LaunchFuncOp> { + using OpRewritePattern<mlir::gpu::LaunchFuncOp>::OpRewritePattern; + LogicalResult matchAndRewrite(mlir::gpu::LaunchFuncOp op, + PatternRewriter &rewriter) const override { + op->setAttr(cuf::getProcAttrName(), + cuf::ProcAttributeAttr::get(op.getContext(), + cuf::ProcAttribute::Global)); + return mlir::success(); + } +}; + +struct CUFLaunchAttachAttr + : public fir::impl::CUFLaunchAttachAttrBase<CUFLaunchAttachAttr> { + + void runOnOperation() override { + auto *context = &this->getContext(); + + mlir::RewritePatternSet patterns(context); + patterns.add<CUFGPUAttachAttrPattern>(context); + + mlir::ConversionTarget target(*context); + target.addIllegalOp<mlir::gpu::LaunchFuncOp>(); + target.addDynamicallyLegalOp<mlir::gpu::LaunchFuncOp>( + [&](mlir::gpu::LaunchFuncOp op) -> bool { + if (op.getKernelName().getValue().contains(cudaKernelInfix)) { + if (op.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>( + cuf::getProcAttrName())) + return true; + return false; + } + return true; + }); + + if (mlir::failed(mlir::applyPartialConversion(this->getOperation(), target, + std::move(patterns)))) { + mlir::emitError(mlir::UnknownLoc::get(context), + "Pattern conversion failed\n"); + this->signalPassFailure(); + } + } +}; + +} // end anonymous namespace diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp index 759e3a65d..ddae324 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp @@ -1,4 +1,4 @@ -//===-- CUFDeviceGlobal.cpp -----------------------------------------------===// +//===-- CUFOpConversion.cpp -----------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -16,6 +16,7 @@ #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/Support/DataLayout.h" +#include "flang/Optimizer/Transforms/Passes.h" #include "flang/Runtime/CUDA/allocatable.h" #include "flang/Runtime/CUDA/common.h" #include "flang/Runtime/CUDA/descriptor.h" @@ -27,6 +28,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -44,213 +46,12 @@ using namespace Fortran::runtime::cuda; namespace { -static inline unsigned getMemType(cuf::DataAttribute attr) { - if (attr == cuf::DataAttribute::Device) - return kMemTypeDevice; - if (attr == cuf::DataAttribute::Managed) - return kMemTypeManaged; - if (attr == cuf::DataAttribute::Unified) - return kMemTypeUnified; - if (attr == cuf::DataAttribute::Pinned) - return kMemTypePinned; - llvm::report_fatal_error("unsupported memory type"); -} - -template <typename OpTy> -static bool isPinned(OpTy op) { - if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned) - return true; - return false; -} - -template <typename OpTy> -static bool hasDoubleDescriptors(OpTy op) { - if (auto declareOp = - mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp())) { - if (mlir::isa_and_nonnull<fir::AddrOfOp>( - declareOp.getMemref().getDefiningOp())) { - if (isPinned(declareOp)) - return false; - return true; - } - } else if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>( - op.getBox().getDefiningOp())) { - if (mlir::isa_and_nonnull<fir::AddrOfOp>( - declareOp.getMemref().getDefiningOp())) { - if (isPinned(declareOp)) - return false; - return true; - } - } - return false; -} - -static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter, - mlir::Location loc, mlir::Type toTy, - mlir::Value val) { - if (val.getType() != toTy) - return fir::ConvertOp::create(rewriter, loc, toTy, val); - return val; -} - -template <typename OpTy> -static mlir::LogicalResult convertOpToCall(OpTy op, - mlir::PatternRewriter &rewriter, - mlir::func::FuncOp func) { - auto mod = op->template getParentOfType<mlir::ModuleOp>(); - fir::FirOpBuilder builder(rewriter, mod); - mlir::Location loc = op.getLoc(); - auto fTy = func.getFunctionType(); - - mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); - mlir::Value sourceLine; - if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>) - sourceLine = fir::factory::locationToLineNo( - builder, loc, op.getSource() ? fTy.getInput(7) : fTy.getInput(6)); - else - sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); - - mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true) - : builder.createBool(loc, false); - - mlir::Value errmsg; - if (op.getErrmsg()) { - errmsg = op.getErrmsg(); - } else { - mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType()); - errmsg = fir::AbsentOp::create(builder, loc, boxNoneTy).getResult(); - } - llvm::SmallVector<mlir::Value> args; - if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>) { - mlir::Value pinned = - op.getPinned() - ? op.getPinned() - : builder.createNullConstant( - loc, fir::ReferenceType::get( - mlir::IntegerType::get(op.getContext(), 1))); - if (op.getSource()) { - mlir::Value stream = - op.getStream() ? op.getStream() - : builder.createNullConstant(loc, fTy.getInput(2)); - args = fir::runtime::createArguments( - builder, loc, fTy, op.getBox(), op.getSource(), stream, pinned, - hasStat, errmsg, sourceFile, sourceLine); - } else { - mlir::Value stream = - op.getStream() ? op.getStream() - : builder.createNullConstant(loc, fTy.getInput(1)); - args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(), - stream, pinned, hasStat, errmsg, - sourceFile, sourceLine); - } - } else { - args = - fir::runtime::createArguments(builder, loc, fTy, op.getBox(), hasStat, - errmsg, sourceFile, sourceLine); - } - auto callOp = fir::CallOp::create(builder, loc, func, args); - rewriter.replaceOp(op, callOp); - return mlir::success(); -} - -struct CUFAllocateOpConversion - : public mlir::OpRewritePattern<cuf::AllocateOp> { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(cuf::AllocateOp op, - mlir::PatternRewriter &rewriter) const override { - auto mod = op->getParentOfType<mlir::ModuleOp>(); - fir::FirOpBuilder builder(rewriter, mod); - mlir::Location loc = op.getLoc(); - - bool isPointer = false; - - if (auto declareOp = - mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp())) - if (declareOp.getFortranAttrs() && - bitEnumContainsAny(*declareOp.getFortranAttrs(), - fir::FortranVariableFlagsEnum::pointer)) - isPointer = true; - - if (hasDoubleDescriptors(op)) { - // Allocation for module variable are done with custom runtime entry point - // so the descriptors can be synchronized. - mlir::func::FuncOp func; - if (op.getSource()) { - func = isPointer ? fir::runtime::getRuntimeFunc<mkRTKey( - CUFPointerAllocateSourceSync)>(loc, builder) - : fir::runtime::getRuntimeFunc<mkRTKey( - CUFAllocatableAllocateSourceSync)>(loc, builder); - } else { - func = - isPointer - ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSync)>( - loc, builder) - : fir::runtime::getRuntimeFunc<mkRTKey( - CUFAllocatableAllocateSync)>(loc, builder); - } - return convertOpToCall<cuf::AllocateOp>(op, rewriter, func); - } - - mlir::func::FuncOp func; - if (op.getSource()) { - func = - isPointer - ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSource)>( - loc, builder) - : fir::runtime::getRuntimeFunc<mkRTKey( - CUFAllocatableAllocateSource)>(loc, builder); - } else { - func = - isPointer - ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocate)>( - loc, builder) - : fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>( - loc, builder); - } - - return convertOpToCall<cuf::AllocateOp>(op, rewriter, func); - } -}; - -struct CUFDeallocateOpConversion - : public mlir::OpRewritePattern<cuf::DeallocateOp> { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(cuf::DeallocateOp op, - mlir::PatternRewriter &rewriter) const override { - - auto mod = op->getParentOfType<mlir::ModuleOp>(); - fir::FirOpBuilder builder(rewriter, mod); - mlir::Location loc = op.getLoc(); - - if (hasDoubleDescriptors(op)) { - // Deallocation for module variable are done with custom runtime entry - // point so the descriptors can be synchronized. - mlir::func::FuncOp func = - fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableDeallocate)>( - loc, builder); - return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func); - } - - // Deallocation for local descriptor falls back on the standard runtime - // AllocatableDeallocate as the dedicated deallocator is set in the - // descriptor before the call. - mlir::func::FuncOp func = - fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc, - builder); - return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func); - } -}; - static bool inDeviceContext(mlir::Operation *op) { if (op->getParentOfType<cuf::KernelOp>()) return true; - if (auto funcOp = op->getParentOfType<mlir::gpu::GPUFuncOp>()) + if (op->getParentOfType<mlir::acc::OffloadRegionOpInterface>()) return true; - if (auto funcOp = op->getParentOfType<mlir::gpu::LaunchOp>()) + if (auto funcOp = op->getParentOfType<mlir::gpu::GPUFuncOp>()) return true; if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) { if (auto cudaProcAttr = @@ -263,187 +64,14 @@ static bool inDeviceContext(mlir::Operation *op) { return false; } -static int computeWidth(mlir::Location loc, mlir::Type type, - fir::KindMapping &kindMap) { - auto eleTy = fir::unwrapSequenceType(type); - if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) - return t.getWidth() / 8; - if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) - return t.getWidth() / 8; - if (eleTy.isInteger(1)) - return 1; - if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) - return kindMap.getLogicalBitsize(t.getFKind()) / 8; - if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) { - int elemSize = - mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8; - return 2 * elemSize; - } - if (auto t{mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)}) - return kindMap.getCharacterBitsize(t.getFKind()) / 8; - mlir::emitError(loc, "unsupported type"); - return 0; +static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter, + mlir::Location loc, mlir::Type toTy, + mlir::Value val) { + if (val.getType() != toTy) + return fir::ConvertOp::create(rewriter, loc, toTy, val); + return val; } -struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> { - using OpRewritePattern::OpRewritePattern; - - CUFAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl, - const fir::LLVMTypeConverter *typeConverter) - : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {} - - mlir::LogicalResult - matchAndRewrite(cuf::AllocOp op, - mlir::PatternRewriter &rewriter) const override { - - mlir::Location loc = op.getLoc(); - - if (inDeviceContext(op.getOperation())) { - // In device context just replace the cuf.alloc operation with a fir.alloc - // the cuf.free will be removed. - auto allocaOp = - fir::AllocaOp::create(rewriter, loc, op.getInType(), - op.getUniqName() ? *op.getUniqName() : "", - op.getBindcName() ? *op.getBindcName() : "", - op.getTypeparams(), op.getShape()); - allocaOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr()); - rewriter.replaceOp(op, allocaOp); - return mlir::success(); - } - - auto mod = op->getParentOfType<mlir::ModuleOp>(); - fir::FirOpBuilder builder(rewriter, mod); - mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); - - if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) { - // Convert scalar and known size array allocations. - mlir::Value bytes; - fir::KindMapping kindMap{fir::getKindMapping(mod)}; - if (fir::isa_trivial(op.getInType())) { - int width = computeWidth(loc, op.getInType(), kindMap); - bytes = - builder.createIntegerConstant(loc, builder.getIndexType(), width); - } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>( - op.getInType())) { - std::size_t size = 0; - if (fir::isa_derived(seqTy.getEleTy())) { - mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy()); - size = dl->getTypeSizeInBits(structTy) / 8; - } else { - size = computeWidth(loc, seqTy.getEleTy(), kindMap); - } - mlir::Value width = - builder.createIntegerConstant(loc, builder.getIndexType(), size); - mlir::Value nbElem; - if (fir::sequenceWithNonConstantShape(seqTy)) { - assert(!op.getShape().empty() && "expect shape with dynamic arrays"); - nbElem = builder.loadIfRef(loc, op.getShape()[0]); - for (unsigned i = 1; i < op.getShape().size(); ++i) { - nbElem = mlir::arith::MulIOp::create( - rewriter, loc, nbElem, - builder.loadIfRef(loc, op.getShape()[i])); - } - } else { - nbElem = builder.createIntegerConstant(loc, builder.getIndexType(), - seqTy.getConstantArraySize()); - } - bytes = mlir::arith::MulIOp::create(rewriter, loc, nbElem, width); - } else if (fir::isa_derived(op.getInType())) { - mlir::Type structTy = typeConverter->convertType(op.getInType()); - std::size_t structSize = dl->getTypeSizeInBits(structTy) / 8; - bytes = builder.createIntegerConstant(loc, builder.getIndexType(), - structSize); - } else { - mlir::emitError(loc, "unsupported type in cuf.alloc\n"); - } - mlir::func::FuncOp func = - fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder); - auto fTy = func.getFunctionType(); - mlir::Value sourceLine = - fir::factory::locationToLineNo(builder, loc, fTy.getInput(3)); - mlir::Value memTy = builder.createIntegerConstant( - loc, builder.getI32Type(), getMemType(op.getDataAttr())); - llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( - builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)}; - auto callOp = fir::CallOp::create(builder, loc, func, args); - callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr()); - auto convOp = builder.createConvert(loc, op.getResult().getType(), - callOp.getResult(0)); - rewriter.replaceOp(op, convOp); - return mlir::success(); - } - - // Convert descriptor allocations to function call. - auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType()); - mlir::func::FuncOp func = - fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDescriptor)>(loc, builder); - auto fTy = func.getFunctionType(); - mlir::Value sourceLine = - fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); - - mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy); - std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; - mlir::Value sizeInBytes = - builder.createIntegerConstant(loc, builder.getIndexType(), boxSize); - - llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( - builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)}; - auto callOp = fir::CallOp::create(builder, loc, func, args); - callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr()); - auto convOp = builder.createConvert(loc, op.getResult().getType(), - callOp.getResult(0)); - rewriter.replaceOp(op, convOp); - return mlir::success(); - } - -private: - mlir::DataLayout *dl; - const fir::LLVMTypeConverter *typeConverter; -}; - -struct CUFDeviceAddressOpConversion - : public mlir::OpRewritePattern<cuf::DeviceAddressOp> { - using OpRewritePattern::OpRewritePattern; - - CUFDeviceAddressOpConversion(mlir::MLIRContext *context, - const mlir::SymbolTable &symtab) - : OpRewritePattern(context), symTab{symtab} {} - - mlir::LogicalResult - matchAndRewrite(cuf::DeviceAddressOp op, - mlir::PatternRewriter &rewriter) const override { - if (auto global = symTab.lookup<fir::GlobalOp>( - op.getHostSymbol().getRootReference().getValue())) { - auto mod = op->getParentOfType<mlir::ModuleOp>(); - mlir::Location loc = op.getLoc(); - auto hostAddr = fir::AddrOfOp::create( - rewriter, loc, fir::ReferenceType::get(global.getType()), - op.getHostSymbol()); - fir::FirOpBuilder builder(rewriter, mod); - mlir::func::FuncOp callee = - fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, - builder); - auto fTy = callee.getFunctionType(); - mlir::Value conv = - createConvertOp(rewriter, loc, fTy.getInput(0), hostAddr); - mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); - mlir::Value sourceLine = - fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); - llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( - builder, loc, fTy, conv, sourceFile, sourceLine)}; - auto call = fir::CallOp::create(rewriter, loc, callee, args); - mlir::Value addr = createConvertOp(rewriter, loc, hostAddr.getType(), - call->getResult(0)); - rewriter.replaceOp(op, addr.getDefiningOp()); - return success(); - } - return failure(); - } - -private: - const mlir::SymbolTable &symTab; -}; - struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> { using OpRewritePattern::OpRewritePattern; @@ -454,7 +82,12 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> { mlir::LogicalResult matchAndRewrite(fir::DeclareOp op, mlir::PatternRewriter &rewriter) const override { + if (op.getResult().getUsers().empty()) + return success(); if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) { + if (inDeviceContext(addrOfOp)) { + return failure(); + } if (auto global = symTab.lookup<fir::GlobalOp>( addrOfOp.getSymbol().getRootReference().getValue())) { if (cuf::isRegisteredDeviceGlobal(global)) { @@ -475,56 +108,6 @@ private: const mlir::SymbolTable &symTab; }; -struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(cuf::FreeOp op, - mlir::PatternRewriter &rewriter) const override { - if (inDeviceContext(op.getOperation())) { - rewriter.eraseOp(op); - return mlir::success(); - } - - if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType())) - return failure(); - - auto mod = op->getParentOfType<mlir::ModuleOp>(); - fir::FirOpBuilder builder(rewriter, mod); - mlir::Location loc = op.getLoc(); - mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); - - auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType()); - if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) { - mlir::func::FuncOp func = - fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder); - auto fTy = func.getFunctionType(); - mlir::Value sourceLine = - fir::factory::locationToLineNo(builder, loc, fTy.getInput(3)); - mlir::Value memTy = builder.createIntegerConstant( - loc, builder.getI32Type(), getMemType(op.getDataAttr())); - llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( - builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)}; - fir::CallOp::create(builder, loc, func, args); - rewriter.eraseOp(op); - return mlir::success(); - } - - // Convert cuf.free on descriptors. - mlir::func::FuncOp func = - fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDescriptor)>(loc, builder); - auto fTy = func.getFunctionType(); - mlir::Value sourceLine = - fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); - llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( - builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)}; - auto callOp = fir::CallOp::create(builder, loc, func, args); - callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr()); - rewriter.eraseOp(op); - return mlir::success(); - } -}; - static bool isDstGlobal(cuf::DataTransferOp op) { if (auto declareOp = op.getDst().getDefiningOp<fir::DeclareOp>()) if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>()) @@ -671,38 +254,15 @@ struct CUFDataTransferOpConversion } mlir::Type i64Ty = builder.getI64Type(); - mlir::Value nbElement; - if (op.getShape()) { - llvm::SmallVector<mlir::Value> extents; - if (auto shapeOp = - mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp())) { - extents = shapeOp.getExtents(); - } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>( - op.getShape().getDefiningOp())) { - for (auto i : llvm::enumerate(shapeShiftOp.getPairs())) - if (i.index() & 1) - extents.push_back(i.value()); - } - - nbElement = fir::ConvertOp::create(rewriter, loc, i64Ty, extents[0]); - for (unsigned i = 1; i < extents.size(); ++i) { - auto operand = - fir::ConvertOp::create(rewriter, loc, i64Ty, extents[i]); - nbElement = - mlir::arith::MulIOp::create(rewriter, loc, nbElement, operand); - } - } else { - if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(dstTy)) - nbElement = builder.createIntegerConstant( - loc, i64Ty, seqTy.getConstantArraySize()); - } + mlir::Value nbElement = + cuf::computeElementCount(rewriter, loc, op.getShape(), dstTy, i64Ty); unsigned width = 0; if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) { mlir::Type structTy = typeConverter->convertType(fir::unwrapSequenceType(dstTy)); width = dl->getTypeSizeInBits(structTy) / 8; } else { - width = computeWidth(loc, dstTy, kindMap); + width = cuf::computeElementByteSize(loc, dstTy, kindMap); } mlir::Value widthValue = mlir::arith::ConstantOp::create( rewriter, loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width)); @@ -934,6 +494,8 @@ struct CUFSyncDescriptorOpConversion }; class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> { + using CUFOpConversionBase::CUFOpConversionBase; + public: void runOnOperation() override { auto *ctx = &getContext(); @@ -953,6 +515,7 @@ public: target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, mlir::gpu::GPUDialect>(); target.addLegalOp<cuf::StreamCastOp>(); + target.addLegalOp<cuf::DeviceAddressOp>(); cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab, patterns); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, @@ -963,6 +526,8 @@ public: } target.addDynamicallyLegalOp<fir::DeclareOp>([&](fir::DeclareOp op) { + if (op.getResult().getUsers().empty()) + return true; if (inDeviceContext(op)) return true; if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) { @@ -992,18 +557,13 @@ public: void cuf::populateCUFToFIRConversionPatterns( const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl, const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) { - patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter); - patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion, - CUFFreeOpConversion, CUFSyncDescriptorOpConversion>( - patterns.getContext()); + patterns.insert<CUFSyncDescriptorOpConversion>(patterns.getContext()); patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab, &dl, &converter); - patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>( - patterns.getContext(), symtab); + patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab); } void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) { - patterns.insert<DeclareOpConversion, CUFDeviceAddressOpConversion>( - patterns.getContext(), symtab); + patterns.insert<DeclareOpConversion>(patterns.getContext(), symtab); } diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversionLate.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversionLate.cpp new file mode 100644 index 0000000..fe45971 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversionLate.cpp @@ -0,0 +1,120 @@ +//===-- CUFOpConversionLate.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/CUFCommon.h" +#include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h" +#include "flang/Optimizer/Builder/Runtime/RTBuilder.h" +#include "flang/Optimizer/CodeGen/TypeConverter.h" +#include "flang/Optimizer/Dialect/CUF/CUFOps.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "flang/Runtime/CUDA/common.h" +#include "flang/Runtime/CUDA/descriptor.h" +#include "flang/Runtime/allocatable.h" +#include "flang/Runtime/allocator-registry-consts.h" +#include "flang/Support/Fortran.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace fir { +#define GEN_PASS_DEF_CUFOPCONVERSIONLATE +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +using namespace fir; +using namespace mlir; +using namespace Fortran::runtime; +using namespace Fortran::runtime::cuda; + +namespace { + +static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter, + mlir::Location loc, mlir::Type toTy, + mlir::Value val) { + if (val.getType() != toTy) + return fir::ConvertOp::create(rewriter, loc, toTy, val); + return val; +} + +struct CUFDeviceAddressOpConversion + : public mlir::OpRewritePattern<cuf::DeviceAddressOp> { + using OpRewritePattern::OpRewritePattern; + + CUFDeviceAddressOpConversion(mlir::MLIRContext *context, + const mlir::SymbolTable &symtab) + : OpRewritePattern(context), symTab{symtab} {} + + mlir::LogicalResult + matchAndRewrite(cuf::DeviceAddressOp op, + mlir::PatternRewriter &rewriter) const override { + if (auto global = symTab.lookup<fir::GlobalOp>( + op.getHostSymbol().getRootReference().getValue())) { + auto mod = op->getParentOfType<mlir::ModuleOp>(); + mlir::Location loc = op.getLoc(); + auto hostAddr = fir::AddrOfOp::create( + rewriter, loc, fir::ReferenceType::get(global.getType()), + op.getHostSymbol()); + fir::FirOpBuilder builder(rewriter, mod); + mlir::func::FuncOp callee = + fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, + builder); + auto fTy = callee.getFunctionType(); + mlir::Value conv = + createConvertOp(rewriter, loc, fTy.getInput(0), hostAddr); + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); + llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( + builder, loc, fTy, conv, sourceFile, sourceLine)}; + auto call = fir::CallOp::create(rewriter, loc, callee, args); + mlir::Value addr = createConvertOp(rewriter, loc, hostAddr.getType(), + call->getResult(0)); + rewriter.replaceOp(op, addr.getDefiningOp()); + return success(); + } + return failure(); + } + +private: + const mlir::SymbolTable &symTab; +}; + +class CUFOpConversionLate + : public fir::impl::CUFOpConversionLateBase<CUFOpConversionLate> { + using CUFOpConversionLateBase::CUFOpConversionLateBase; + +public: + void runOnOperation() override { + auto *ctx = &getContext(); + mlir::RewritePatternSet patterns(ctx); + mlir::ConversionTarget target(*ctx); + mlir::Operation *op = getOperation(); + mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op); + if (!module) + return signalPassFailure(); + mlir::SymbolTable symtab(module); + target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, + mlir::gpu::GPUDialect>(); + patterns.insert<CUFDeviceAddressOpConversion>(patterns.getContext(), + symtab); + if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + mlir::emitError(mlir::UnknownLoc::get(ctx), + "error in CUF op conversion\n"); + signalPassFailure(); + } + } +}; +} // namespace diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp new file mode 100644 index 0000000..3eb6559 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp @@ -0,0 +1,153 @@ +//===-- CUFPredefinedVarToGPU.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIROpsSupport.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Pass/Pass.h" + +namespace fir { +#define GEN_PASS_DEF_CUFPREDEFINEDVARTOGPU +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +using namespace mlir; + +namespace { + +template <typename OpTyX, typename OpTyY, typename OpTyZ> +static void createForAllDimensions(mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value c1, + SmallVectorImpl<mlir::Value> &values, + bool incrementByOne = false) { + if (incrementByOne) { + auto baseX = OpTyX::create(builder, loc, builder.getI32Type()); + values.push_back(mlir::arith::AddIOp::create(builder, loc, baseX, c1)); + auto baseY = OpTyY::create(builder, loc, builder.getI32Type()); + values.push_back(mlir::arith::AddIOp::create(builder, loc, baseY, c1)); + auto baseZ = OpTyZ::create(builder, loc, builder.getI32Type()); + values.push_back(mlir::arith::AddIOp::create(builder, loc, baseZ, c1)); + } else { + values.push_back(OpTyX::create(builder, loc, builder.getI32Type())); + values.push_back(OpTyY::create(builder, loc, builder.getI32Type())); + values.push_back(OpTyZ::create(builder, loc, builder.getI32Type())); + } +} + +static constexpr llvm::StringRef builtinsModuleName = "__fortran_builtins"; +static constexpr llvm::StringRef builtinVarPrefix = "__builtin_"; +static constexpr llvm::StringRef threadidx = "threadidx"; +static constexpr llvm::StringRef blockidx = "blockidx"; +static constexpr llvm::StringRef blockdim = "blockdim"; +static constexpr llvm::StringRef griddim = "griddim"; + +static constexpr unsigned field_x = 0; +static constexpr unsigned field_y = 1; +static constexpr unsigned field_z = 2; + +std::string mangleBuiltin(llvm::StringRef varName) { + return "_QM" + builtinsModuleName.str() + "E" + builtinVarPrefix.str() + + varName.str(); +} + +static void processCoordinateOp(mlir::OpBuilder &builder, mlir::Location loc, + fir::CoordinateOp coordOp, unsigned fieldIdx, + mlir::Value &gpuValue) { + std::optional<llvm::ArrayRef<int32_t>> fieldIndices = + coordOp.getFieldIndices(); + assert(fieldIndices && fieldIndices->size() == 1 && + "expect only one coordinate"); + if (static_cast<unsigned>((*fieldIndices)[0]) == fieldIdx) { + llvm::SmallVector<fir::LoadOp> opToErase; + for (mlir::OpOperand &coordUse : coordOp.getResult().getUses()) { + assert(mlir::isa<fir::LoadOp>(coordUse.getOwner()) && + "only expect load op"); + auto loadOp = mlir::dyn_cast<fir::LoadOp>(coordUse.getOwner()); + loadOp.getResult().replaceAllUsesWith(gpuValue); + opToErase.push_back(loadOp); + } + for (auto op : opToErase) + op.erase(); + } +} + +static void +processDeclareOp(mlir::OpBuilder &builder, mlir::Location loc, + fir::DeclareOp declareOp, llvm::StringRef builtinVar, + llvm::SmallVectorImpl<mlir::Value> &gpuValues, + llvm::SmallVectorImpl<mlir::Operation *> &opsToDelete) { + if (declareOp.getUniqName().str().compare(builtinVar) == 0) { + for (mlir::OpOperand &use : declareOp.getResult().getUses()) { + fir::CoordinateOp coordOp = + mlir::dyn_cast<fir::CoordinateOp>(use.getOwner()); + processCoordinateOp(builder, loc, coordOp, field_x, gpuValues[0]); + processCoordinateOp(builder, loc, coordOp, field_y, gpuValues[1]); + processCoordinateOp(builder, loc, coordOp, field_z, gpuValues[2]); + opsToDelete.push_back(coordOp); + } + opsToDelete.push_back(declareOp.getOperation()); + if (declareOp.getMemref().getDefiningOp()) + opsToDelete.push_back(declareOp.getMemref().getDefiningOp()); + } +} + +struct CUFPredefinedVarToGPU + : public fir::impl::CUFPredefinedVarToGPUBase<CUFPredefinedVarToGPU> { + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + if (funcOp.getBody().empty()) + return; + + if (auto cudaProcAttr = + funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>( + cuf::getProcAttrName())) { + if (cudaProcAttr.getValue() == cuf::ProcAttribute::Device || + cudaProcAttr.getValue() == cuf::ProcAttribute::Global || + cudaProcAttr.getValue() == cuf::ProcAttribute::GridGlobal || + cudaProcAttr.getValue() == cuf::ProcAttribute::HostDevice) { + mlir::Location loc = funcOp.getLoc(); + mlir::OpBuilder builder(funcOp.getContext()); + builder.setInsertionPointToStart(&funcOp.getBody().front()); + auto c1 = mlir::arith::ConstantOp::create( + builder, loc, builder.getI32Type(), builder.getI32IntegerAttr(1)); + llvm::SmallVector<mlir::Value, 3> threadids, blockids, blockdims, + griddims; + createForAllDimensions<mlir::NVVM::ThreadIdXOp, mlir::NVVM::ThreadIdYOp, + mlir::NVVM::ThreadIdZOp>( + builder, loc, c1, threadids, /*incrementByOne=*/true); + createForAllDimensions<mlir::NVVM::BlockIdXOp, mlir::NVVM::BlockIdYOp, + mlir::NVVM::BlockIdZOp>( + builder, loc, c1, blockids, /*incrementByOne=*/true); + createForAllDimensions<mlir::NVVM::GridDimXOp, mlir::NVVM::GridDimYOp, + mlir::NVVM::GridDimZOp>(builder, loc, c1, + griddims); + createForAllDimensions<mlir::NVVM::BlockDimXOp, mlir::NVVM::BlockDimYOp, + mlir::NVVM::BlockDimZOp>(builder, loc, c1, + blockdims); + + llvm::SmallVector<mlir::Operation *> opsToDelete; + for (auto declareOp : funcOp.getOps<fir::DeclareOp>()) { + processDeclareOp(builder, loc, declareOp, mangleBuiltin(threadidx), + threadids, opsToDelete); + processDeclareOp(builder, loc, declareOp, mangleBuiltin(blockidx), + blockids, opsToDelete); + processDeclareOp(builder, loc, declareOp, mangleBuiltin(blockdim), + blockdims, opsToDelete); + processDeclareOp(builder, loc, declareOp, mangleBuiltin(griddim), + griddims, opsToDelete); + } + + for (auto op : opsToDelete) + op->erase(); + } + } + } +}; + +} // end anonymous namespace diff --git a/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp b/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp index 00fdb5a..7c8ee09 100644 --- a/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp +++ b/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp @@ -438,6 +438,7 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType( context, llvm::dwarf::DW_TAG_member, mlir::StringAttr::get(context, fieldName), elemTy, byteSize * 8, byteAlign * 8, offset * 8, /*optional<address space>=*/std::nullopt, + /*flags=*/mlir::LLVM::DIFlags::Zero, /*extra data=*/nullptr); elements.push_back(tyAttr); offset += llvm::alignTo(byteSize, byteAlign); @@ -480,6 +481,7 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertTupleType( context, llvm::dwarf::DW_TAG_member, mlir::StringAttr::get(context, ""), elemTy, byteSize * 8, byteAlign * 8, offset * 8, /*optional<address space>=*/std::nullopt, + /*flags=*/mlir::LLVM::DIFlags::Zero, /*extra data=*/nullptr); elements.push_back(tyAttr); offset += llvm::alignTo(byteSize, byteAlign); @@ -528,13 +530,10 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertSequenceType( if (dim == seqTy.getUnknownExtent()) { // This path is taken for both assumed size array or when the size of the // array is variable. In the case of variable size, we create a variable - // to use as countAttr. Note that fir has a constant size of -1 for - // assumed size array. So !optint check makes sure we don't generate - // variable in that case. + // to use as countAttr. if (declOp && declOp.getShape().size() > index) { - std::optional<std::int64_t> optint = - getIntIfConstant(declOp.getShape()[index]); - if (!optint) + if (!llvm::isa_and_nonnull<fir::AssumedSizeExtentOp>( + declOp.getShape()[index].getDefiningOp())) countAttr = generateArtificialVariable( context, declOp.getShape()[index], fileAttr, scope, declOp); } @@ -676,7 +675,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertPointerLikeType( context, llvm::dwarf::DW_TAG_pointer_type, mlir::StringAttr::get(context, ""), elTyAttr, /*sizeInBits=*/ptrSize * 8, /*alignInBits=*/0, /*offset=*/0, - /*optional<address space>=*/std::nullopt, /*extra data=*/nullptr); + /*optional<address space>=*/std::nullopt, + /*flags=*/mlir::LLVM::DIFlags::Zero, /*extra data=*/nullptr); } static mlir::StringAttr getBasicTypeName(mlir::MLIRContext *context, @@ -721,6 +721,32 @@ DebugTypeGenerator::convertType(mlir::Type Ty, mlir::LLVM::DIFileAttr fileAttr, return convertRecordType(recTy, fileAttr, scope, declOp); } else if (auto tupleTy = mlir::dyn_cast_if_present<mlir::TupleType>(Ty)) { return convertTupleType(tupleTy, fileAttr, scope, declOp); + } else if (mlir::isa<mlir::FunctionType>(Ty)) { + // Handle function types - these represent procedure pointers after the + // BoxedProcedure pass has run and unwrapped the fir.boxproc type, as well + // as dummy procedures (which are represented as function types in FIR) + llvm::SmallVector<mlir::LLVM::DITypeAttr> types; + + auto funcTy = mlir::cast<mlir::FunctionType>(Ty); + // Add return type (or void if no return type) + if (funcTy.getNumResults() == 0) + types.push_back(mlir::LLVM::DINullTypeAttr::get(context)); + else + types.push_back( + convertType(funcTy.getResult(0), fileAttr, scope, declOp)); + + for (mlir::Type paramTy : funcTy.getInputs()) + types.push_back(convertType(paramTy, fileAttr, scope, declOp)); + + auto subroutineTy = mlir::LLVM::DISubroutineTypeAttr::get( + context, /*callingConvention=*/0, types); + + return mlir::LLVM::DIDerivedTypeAttr::get( + context, llvm::dwarf::DW_TAG_pointer_type, + mlir::StringAttr::get(context, ""), subroutineTy, + /*sizeInBits=*/ptrSize * 8, /*alignInBits=*/0, /*offset=*/0, + /*optional<address space>=*/std::nullopt, + /*flags=*/mlir::LLVM::DIFlags::Zero, /*extra data=*/nullptr); } else if (auto refTy = mlir::dyn_cast_if_present<fir::ReferenceType>(Ty)) { auto elTy = refTy.getEleTy(); return convertPointerLikeType(elTy, fileAttr, scope, declOp, diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp new file mode 100644 index 0000000..bf125eb --- /dev/null +++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp @@ -0,0 +1,1061 @@ +//===-- FIRToMemRef.cpp - Convert FIR loads and stores to MemRef ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass lowers FIR dialect memory operations to the MemRef dialect. +// In particular it: +// +// - Rewrites `fir.alloca` to `memref.alloca`. +// +// - Rewrites `fir.load` / `fir.store` to `memref.load` / `memref.store`. +// +// - Allows FIR and MemRef to coexist by introducing `fir.convert` at +// memory-use sites. Memory operations (`memref.load`, `memref.store`, +// `memref.reinterpret_cast`, etc.) see MemRef-typed values, while the +// original FIR-typed values remain available for non-memory uses. For +// example: +// +// %fir_ref = ... : !fir.ref<!fir.array<...>> +// %memref = fir.convert %fir_ref +// : !fir.ref<!fir.array<...>> -> memref<...> +// %val = memref.load %memref[...] : memref<...> +// fir.call @callee(%fir_ref) : (!fir.ref<!fir.array<...>>) -> () +// +// Here the MemRef-typed value is used for `memref.load`, while the +// original FIR-typed value is preserved for `fir.call`. +// +// - Computes shapes, strides, and indices as needed for slices and shifts +// and emits `memref.reinterpret_cast` when dynamic layout is required +// (TODO: use memref.cast instead). +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/CUFCommon.h" +#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIROpsSupport.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Dialect/Support/FIRContext.h" +#include "flang/Optimizer/Dialect/Support/KindMapping.h" +#include "flang/Optimizer/Transforms/FIRToMemRefTypeConverter.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "fir-to-memref" + +using namespace mlir; + +namespace fir { + +#define GEN_PASS_DEF_FIRTOMEMREF +#include "flang/Optimizer/Transforms/Passes.h.inc" + +static bool isMarshalLike(Operation *op) { + auto convert = dyn_cast_if_present<fir::ConvertOp>(op); + if (!convert) + return false; + + bool resIsMemRef = isa<MemRefType>(convert.getType()); + bool argIsMemRef = isa<MemRefType>(convert.getValue().getType()); + + assert(!(resIsMemRef && argIsMemRef) && + "unexpected fir.convert memref -> memref in isMarshalLike"); + + return resIsMemRef || argIsMemRef; +} + +using MemRefInfo = FailureOr<std::pair<Value, SmallVector<Value>>>; + +static llvm::cl::opt<bool> enableFIRConvertOptimizations( + "enable-fir-convert-opts", + llvm::cl::desc("enable emilinating redundant fir.convert in FIR-to-MemRef"), + llvm::cl::init(false), llvm::cl::Hidden); + +class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> { +public: + void runOnOperation() override; + +private: + llvm::SmallSetVector<Operation *, 32> eraseOps; + + DominanceInfo *domInfo = nullptr; + + void rewriteAlloca(fir::AllocaOp, PatternRewriter &, + FIRToMemRefTypeConverter &); + + void rewriteLoadOp(fir::LoadOp, PatternRewriter &, + FIRToMemRefTypeConverter &); + + void rewriteStoreOp(fir::StoreOp, PatternRewriter &, + FIRToMemRefTypeConverter &); + + MemRefInfo getMemRefInfo(Value, PatternRewriter &, FIRToMemRefTypeConverter &, + Operation *); + + MemRefInfo convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp, + PatternRewriter &, FIRToMemRefTypeConverter &); + + void replaceFIRMemrefs(Value, Value, PatternRewriter &) const; + + FailureOr<Value> getFIRConvert(Operation *memOp, Operation *memref, + PatternRewriter &, FIRToMemRefTypeConverter &); + + FailureOr<SmallVector<Value>> getMemrefIndices(fir::ArrayCoorOp, Operation *, + PatternRewriter &, Value, + Value) const; + + bool memrefIsOptional(Operation *) const; + + Value canonicalizeIndex(Value, PatternRewriter &) const; + + template <typename OpTy> + void getShapeFrom(OpTy op, SmallVector<Value> &shapeVec, + SmallVector<Value> &shiftVec, + SmallVector<Value> &sliceVec) const; + + void populateShapeAndShift(SmallVectorImpl<Value> &shapeVec, + SmallVectorImpl<Value> &shiftVec, + fir::ShapeShiftOp shift) const; + + void populateShift(SmallVectorImpl<Value> &vec, fir::ShiftOp shift) const; + + void populateShape(SmallVectorImpl<Value> &vec, fir::ShapeOp shape) const; + + unsigned getRankFromEmbox(fir::EmboxOp embox) const { + auto memrefType = embox.getMemref().getType(); + Type unwrappedType = fir::unwrapRefType(memrefType); + if (auto seqType = dyn_cast<fir::SequenceType>(unwrappedType)) + return seqType.getDimension(); + return 0; + } + + bool isCompilerGeneratedAlloca(Operation *op) const; + + void copyAttribute(Operation *from, Operation *to, + llvm::StringRef name) const; + + Type getBaseType(Type type, bool complexBaseTypes = false) const; + + bool memrefIsDeviceData(Operation *memref) const; + + mlir::Attribute findCudaDataAttr(Value val) const; +}; + +void FIRToMemRef::populateShapeAndShift(SmallVectorImpl<Value> &shapeVec, + SmallVectorImpl<Value> &shiftVec, + fir::ShapeShiftOp shift) const { + for (mlir::OperandRange::iterator i = shift.getPairs().begin(), + endIter = shift.getPairs().end(); + i != endIter;) { + shiftVec.push_back(*i++); + shapeVec.push_back(*i++); + } +} + +bool FIRToMemRef::isCompilerGeneratedAlloca(Operation *op) const { + if (!isa<fir::AllocaOp, memref::AllocaOp>(op)) + llvm_unreachable("expected alloca op"); + + return !op->getAttr("bindc_name") && !op->getAttr("uniq_name"); +} + +void FIRToMemRef::copyAttribute(Operation *from, Operation *to, + llvm::StringRef name) const { + if (Attribute value = from->getAttr(name)) + to->setAttr(name, value); +} + +Type FIRToMemRef::getBaseType(Type type, bool complexBaseTypes) const { + if (fir::isa_fir_type(type)) { + type = fir::getFortranElementType(type); + } else if (auto memrefTy = dyn_cast<MemRefType>(type)) { + type = memrefTy.getElementType(); + } + + if (!complexBaseTypes) + if (auto complexTy = dyn_cast<ComplexType>(type)) + type = complexTy.getElementType(); + return type; +} + +bool FIRToMemRef::memrefIsDeviceData(Operation *memref) const { + if (isa<ACC_DATA_ENTRY_OPS>(memref)) + return true; + + return cuf::hasDeviceDataAttr(memref); +} + +mlir::Attribute FIRToMemRef::findCudaDataAttr(Value val) const { + Value currentVal = val; + llvm::SmallPtrSet<Operation *, 8> visited; + + while (currentVal) { + Operation *defOp = currentVal.getDefiningOp(); + if (!defOp || !visited.insert(defOp).second) + break; + + if (cuf::DataAttributeAttr cudaAttr = cuf::getDataAttr(defOp)) + return cudaAttr; + + // TODO: This is a best-effort backward walk; it is easy to miss attributes + // as FIR evolves. Long term, it would be preferable if the necessary + // information was carried in the type system (or otherwise made available + // without relying on a walk-back through defining ops). + if (auto reboxOp = dyn_cast<fir::ReboxOp>(defOp)) { + currentVal = reboxOp.getBox(); + } else if (auto convertOp = dyn_cast<fir::ConvertOp>(defOp)) { + currentVal = convertOp->getOperand(0); + } else if (auto emboxOp = dyn_cast<fir::EmboxOp>(defOp)) { + currentVal = emboxOp.getMemref(); + } else if (auto boxAddrOp = dyn_cast<fir::BoxAddrOp>(defOp)) { + currentVal = boxAddrOp.getVal(); + } else if (auto declareOp = dyn_cast<fir::DeclareOp>(defOp)) { + currentVal = declareOp.getMemref(); + } else { + break; + } + } + return nullptr; +} + +void FIRToMemRef::populateShift(SmallVectorImpl<Value> &vec, + fir::ShiftOp shift) const { + vec.append(shift.getOrigins().begin(), shift.getOrigins().end()); +} + +void FIRToMemRef::populateShape(SmallVectorImpl<Value> &vec, + fir::ShapeOp shape) const { + vec.append(shape.getExtents().begin(), shape.getExtents().end()); +} + +template <typename OpTy> +void FIRToMemRef::getShapeFrom(OpTy op, SmallVector<Value> &shapeVec, + SmallVector<Value> &shiftVec, + SmallVector<Value> &sliceVec) const { + if constexpr (std::is_same_v<OpTy, fir::ArrayCoorOp> || + std::is_same_v<OpTy, fir::ReboxOp> || + std::is_same_v<OpTy, fir::EmboxOp>) { + Value shapeVal = op.getShape(); + + if (shapeVal) { + Operation *shapeValOp = shapeVal.getDefiningOp(); + + if (auto shapeOp = dyn_cast<fir::ShapeOp>(shapeValOp)) { + populateShape(shapeVec, shapeOp); + } else if (auto shapeShiftOp = dyn_cast<fir::ShapeShiftOp>(shapeValOp)) { + populateShapeAndShift(shapeVec, shiftVec, shapeShiftOp); + } else if (auto shiftOp = dyn_cast<fir::ShiftOp>(shapeValOp)) { + populateShift(shiftVec, shiftOp); + } + } + + Value sliceVal = op.getSlice(); + if (sliceVal) { + if (auto sliceOp = sliceVal.getDefiningOp<fir::SliceOp>()) { + auto triples = sliceOp.getTriples(); + sliceVec.append(triples.begin(), triples.end()); + } + } + } +} + +void FIRToMemRef::rewriteAlloca(fir::AllocaOp firAlloca, + PatternRewriter &rewriter, + FIRToMemRefTypeConverter &typeConverter) { + if (!typeConverter.convertibleType(firAlloca.getInType())) + return; + + if (typeConverter.isEmptyArray(firAlloca.getType())) + return; + + rewriter.setInsertionPointAfter(firAlloca); + + Type type = firAlloca.getType(); + MemRefType memrefTy = typeConverter.convertMemrefType(type); + + Location loc = firAlloca.getLoc(); + + SmallVector<Value> sizes = firAlloca.getOperands(); + std::reverse(sizes.begin(), sizes.end()); + + auto alloca = memref::AllocaOp::create(rewriter, loc, memrefTy, sizes); + copyAttribute(firAlloca, alloca, firAlloca.getBindcNameAttrName()); + copyAttribute(firAlloca, alloca, firAlloca.getUniqNameAttrName()); + copyAttribute(firAlloca, alloca, cuf::getDataAttrName()); + + auto convert = fir::ConvertOp::create(rewriter, loc, type, alloca); + + rewriter.replaceOp(firAlloca, convert); + + if (isCompilerGeneratedAlloca(alloca)) { + for (Operation *userOp : convert->getUsers()) { + if (auto declareOp = dyn_cast<fir::DeclareOp>(userOp)) { + LLVM_DEBUG(llvm::dbgs() + << "FIRToMemRef: removing declare for compiler temp:\n"; + declareOp->dump()); + declareOp->replaceAllUsesWith(convert); + eraseOps.insert(userOp); + } + } + } +} + +bool FIRToMemRef::memrefIsOptional(Operation *op) const { + if (auto declare = dyn_cast<fir::DeclareOp>(op)) { + if (fir::FortranVariableOpInterface(declare).isOptional()) + return true; + + Value operand = declare.getMemref(); + Operation *operandOp = operand.getDefiningOp(); + if (operandOp && isa<fir::AbsentOp>(operandOp)) + return true; + } + + for (mlir::Value result : op->getResults()) + for (mlir::Operation *userOp : result.getUsers()) + if (isa<fir::IsPresentOp>(userOp)) + return true; + + // TODO: If `op` is not a `fir.declare`, OPTIONAL information may still be + // present on a related `fir.declare` reached by tracing the address/box + // through common forwarding ops (e.g. `fir.convert`, `fir.rebox`, + // `fir.embox`, `fir.box_addr`), then checking `declare.isOptional()`. Add the + // search after FIR improves on it. + return false; +} + +static Value castTypeToIndexType(Value originalValue, + PatternRewriter &rewriter) { + if (originalValue.getType().isIndex()) + return originalValue; + + Type indexType = rewriter.getIndexType(); + return arith::IndexCastOp::create(rewriter, originalValue.getLoc(), indexType, + originalValue); +} + +FailureOr<SmallVector<Value>> +FIRToMemRef::getMemrefIndices(fir::ArrayCoorOp arrayCoorOp, Operation *memref, + PatternRewriter &rewriter, Value converted, + Value one) const { + IndexType indexTy = rewriter.getIndexType(); + SmallVector<Value> indices; + Location loc = arrayCoorOp->getLoc(); + SmallVector<Value> shiftVec, shapeVec, sliceVec; + int rank = arrayCoorOp.getIndices().size(); + getShapeFrom<fir::ArrayCoorOp>(arrayCoorOp, shapeVec, shiftVec, sliceVec); + + if (auto embox = dyn_cast_or_null<fir::EmboxOp>(memref)) { + getShapeFrom<fir::EmboxOp>(embox, shapeVec, shiftVec, sliceVec); + rank = getRankFromEmbox(embox); + } + + SmallVector<Value> sliceLbs, sliceStrides; + for (size_t i = 0; i < sliceVec.size(); i += 3) { + sliceLbs.push_back(castTypeToIndexType(sliceVec[i], rewriter)); + sliceStrides.push_back(castTypeToIndexType(sliceVec[i + 2], rewriter)); + } + + const bool isShifted = !shiftVec.empty(); + const bool isSliced = !sliceVec.empty(); + + ValueRange idxs = arrayCoorOp.getIndices(); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + + SmallVector<bool> filledPositions(rank, false); + for (int i = 0; i < rank; ++i) { + Value step = isSliced ? sliceStrides[i] : one; + Operation *stepOp = step.getDefiningOp(); + if (stepOp && mlir::isa_and_nonnull<fir::UndefOp>(stepOp)) { + Value shift = isShifted ? shiftVec[i] : one; + Value sliceLb = isSliced ? sliceLbs[i] : shift; + Value offset = arith::SubIOp::create(rewriter, loc, sliceLb, shift); + indices.push_back(offset); + filledPositions[i] = true; + } else { + indices.push_back(zero); + } + } + + int arrayCoorIdx = 0; + for (int i = 0; i < rank; ++i) { + if (filledPositions[i]) + continue; + + assert((unsigned int)arrayCoorIdx < idxs.size() && + "empty dimension should be eliminated\n"); + Value index = canonicalizeIndex(idxs[arrayCoorIdx], rewriter); + Type cTy = index.getType(); + if (!llvm::isa<IndexType>(cTy)) { + assert(cTy.isSignlessInteger() && "expected signless integer type"); + index = arith::IndexCastOp::create(rewriter, loc, indexTy, index); + } + + Value shift = isShifted ? shiftVec[i] : one; + Value stride = isSliced ? sliceStrides[i] : one; + Value sliceLb = isSliced ? sliceLbs[i] : shift; + + Value oneIdx = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value indexAdjustment = isSliced ? oneIdx : sliceLb; + Value delta = arith::SubIOp::create(rewriter, loc, index, indexAdjustment); + + Value scaled = arith::MulIOp::create(rewriter, loc, delta, stride); + + Value offset = arith::SubIOp::create(rewriter, loc, sliceLb, shift); + + Value finalIndex = arith::AddIOp::create(rewriter, loc, scaled, offset); + + indices[i] = finalIndex; + arrayCoorIdx++; + } + + std::reverse(indices.begin(), indices.end()); + + return indices; +} + +MemRefInfo +FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp, + PatternRewriter &rewriter, + FIRToMemRefTypeConverter &typeConverter) { + IndexType indexTy = rewriter.getIndexType(); + Value firMemref = arrayCoorOp.getMemref(); + if (!typeConverter.convertibleMemrefType(firMemref.getType())) + return failure(); + + if (typeConverter.isEmptyArray(firMemref.getType())) + return failure(); + + if (auto blockArg = dyn_cast<BlockArgument>(firMemref)) { + Value elemRef = arrayCoorOp.getResult(); + rewriter.setInsertionPointAfter(arrayCoorOp); + Location loc = arrayCoorOp->getLoc(); + Type elemMemrefTy = typeConverter.convertMemrefType(elemRef.getType()); + Value converted = + fir::ConvertOp::create(rewriter, loc, elemMemrefTy, elemRef); + SmallVector<Value> indices; + return std::pair{converted, indices}; + } + + Operation *memref = firMemref.getDefiningOp(); + + FailureOr<Value> converted; + if (enableFIRConvertOptimizations && isMarshalLike(memref) && + !fir::isa_fir_type(firMemref.getType())) { + converted = firMemref; + rewriter.setInsertionPoint(arrayCoorOp); + } else { + Operation *arrayCoorOperation = arrayCoorOp.getOperation(); + rewriter.setInsertionPoint(arrayCoorOp); + if (memrefIsOptional(memref)) { + auto ifOp = arrayCoorOperation->getParentOfType<scf::IfOp>(); + if (ifOp) { + Operation *condition = ifOp.getCondition().getDefiningOp(); + if (condition && isa<fir::IsPresentOp>(condition)) + if (condition->getOperand(0) == firMemref) { + if (arrayCoorOperation->getParentRegion() == &ifOp.getThenRegion()) + rewriter.setInsertionPointToStart( + &(ifOp.getThenRegion().front())); + else if (arrayCoorOperation->getParentRegion() == + &ifOp.getElseRegion()) + rewriter.setInsertionPointToStart( + &(ifOp.getElseRegion().front())); + } + } + } + + converted = getFIRConvert(memOp, memref, rewriter, typeConverter); + if (failed(converted)) + return failure(); + + rewriter.setInsertionPointAfter(arrayCoorOp); + } + + Location loc = arrayCoorOp->getLoc(); + Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); + FailureOr<SmallVector<Value>> failureOrIndices = + getMemrefIndices(arrayCoorOp, memref, rewriter, *converted, one); + if (failed(failureOrIndices)) + return failure(); + SmallVector<Value> indices = *failureOrIndices; + + if (converted == firMemref) + return std::pair{*converted, indices}; + + Value convertedVal = *converted; + MemRefType memRefTy = dyn_cast<MemRefType>(convertedVal.getType()); + + bool isRebox = firMemref.getDefiningOp<fir::ReboxOp>() != nullptr; + + if (memRefTy.hasStaticShape() && !isRebox) + return std::pair{*converted, indices}; + + unsigned rank = arrayCoorOp.getIndices().size(); + + if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>()) + rank = getRankFromEmbox(embox); + + SmallVector<Value> sizes; + sizes.reserve(rank); + SmallVector<Value> strides; + strides.reserve(rank); + + SmallVector<Value> shapeVec, shiftVec, sliceVec; + getShapeFrom<fir::ArrayCoorOp>(arrayCoorOp, shapeVec, shiftVec, sliceVec); + + Value box = firMemref; + if (!isa<BlockArgument>(firMemref)) { + if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>()) + getShapeFrom<fir::EmboxOp>(embox, shapeVec, shiftVec, sliceVec); + else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>()) + getShapeFrom<fir::ReboxOp>(rebox, shapeVec, shiftVec, sliceVec); + } + + if (shapeVec.empty()) { + auto boxElementSize = + fir::BoxEleSizeOp::create(rewriter, loc, indexTy, box); + + for (unsigned i = 0; i < rank; ++i) { + Value dim = arith::ConstantIndexOp::create(rewriter, loc, rank - i - 1); + auto boxDims = fir::BoxDimsOp::create(rewriter, loc, indexTy, indexTy, + indexTy, box, dim); + + Value extent = boxDims->getResult(1); + sizes.push_back(castTypeToIndexType(extent, rewriter)); + + Value byteStride = boxDims->getResult(2); + Value div = + arith::DivSIOp::create(rewriter, loc, byteStride, boxElementSize); + strides.push_back(castTypeToIndexType(div, rewriter)); + } + + } else { + Value oneIdx = + arith::ConstantIndexOp::create(rewriter, arrayCoorOp->getLoc(), 1); + for (unsigned i = rank - 1; i > 0; --i) { + Value size = shapeVec[i]; + sizes.push_back(castTypeToIndexType(size, rewriter)); + + Value stride = shapeVec[0]; + for (unsigned j = 1; j <= i - 1; ++j) + stride = arith::MulIOp::create(rewriter, loc, shapeVec[j], stride); + strides.push_back(castTypeToIndexType(stride, rewriter)); + } + + sizes.push_back(castTypeToIndexType(shapeVec[0], rewriter)); + strides.push_back(oneIdx); + } + + assert(strides.size() == sizes.size() && sizes.size() == rank); + + int64_t dynamicOffset = ShapedType::kDynamic; + SmallVector<int64_t> dynamicStrides(rank, ShapedType::kDynamic); + auto stridedLayout = StridedLayoutAttr::get(convertedVal.getContext(), + dynamicOffset, dynamicStrides); + + SmallVector<int64_t> dynamicShape(rank, ShapedType::kDynamic); + memRefTy = + MemRefType::get(dynamicShape, memRefTy.getElementType(), stridedLayout); + + Value offset = arith::ConstantIndexOp::create(rewriter, loc, 0); + + auto reinterpret = memref::ReinterpretCastOp::create( + rewriter, loc, memRefTy, *converted, offset, sizes, strides); + + Value result = reinterpret->getResult(0); + return std::pair{result, indices}; +} + +FailureOr<Value> +FIRToMemRef::getFIRConvert(Operation *memOp, Operation *op, + PatternRewriter &rewriter, + FIRToMemRefTypeConverter &typeConverter) { + if (enableFIRConvertOptimizations && !op->hasOneUse() && + !memrefIsOptional(op)) { + for (Operation *userOp : op->getUsers()) { + if (auto convertOp = dyn_cast<fir::ConvertOp>(userOp)) { + Value converted = convertOp.getResult(); + if (!isa<MemRefType>(converted.getType())) + continue; + + if (userOp->getParentOp() == memOp->getParentOp() && + domInfo->dominates(userOp, memOp)) + return converted; + } + } + } + + assert(op->getNumResults() == 1 && "expecting one result"); + + Value basePtr = op->getResult(0); + + MemRefType memrefTy = typeConverter.convertMemrefType(basePtr.getType()); + Type baseTy = memrefTy.getElementType(); + + if (fir::isa_std_type(baseTy) && memrefTy.getRank() == 0) { + if (auto convertOp = basePtr.getDefiningOp<fir::ConvertOp>()) { + Value input = convertOp.getOperand(); + if (auto alloca = input.getDefiningOp<memref::AllocaOp>()) { + assert(alloca.getType() == memrefTy && "expected same types"); + if (isCompilerGeneratedAlloca(alloca)) + return alloca.getResult(); + } + } + } + + const Location loc = op->getLoc(); + + if (isa<fir::BoxType>(basePtr.getType())) { + Operation *baseOp = basePtr.getDefiningOp(); + auto boxAddrOp = fir::BoxAddrOp::create(rewriter, loc, basePtr); + + if (auto cudaAttr = findCudaDataAttr(basePtr)) + boxAddrOp->setAttr(cuf::getDataAttrName(), cudaAttr); + + basePtr = boxAddrOp; + memrefTy = typeConverter.convertMemrefType(basePtr.getType()); + + if (baseOp) { + auto sameBaseBoxTypes = [&](Type baseType, Type memrefType) -> bool { + Type emboxBaseTy = getBaseType(baseType, true); + Type emboxMemrefTy = getBaseType(memrefType, true); + return emboxBaseTy == emboxMemrefTy; + }; + + if (auto embox = dyn_cast_or_null<fir::EmboxOp>(baseOp)) { + if (!sameBaseBoxTypes(embox.getType(), embox.getMemref().getType())) { + LLVM_DEBUG(llvm::dbgs() + << "FIRToMemRef: embox base type and memref type are not " + "the same, bailing out of conversion\n"); + return failure(); + } + if (embox.getSlice() && + embox.getSlice().getDefiningOp<fir::SliceOp>()) { + Type originalType = embox.getMemref().getType(); + basePtr = embox.getMemref(); + + if (typeConverter.convertibleMemrefType(originalType)) { + auto convertedMemrefTy = + typeConverter.convertMemrefType(originalType); + memrefTy = convertedMemrefTy; + } else { + return failure(); + } + } + } + + if (auto rebox = dyn_cast<fir::ReboxOp>(baseOp)) { + if (!sameBaseBoxTypes(rebox.getType(), rebox.getBox().getType())) { + LLVM_DEBUG(llvm::dbgs() + << "FIRToMemRef: rebox base type and box type are not the " + "same, bailing out of conversion\n"); + return failure(); + } + Type originalType = rebox.getBox().getType(); + if (auto boxTy = dyn_cast<fir::BoxType>(originalType)) + originalType = boxTy.getElementType(); + if (!typeConverter.convertibleMemrefType(originalType)) { + return failure(); + } else { + auto convertedMemrefTy = + typeConverter.convertMemrefType(originalType); + memrefTy = convertedMemrefTy; + } + } + } + } + + auto convert = fir::ConvertOp::create(rewriter, loc, memrefTy, basePtr); + return convert->getResult(0); +} + +Value FIRToMemRef::canonicalizeIndex(Value index, + PatternRewriter &rewriter) const { + if (auto blockArg = dyn_cast<BlockArgument>(index)) + return index; + + Operation *op = index.getDefiningOp(); + + if (auto constant = dyn_cast<arith::ConstantIntOp>(op)) { + if (!constant.getType().isIndex()) { + Value v = arith::ConstantIndexOp::create(rewriter, op->getLoc(), + constant.value()); + return v; + } + return constant; + } + + if (auto extsi = dyn_cast<arith::ExtSIOp>(op)) { + Value operand = extsi.getOperand(); + if (auto indexCast = operand.getDefiningOp<arith::IndexCastOp>()) { + Value v = indexCast.getOperand(); + return v; + } + return canonicalizeIndex(operand, rewriter); + } + + if (auto add = dyn_cast<arith::AddIOp>(op)) { + Value lhs = canonicalizeIndex(add.getLhs(), rewriter); + Value rhs = canonicalizeIndex(add.getRhs(), rewriter); + if (lhs.getType() == rhs.getType()) + return arith::AddIOp::create(rewriter, op->getLoc(), lhs, rhs); + } + return index; +} + +MemRefInfo FIRToMemRef::getMemRefInfo(Value firMemref, + PatternRewriter &rewriter, + FIRToMemRefTypeConverter &typeConverter, + Operation *memOp) { + Operation *memrefOp = firMemref.getDefiningOp(); + if (!memrefOp) { + if (auto blockArg = dyn_cast<BlockArgument>(firMemref)) { + rewriter.setInsertionPoint(memOp); + Type memrefTy = typeConverter.convertMemrefType(blockArg.getType()); + if (auto mt = dyn_cast<MemRefType>(memrefTy)) + if (auto inner = llvm::dyn_cast<MemRefType>(mt.getElementType())) + memrefTy = inner; + Value converted = fir::ConvertOp::create(rewriter, blockArg.getLoc(), + memrefTy, blockArg); + SmallVector<Value> indices; + return std::pair{converted, indices}; + } + llvm_unreachable( + "FIRToMemRef: expected defining op or block argument for FIR memref"); + } + + if (auto arrayCoorOp = dyn_cast<fir::ArrayCoorOp>(memrefOp)) { + MemRefInfo memrefInfo = + convertArrayCoorOp(memOp, arrayCoorOp, rewriter, typeConverter); + if (succeeded(memrefInfo)) { + for (auto user : memrefOp->getUsers()) { + if (!isa<fir::LoadOp, fir::StoreOp>(user)) { + LLVM_DEBUG( + llvm::dbgs() + << "FIRToMemRef: array memref used by unsupported op:\n"; + firMemref.dump(); user->dump()); + return memrefInfo; + } + } + eraseOps.insert(memrefOp); + } + return memrefInfo; + } + + rewriter.setInsertionPoint(memOp); + + if (isMarshalLike(memrefOp)) { + FailureOr<Value> converted = + getFIRConvert(memOp, memrefOp, rewriter, typeConverter); + if (failed(converted)) { + LLVM_DEBUG(llvm::dbgs() + << "FIRToMemRef: expected FIR memref in convert, bailing " + "out:\n"; + firMemref.dump()); + return failure(); + } + SmallVector<Value> indices; + return std::pair{*converted, indices}; + } + + if (auto declareOp = dyn_cast<fir::DeclareOp>(memrefOp)) { + FailureOr<Value> converted = + getFIRConvert(memOp, declareOp, rewriter, typeConverter); + if (failed(converted)) { + LLVM_DEBUG(llvm::dbgs() + << "FIRToMemRef: unable to create convert for scalar " + "memref:\n"; + firMemref.dump()); + return failure(); + } + SmallVector<Value> indices; + return std::pair{*converted, indices}; + } + + if (auto coordinateOp = dyn_cast<fir::CoordinateOp>(memrefOp)) { + FailureOr<Value> converted = + getFIRConvert(memOp, coordinateOp, rewriter, typeConverter); + if (failed(converted)) { + LLVM_DEBUG( + llvm::dbgs() + << "FIRToMemRef: unable to create convert for derived-type " + "memref:\n"; + firMemref.dump()); + return failure(); + } + SmallVector<Value> indices; + return std::pair{*converted, indices}; + } + + if (auto convertOp = dyn_cast<fir::ConvertOp>(memrefOp)) { + Type fromTy = convertOp->getOperand(0).getType(); + Type toTy = firMemref.getType(); + if (isa<fir::ReferenceType>(fromTy) && isa<fir::ReferenceType>(toTy)) { + FailureOr<Value> converted = + getFIRConvert(memOp, convertOp, rewriter, typeConverter); + if (failed(converted)) { + LLVM_DEBUG( + llvm::dbgs() + << "FIRToMemRef: unable to create convert for conversion " + "op:\n"; + firMemref.dump()); + return failure(); + } + SmallVector<Value> indices; + return std::pair{*converted, indices}; + } + } + + if (auto boxAddrOp = dyn_cast<fir::BoxAddrOp>(memrefOp)) { + FailureOr<Value> converted = + getFIRConvert(memOp, boxAddrOp, rewriter, typeConverter); + if (failed(converted)) { + LLVM_DEBUG(llvm::dbgs() + << "FIRToMemRef: unable to create convert for box_addr " + "op:\n"; + firMemref.dump()); + return failure(); + } + SmallVector<Value> indices; + return std::pair{*converted, indices}; + } + + if (memrefIsDeviceData(memrefOp)) { + FailureOr<Value> converted = + getFIRConvert(memOp, memrefOp, rewriter, typeConverter); + if (failed(converted)) + return failure(); + SmallVector<Value> indices; + return std::pair{*converted, indices}; + } + + LLVM_DEBUG(llvm::dbgs() + << "FIRToMemRef: unable to create convert for memref value:\n"; + firMemref.dump()); + + return failure(); +} + +void FIRToMemRef::replaceFIRMemrefs(Value firMemref, Value converted, + PatternRewriter &rewriter) const { + Operation *op = firMemref.getDefiningOp(); + if (op && (isa<fir::ArrayCoorOp>(op) || isMarshalLike(op))) + return; + + SmallPtrSet<Operation *, 4> worklist; + for (auto user : firMemref.getUsers()) { + if (isMarshalLike(user) || isa<fir::LoadOp, fir::StoreOp>(user)) + continue; + if (!domInfo->dominates(converted, user)) + continue; + if (!(isa<omp::AtomicCaptureOp>(user->getParentOp()) || + isa<acc::AtomicCaptureOp>(user->getParentOp()))) + worklist.insert(user); + } + + Type ty = firMemref.getType(); + + for (auto op : worklist) { + rewriter.setInsertionPoint(op); + Location loc = op->getLoc(); + Value replaceConvert = fir::ConvertOp::create(rewriter, loc, ty, converted); + op->replaceUsesOfWith(firMemref, replaceConvert); + } + + worklist.clear(); + + for (auto user : firMemref.getUsers()) { + if (isMarshalLike(user) || isa<fir::LoadOp, fir::StoreOp>(user)) + continue; + if (isa<omp::AtomicCaptureOp>(user->getParentOp()) || + isa<acc::AtomicCaptureOp>(user->getParentOp())) + if (domInfo->dominates(converted, user)) + worklist.insert(user); + } + + if (worklist.empty()) + return; + + while (!worklist.empty()) { + Operation *parentOp = (*worklist.begin())->getParentOp(); + + Value replaceConvert; + SmallVector<Operation *> erase; + for (auto op : worklist) { + if (op->getParentOp() != parentOp) + continue; + if (!replaceConvert) { + rewriter.setInsertionPoint(parentOp); + replaceConvert = + fir::ConvertOp::create(rewriter, op->getLoc(), ty, converted); + } + op->replaceUsesOfWith(firMemref, replaceConvert); + erase.push_back(op); + } + + for (auto op : erase) + worklist.erase(op); + } +} + +void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter, + FIRToMemRefTypeConverter &typeConverter) { + Value firMemref = load.getMemref(); + if (!typeConverter.convertibleType(firMemref.getType())) + return; + + LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: attempting to convert FIR load:\n"; + load.dump(); firMemref.dump()); + + MemRefInfo memrefInfo = + getMemRefInfo(firMemref, rewriter, typeConverter, load.getOperation()); + if (failed(memrefInfo)) + return; + + Type originalType = load.getResult().getType(); + Value converted = memrefInfo->first; + SmallVector<Value> indices = memrefInfo->second; + + LLVM_DEBUG(llvm::dbgs() + << "FIRToMemRef: convert for FIR load created successfully:\n"; + converted.dump()); + + rewriter.setInsertionPointAfter(load); + + Attribute attr = (load.getOperation())->getAttr("tbaa"); + memref::LoadOp loadOp = + rewriter.replaceOpWithNewOp<memref::LoadOp>(load, converted, indices); + if (attr) + loadOp.getOperation()->setAttr("tbaa", attr); + + LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.load op:\n"; + loadOp.dump(); assert(succeeded(verify(loadOp)))); + + if (isa<fir::LogicalType>(originalType)) { + Value logicalVal = + fir::ConvertOp::create(rewriter, loadOp.getLoc(), originalType, loadOp); + loadOp.getResult().replaceAllUsesExcept(logicalVal, + logicalVal.getDefiningOp()); + } + + if (!isa<fir::LogicalType>(originalType)) + replaceFIRMemrefs(firMemref, converted, rewriter); +} + +void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter, + FIRToMemRefTypeConverter &typeConverter) { + Value firMemref = store.getMemref(); + + if (!typeConverter.convertibleType(firMemref.getType())) + return; + + LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: attempting to convert FIR store:\n"; + store.dump(); firMemref.dump()); + + MemRefInfo memrefInfo = + getMemRefInfo(firMemref, rewriter, typeConverter, store.getOperation()); + if (failed(memrefInfo)) + return; + + Value converted = memrefInfo->first; + SmallVector<Value> indices = memrefInfo->second; + LLVM_DEBUG( + llvm::dbgs() + << "FIRToMemRef: convert for FIR store created successfully:\n"; + converted.dump()); + + Value value = store.getValue(); + rewriter.setInsertionPointAfter(store); + + if (isa<fir::LogicalType>(value.getType())) { + Type convertedType = typeConverter.convertType(value.getType()); + value = + fir::ConvertOp::create(rewriter, store.getLoc(), convertedType, value); + } + + Attribute attr = (store.getOperation())->getAttr("tbaa"); + memref::StoreOp storeOp = rewriter.replaceOpWithNewOp<memref::StoreOp>( + store, value, converted, indices); + if (attr) + storeOp.getOperation()->setAttr("tbaa", attr); + + LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.store op:\n"; + storeOp.dump(); assert(succeeded(verify(storeOp)))); + + bool isLogicalRef = false; + if (fir::ReferenceType refTy = + llvm::dyn_cast<fir::ReferenceType>(firMemref.getType())) + isLogicalRef = llvm::isa<fir::LogicalType>(refTy.getEleTy()); + if (!isLogicalRef) + replaceFIRMemrefs(firMemref, converted, rewriter); +} + +void FIRToMemRef::runOnOperation() { + LLVM_DEBUG(llvm::dbgs() << "Enter FIRToMemRef()\n"); + + func::FuncOp op = getOperation(); + MLIRContext *context = op.getContext(); + ModuleOp mod = op->getParentOfType<ModuleOp>(); + FIRToMemRefTypeConverter typeConverter(mod); + + typeConverter.setConvertComplexTypes(true); + + PatternRewriter rewriter(context); + domInfo = new DominanceInfo(op); + + op.walk([&](fir::AllocaOp alloca) { + rewriteAlloca(alloca, rewriter, typeConverter); + }); + + op.walk([&](Operation *op) { + if (fir::LoadOp loadOp = dyn_cast<fir::LoadOp>(op)) + rewriteLoadOp(loadOp, rewriter, typeConverter); + else if (fir::StoreOp storeOp = dyn_cast<fir::StoreOp>(op)) + rewriteStoreOp(storeOp, rewriter, typeConverter); + }); + + for (auto eraseOp : eraseOps) + rewriter.eraseOp(eraseOp); + eraseOps.clear(); + + if (domInfo) + delete domInfo; + + LLVM_DEBUG(llvm::dbgs() << "After FIRToMemRef()\n"; op.dump(); + llvm::dbgs() << "Exit FIRToMemRef()\n";); +} + +} // namespace fir diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp index 70d6ebb..04ba053 100644 --- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp +++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp @@ -18,6 +18,8 @@ namespace fir { namespace { class FIRToSCFPass : public fir::impl::FIRToSCFPassBase<FIRToSCFPass> { + using FIRToSCFPassBase::FIRToSCFPassBase; + public: void runOnOperation() override; }; @@ -25,11 +27,18 @@ public: struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> { using OpRewritePattern<fir::DoLoopOp>::OpRewritePattern; + DoLoopConversion(mlir::MLIRContext *context, + bool parallelUnorderedLoop = false, + mlir::PatternBenefit benefit = 1) + : OpRewritePattern<fir::DoLoopOp>(context, benefit), + parallelUnorderedLoop(parallelUnorderedLoop) {} + mlir::LogicalResult matchAndRewrite(fir::DoLoopOp doLoopOp, mlir::PatternRewriter &rewriter) const override { mlir::Location loc = doLoopOp.getLoc(); bool hasFinalValue = doLoopOp.getFinalValue().has_value(); + bool isUnordered = doLoopOp.getUnordered().has_value(); // Get loop values from the DoLoopOp mlir::Value low = doLoopOp.getLowerBound(); @@ -53,39 +62,78 @@ struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> { mlir::arith::DivSIOp::create(rewriter, loc, distance, step); auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0); auto one = mlir::arith::ConstantIndexOp::create(rewriter, loc, 1); - auto scfForOp = - mlir::scf::ForOp::create(rewriter, loc, zero, tripCount, one, iterArgs); + // Create the scf.for or scf.parallel operation + mlir::Operation *scfLoopOp = nullptr; + if (isUnordered && parallelUnorderedLoop) { + scfLoopOp = mlir::scf::ParallelOp::create(rewriter, loc, {zero}, + {tripCount}, {one}, iterArgs); + } else { + scfLoopOp = mlir::scf::ForOp::create(rewriter, loc, zero, tripCount, one, + iterArgs); + } + + // Move the body of the fir.do_loop to the scf.for or scf.parallel auto &loopOps = doLoopOp.getBody()->getOperations(); auto resultOp = mlir::cast<fir::ResultOp>(doLoopOp.getBody()->getTerminator()); auto results = resultOp.getOperands(); - mlir::Block *loweredBody = scfForOp.getBody(); + auto scfLoopLikeOp = mlir::cast<mlir::LoopLikeOpInterface>(scfLoopOp); + mlir::Block &scfLoopBody = scfLoopLikeOp.getLoopRegions().front()->front(); - loweredBody->getOperations().splice(loweredBody->begin(), loopOps, - loopOps.begin(), - std::prev(loopOps.end())); + scfLoopBody.getOperations().splice(scfLoopBody.begin(), loopOps, + loopOps.begin(), + std::prev(loopOps.end())); - rewriter.setInsertionPointToStart(loweredBody); + rewriter.setInsertionPointToStart(&scfLoopBody); mlir::Value iv = mlir::arith::MulIOp::create( - rewriter, loc, scfForOp.getInductionVar(), step); + rewriter, loc, scfLoopLikeOp.getSingleInductionVar().value(), step); iv = mlir::arith::AddIOp::create(rewriter, loc, low, iv); + mlir::Value firIV = doLoopOp.getInductionVar(); + firIV.replaceAllUsesWith(iv); + + mlir::Value finalValue; + if (hasFinalValue) { + // Prefer re-using an existing `arith.addi` in the moved loop body if it + // already computes the next `iv + step`. + if (!results.empty()) { + if (auto addOp = results.front().getDefiningOp<mlir::arith::AddIOp>()) { + mlir::Value lhs = addOp.getLhs(); + mlir::Value rhs = addOp.getRhs(); + if ((lhs == iv && rhs == step) || (lhs == step && rhs == iv)) + finalValue = results.front(); + } + } + if (!finalValue) + finalValue = mlir::arith::AddIOp::create(rewriter, loc, iv, step); + } - if (!results.empty()) { - rewriter.setInsertionPointToEnd(loweredBody); - mlir::scf::YieldOp::create(rewriter, resultOp->getLoc(), results); + if (hasFinalValue || !results.empty()) { + rewriter.setInsertionPointToEnd(&scfLoopBody); + llvm::SmallVector<mlir::Value> yieldOperands; + if (hasFinalValue) { + yieldOperands.push_back(finalValue); + llvm::append_range(yieldOperands, results.drop_front()); + } else { + llvm::append_range(yieldOperands, results); + } + mlir::scf::YieldOp::create(rewriter, resultOp->getLoc(), yieldOperands); } - doLoopOp.getInductionVar().replaceAllUsesWith(iv); - rewriter.replaceAllUsesWith(doLoopOp.getRegionIterArgs(), - hasFinalValue - ? scfForOp.getRegionIterArgs().drop_front() - : scfForOp.getRegionIterArgs()); - - // Copy all the attributes from the old to new op. - scfForOp->setAttrs(doLoopOp->getAttrs()); - rewriter.replaceOp(doLoopOp, scfForOp); + rewriter.replaceAllUsesWith( + doLoopOp.getRegionIterArgs(), + hasFinalValue ? scfLoopLikeOp.getRegionIterArgs().drop_front() + : scfLoopLikeOp.getRegionIterArgs()); + + // Copy loop annotations from the fir.do_loop to scf loop op. + if (auto ann = doLoopOp.getLoopAnnotation()) + scfLoopOp->setAttr("loop_annotation", *ann); + + rewriter.replaceOp(doLoopOp, scfLoopOp); return mlir::success(); } + +private: + bool parallelUnorderedLoop; }; struct IterWhileConversion : public mlir::OpRewritePattern<fir::IterWhileOp> { @@ -102,6 +150,7 @@ struct IterWhileConversion : public mlir::OpRewritePattern<fir::IterWhileOp> { mlir::Value okInit = iterWhileOp.getIterateIn(); mlir::ValueRange iterArgs = iterWhileOp.getInitArgs(); + bool hasFinalValue = iterWhileOp.getFinalValue().has_value(); mlir::SmallVector<mlir::Value> initVals; initVals.push_back(lowerBound); @@ -128,10 +177,23 @@ struct IterWhileConversion : public mlir::OpRewritePattern<fir::IterWhileOp> { rewriter.setInsertionPointToStart(&beforeBlock); - mlir::Value inductionCmp = mlir::arith::CmpIOp::create( + // The comparison depends on the sign of the step value. We fully expect + // this expression to be folded by the optimizer or LLVM. This expression + // is written this way so that `step == 0` always returns `false`. + auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0); + auto compl0 = mlir::arith::CmpIOp::create( + rewriter, loc, mlir::arith::CmpIPredicate::slt, zero, step); + auto compl1 = mlir::arith::CmpIOp::create( rewriter, loc, mlir::arith::CmpIPredicate::sle, ivInBefore, upperBound); - mlir::Value cond = mlir::arith::AndIOp::create(rewriter, loc, inductionCmp, - earlyExitInBefore); + auto compl2 = mlir::arith::CmpIOp::create( + rewriter, loc, mlir::arith::CmpIPredicate::slt, step, zero); + auto compl3 = mlir::arith::CmpIOp::create( + rewriter, loc, mlir::arith::CmpIPredicate::sge, ivInBefore, upperBound); + auto cmp0 = mlir::arith::AndIOp::create(rewriter, loc, compl0, compl1); + auto cmp1 = mlir::arith::AndIOp::create(rewriter, loc, compl2, compl3); + auto cmp2 = mlir::arith::OrIOp::create(rewriter, loc, cmp0, cmp1); + mlir::Value cond = + mlir::arith::AndIOp::create(rewriter, loc, earlyExitInBefore, cmp2); mlir::scf::ConditionOp::create(rewriter, loc, cond, argsInBefore); @@ -140,17 +202,22 @@ struct IterWhileConversion : public mlir::OpRewritePattern<fir::IterWhileOp> { auto *afterBody = scfWhileOp.getAfterBody(); auto resultOp = mlir::cast<fir::ResultOp>(afterBody->getTerminator()); - mlir::SmallVector<mlir::Value> results(resultOp->getOperands()); - mlir::Value ivInAfter = scfWhileOp.getAfterArguments()[0]; + mlir::SmallVector<mlir::Value> results; + mlir::Value iv = scfWhileOp.getAfterArguments()[0]; rewriter.setInsertionPointToStart(afterBody); - results[0] = mlir::arith::AddIOp::create(rewriter, loc, ivInAfter, step); + results.push_back(mlir::arith::AddIOp::create(rewriter, loc, iv, step)); + llvm::append_range(results, hasFinalValue + ? resultOp->getOperands().drop_front() + : resultOp->getOperands()); rewriter.setInsertionPointToEnd(afterBody); rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(resultOp, results); scfWhileOp->setAttrs(iterWhileOp->getAttrs()); - rewriter.replaceOp(iterWhileOp, scfWhileOp); + rewriter.replaceOp(iterWhileOp, + hasFinalValue ? scfWhileOp->getResults() + : scfWhileOp->getResults().drop_front()); return mlir::success(); } }; @@ -197,13 +264,14 @@ struct IfConversion : public mlir::OpRewritePattern<fir::IfOp> { }; } // namespace +void fir::populateFIRToSCFRewrites(mlir::RewritePatternSet &patterns, + bool parallelUnordered) { + patterns.add<IterWhileConversion, IfConversion>(patterns.getContext()); + patterns.add<DoLoopConversion>(patterns.getContext(), parallelUnordered); +} + void FIRToSCFPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); - patterns.add<DoLoopConversion, IterWhileConversion, IfConversion>( - patterns.getContext()); + fir::populateFIRToSCFRewrites(patterns, parallelUnordered); walkAndApplyPatterns(getOperation(), std::move(patterns)); } - -std::unique_ptr<mlir::Pass> fir::createFIRToSCFPass() { - return std::make_unique<FIRToSCFPass>(); -} diff --git a/flang/lib/Optimizer/Transforms/FunctionAttr.cpp b/flang/lib/Optimizer/Transforms/FunctionAttr.cpp index 9dfe26cb..3879a80 100644 --- a/flang/lib/Optimizer/Transforms/FunctionAttr.cpp +++ b/flang/lib/Optimizer/Transforms/FunctionAttr.cpp @@ -87,10 +87,6 @@ void FunctionAttrPass::runOnOperation() { func->setAttr(mlir::LLVM::LLVMFuncOp::getInstrumentFunctionExitAttrName( llvmFuncOpName), mlir::StringAttr::get(context, instrumentFunctionExit)); - if (noInfsFPMath) - func->setAttr( - mlir::LLVM::LLVMFuncOp::getNoInfsFpMathAttrName(llvmFuncOpName), - mlir::BoolAttr::get(context, true)); if (noNaNsFPMath) func->setAttr( mlir::LLVM::LLVMFuncOp::getNoNansFpMathAttrName(llvmFuncOpName), @@ -99,10 +95,6 @@ void FunctionAttrPass::runOnOperation() { func->setAttr( mlir::LLVM::LLVMFuncOp::getNoSignedZerosFpMathAttrName(llvmFuncOpName), mlir::BoolAttr::get(context, true)); - if (unsafeFPMath) - func->setAttr( - mlir::LLVM::LLVMFuncOp::getUnsafeFpMathAttrName(llvmFuncOpName), - mlir::BoolAttr::get(context, true)); if (!reciprocals.empty()) func->setAttr( mlir::LLVM::LLVMFuncOp::getReciprocalEstimatesAttrName(llvmFuncOpName), diff --git a/flang/lib/Optimizer/Transforms/LoopInvariantCodeMotion.cpp b/flang/lib/Optimizer/Transforms/LoopInvariantCodeMotion.cpp new file mode 100644 index 0000000..8ebb898 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/LoopInvariantCodeMotion.cpp @@ -0,0 +1,323 @@ +//===- LoopInvariantCodeMotion.cpp ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// FIR-specific Loop Invariant Code Motion pass. +/// The pass relies on FIR types and interfaces to prove the safety +/// of hoisting invariant operations out of loop-like operations. +/// It may be run on both HLFIR and FIR representations. +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Analysis/AliasAnalysis.h" +#include "flang/Optimizer/Dialect/FIROperationMoveOpInterface.h" +#include "flang/Optimizer/Dialect/FIROpsSupport.h" +#include "flang/Optimizer/Dialect/FortranVariableInterface.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/DebugLog.h" + +namespace fir { +#define GEN_PASS_DEF_LOOPINVARIANTCODEMOTION +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +#define DEBUG_TYPE "flang-licm" + +// Temporary engineering option for triaging LICM. +static llvm::cl::opt<bool> disableFlangLICM( + "disable-flang-licm", llvm::cl::init(false), llvm::cl::Hidden, + llvm::cl::desc("Disable Flang's loop invariant code motion")); + +namespace { + +using namespace mlir; + +/// The pass tries to hoist loop invariant operations with only +/// MemoryEffects::Read effects (MemoryEffects::Write support +/// may be added later). +/// The safety of hoisting is proven by: +/// * Proving that the loop runs at least one iteration. +/// * Proving that is is always safe to load from this location +/// (see isSafeToHoistLoad() comments below). +struct LoopInvariantCodeMotion + : fir::impl::LoopInvariantCodeMotionBase<LoopInvariantCodeMotion> { + void runOnOperation() override; +}; + +} // namespace + +/// 'location' is a memory reference used by a memory access. +/// The type of 'location' defines the data type of the access +/// (e.g. it is considered to be invalid to access 'i64' +/// data using '!fir.ref<i32>`). +/// For the given location, this function returns true iff +/// the Fortran object being accessed is a scalar that +/// may not be OPTIONAL. +/// +/// Note that the '!fir.ref<!fir.box<>>' accesses are considered +/// to be scalar, even if the underlying data is an array. +/// +/// Note that an access of '!fir.ref<scalar>' may access +/// an array object. For example: +/// real :: x(:) +/// do i=... +/// = x(10) +/// 'x(10)' accesses array 'x', and it may be unsafe to hoist +/// it without proving that '10' is a valid index for the array. +/// The fact that 'x' is not OPTIONAL does not allow hoisting +/// on its own. +static bool isNonOptionalScalar(Value location) { + while (true) { + LDBG() << "Checking location:\n" << location; + Type dataType = fir::unwrapRefType(location.getType()); + if (!isa<fir::BaseBoxType>(location.getType()) && + (!dataType || + (!isa<fir::BaseBoxType>(dataType) && !fir::isa_trivial(dataType) && + !fir::isa_derived(dataType)))) { + LDBG() << "Failure: data access is not scalar"; + return false; + } + Operation *defOp = location.getDefiningOp(); + if (!defOp) { + // If this is a function argument + auto blockArg = cast<BlockArgument>(location); + Block *block = blockArg.getOwner(); + if (block && block->isEntryBlock()) + if (auto funcOp = + dyn_cast_if_present<FunctionOpInterface>(block->getParentOp())) + if (!funcOp.getArgAttrOfType<UnitAttr>(blockArg.getArgNumber(), + fir::getOptionalAttrName())) { + LDBG() << "Success: is non optional scalar dummy"; + return true; + } + + LDBG() << "Failure: no defining operation"; + return false; + } + + // Scalars "defined" by fir.alloca and fir.address_of + // are present. + if (isa<fir::AllocaOp, fir::AddrOfOp>(defOp)) { + LDBG() << "Success: is non optional scalar"; + return true; + } + + if (auto varIface = dyn_cast<fir::FortranVariableOpInterface>(defOp)) { + if (varIface.isOptional()) { + // The variable is optional, so do not look further. + // Note that it is possible to deduce that the optional + // is actually present, but we are not doing it now. + LDBG() << "Failure: is optional"; + return false; + } + + // In case of MLIR inlining and ASSOCIATE an [hl]fir.declare + // may declare a scalar variable that is actually a "view" + // of an array element. Originally, such [hl]fir.declare + // would be located inside the loop preventing the hoisting. + // But if we decide to hoist such [hl]fir.declare in future, + // we cannot rely on their attributes/types. + // Use reliable checks based on the variable storage. + + // If the variable has storage specifier (e.g. it is a member + // of COMMON, etc.), we can rely that the storage is present, + // and we can also rely on its FortranVariableOpInterface + // definition type (which is a scalar due to previous checks). + if (auto storageIface = + dyn_cast<fir::FortranVariableStorageOpInterface>(defOp)) + if (Value storage = storageIface.getStorage()) { + LDBG() << "Success: is scalar with existing storage"; + return true; + } + + // TODO: we can probably use FIR AliasAnalysis' getSource() + // method to identify the storage in more cases. + Value memref = llvm::TypeSwitch<Operation *, Value>(defOp) + .Case<fir::DeclareOp, hlfir::DeclareOp>( + [](auto op) { return op.getMemref(); }) + .Default([](auto) { return nullptr; }); + + if (memref) + return isNonOptionalScalar(memref); + + LDBG() << "Failure: cannot reason about variable storage"; + return false; + } + if (auto viewIface = dyn_cast<fir::FortranObjectViewOpInterface>(defOp)) { + location = viewIface.getViewSource(cast<OpResult>(location)); + } else { + LDBG() << "Failure: unknown operation:\n" << *defOp; + return false; + } + } +} + +/// Returns true iff it is safe to hoist the given load-like operation 'op', +/// which access given memory 'locations', out of the operation 'loopLike'. +/// The current safety conditions are: +/// * The loop runs at least one iteration, OR +/// * all the accessed locations are inside scalar non-OPTIONAL +/// Fortran objects (Fortran descriptors are considered to be scalars). +static bool isSafeToHoistLoad(Operation *op, ArrayRef<Value> locations, + LoopLikeOpInterface loopLike, + AliasAnalysis &aliasAnalysis) { + for (Value location : locations) + if (aliasAnalysis.getModRef(loopLike.getOperation(), location) + .isModAndRef()) { + LDBG() << "Failure: reads location:\n" + << location << "\nwhich is modified inside the loop"; + return false; + } + + // Check that it is safe to read from all the locations before the loop. + std::optional<llvm::APInt> tripCount = loopLike.getStaticTripCount(); + if (tripCount && !tripCount->isZero()) { + // Loop executes at least one iteration, so it is safe to hoist. + LDBG() << "Success: loop has non-zero iterations"; + return true; + } + + // Check whether the access must always be valid. + return llvm::all_of( + locations, [&](Value location) { return isNonOptionalScalar(location); }); + // TODO: consider hoisting under condition of the loop's trip count + // being non-zero. +} + +/// Returns true iff the given 'op' is a load-like operation, +/// and it can be hoisted out of 'loopLike' operation. +static bool canHoistLoad(Operation *op, LoopLikeOpInterface loopLike, + AliasAnalysis &aliasAnalysis) { + LDBG() << "Checking operation:\n" << *op; + if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) { + SmallVector<MemoryEffects::EffectInstance> effects; + effectInterface.getEffects(effects); + if (effects.empty()) { + LDBG() << "Failure: not a load"; + return false; + } + llvm::SetVector<Value> locations; + for (const MemoryEffects::EffectInstance &effect : effects) { + Value location = effect.getValue(); + if (!isa<MemoryEffects::Read>(effect.getEffect())) { + LDBG() << "Failure: has unsupported effects"; + return false; + } else if (!location) { + LDBG() << "Failure: reads from unknown location"; + return false; + } + locations.insert(location); + } + return isSafeToHoistLoad(op, locations.getArrayRef(), loopLike, + aliasAnalysis); + } + LDBG() << "Failure: has unknown effects"; + return false; +} + +void LoopInvariantCodeMotion::runOnOperation() { + if (disableFlangLICM) { + LDBG() << "Skipping [HL]FIR LoopInvariantCodeMotion()"; + return; + } + + LDBG() << "Enter [HL]FIR LoopInvariantCodeMotion()"; + + auto &aliasAnalysis = getAnalysis<AliasAnalysis>(); + aliasAnalysis.addAnalysisImplementation(fir::AliasAnalysis{}); + + std::function<bool(Operation *, LoopLikeOpInterface loopLike)> + shouldMoveOutOfLoop = [&](Operation *op, LoopLikeOpInterface loopLike) { + if (isPure(op)) { + LDBG() << "Pure operation: " << *op; + return true; + } + + // Handle RecursivelySpeculatable operations that have + // RecursiveMemoryEffects by checking if all their + // nested operations can be hoisted. + auto iface = dyn_cast<ConditionallySpeculatable>(op); + if (iface && iface.getSpeculatability() == + Speculation::RecursivelySpeculatable) { + if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) { + LDBG() << "Checking recursive operation:\n" << *op; + llvm::SmallVector<Operation *> nestedOps; + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (Operation &nestedOp : block) + nestedOps.push_back(&nestedOp); + + bool result = llvm::all_of(nestedOps, [&](Operation *nestedOp) { + return shouldMoveOutOfLoop(nestedOp, loopLike); + }); + LDBG() << "Recursive operation can" << (result ? "" : "not") + << " be hoisted"; + + // If nested operations cannot be hoisted, there is nothing + // else to check. Also if the operation itself does not have + // any memory effects, we can return the result now. + // Otherwise, we have to check the operation itself below. + if (!result || !isa<MemoryEffectOpInterface>(op)) + return result; + } + } + return canHoistLoad(op, loopLike, aliasAnalysis); + }; + + getOperation()->walk([&](LoopLikeOpInterface loopLike) { + if (!fir::canMoveOutOf(loopLike, nullptr)) { + LDBG() << "Cannot hoist anything out of loop operation: "; + LDBG_OS([&](llvm::raw_ostream &os) { + loopLike->print(os, OpPrintingFlags().skipRegions()); + }); + return; + } + // We always hoist operations to the parent operation of the loopLike. + // Check that the parent operation allows the hoisting, e.g. + // omp::LoopWrapperInterface operations assume tight nesting + // of the inner maybe loop-like operations, so hoisting + // to such a parent would be invalid. We rely on + // fir::canMoveFromDescendant() to identify whether the hoisting + // is allowed. + Operation *parentOp = loopLike->getParentOp(); + if (!parentOp) { + LDBG() << "Skipping top-level loop-like operation?"; + return; + } else if (!fir::canMoveFromDescendant(parentOp, loopLike, nullptr)) { + LDBG() << "Cannot hoist anything into operation: "; + LDBG_OS([&](llvm::raw_ostream &os) { + parentOp->print(os, OpPrintingFlags().skipRegions()); + }); + return; + } + moveLoopInvariantCode( + loopLike.getLoopRegions(), + /*isDefinedOutsideRegion=*/ + [&](Value value, Region *) { + return loopLike.isDefinedOutsideOfLoop(value); + }, + /*shouldMoveOutOfRegion=*/ + [&](Operation *op, Region *) { + if (!fir::canMoveOutOf(loopLike, op)) { + LDBG() << "Cannot hoist " << *op << " out of the loop"; + return false; + } + if (!fir::canMoveFromDescendant(parentOp, loopLike, op)) { + LDBG() << "Cannot hoist " << *op << " into the parent of the loop"; + return false; + } + return shouldMoveOutOfLoop(op, loopLike); + }, + /*moveOutOfRegion=*/ + [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); }); + }); + + LDBG() << "Exit [HL]FIR LoopInvariantCodeMotion()"; +} diff --git a/flang/lib/Optimizer/Transforms/MIFOpConversion.cpp b/flang/lib/Optimizer/Transforms/MIFOpConversion.cpp index 206cb9b..fed941c0 100644 --- a/flang/lib/Optimizer/Transforms/MIFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/MIFOpConversion.cpp @@ -16,6 +16,7 @@ #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/Support/DataLayout.h" #include "flang/Optimizer/Support/InternalNames.h" +#include "flang/Runtime/stop.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -67,6 +68,118 @@ genErrmsgPRIF(fir::FirOpBuilder &builder, mlir::Location loc, return {errMsg, errMsgAlloc}; } +static mlir::Value genStatPRIF(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value stat) { + if (!stat) + return fir::AbsentOp::create(builder, loc, getPRIFStatType(builder)); + return stat; +} + +static fir::CallOp genPRIFStopErrorStop(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value stopCode, + bool isError = false) { + mlir::Type stopCharTy = fir::BoxCharType::get(builder.getContext(), 1); + mlir::Type i1Ty = builder.getI1Type(); + mlir::Type i32Ty = builder.getI32Type(); + + mlir::FunctionType ftype = mlir::FunctionType::get( + builder.getContext(), + /*inputs*/ + {builder.getRefType(i1Ty), builder.getRefType(i32Ty), stopCharTy}, + /*results*/ {}); + mlir::func::FuncOp funcOp = + isError + ? builder.createFunction(loc, getPRIFProcName("error_stop"), ftype) + : builder.createFunction(loc, getPRIFProcName("stop"), ftype); + + // QUIET is managed in flang-rt, so its value is set to TRUE here. + mlir::Value q = builder.createBool(loc, true); + mlir::Value quiet = builder.createTemporary(loc, i1Ty); + fir::StoreOp::create(builder, loc, q, quiet); + + mlir::Value stopCodeInt, stopCodeChar; + if (!stopCode) { + stopCodeChar = fir::AbsentOp::create(builder, loc, stopCharTy); + stopCodeInt = + fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty)); + } else if (fir::isa_integer(stopCode.getType())) { + stopCodeChar = fir::AbsentOp::create(builder, loc, stopCharTy); + stopCodeInt = builder.createTemporary(loc, i32Ty); + if (stopCode.getType() != i32Ty) + stopCode = fir::ConvertOp::create(builder, loc, i32Ty, stopCode); + fir::StoreOp::create(builder, loc, stopCode, stopCodeInt); + } else { + stopCodeChar = stopCode; + if (!mlir::isa<fir::BoxCharType>(stopCodeChar.getType())) { + auto len = + fir::UndefOp::create(builder, loc, builder.getCharacterLengthType()); + stopCodeChar = + fir::EmboxCharOp::create(builder, loc, stopCharTy, stopCodeChar, len); + } + stopCodeInt = + fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty)); + } + + llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( + builder, loc, ftype, quiet, stopCodeInt, stopCodeChar); + return fir::CallOp::create(builder, loc, funcOp, args); +} + +enum class TerminationKind { Normal = 0, Error = 1, FailImage = 2 }; +// Generates a wrapper function for the different kind of termination in PRIF. +// This function will be used to register wrappers on PRIF runtime termination +// functions into the Fortran runtime. +mlir::Value genTerminationOperationWrapper(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::ModuleOp module, + TerminationKind termKind) { + std::string funcName; + mlir::FunctionType funcType = + mlir::FunctionType::get(builder.getContext(), {}, {}); + mlir::Type i32Ty = builder.getI32Type(); + if (termKind == TerminationKind::Normal) { + funcName = getPRIFProcName("stop"); + funcType = mlir::FunctionType::get(builder.getContext(), {i32Ty}, {}); + } else if (termKind == TerminationKind::Error) { + funcName = getPRIFProcName("error_stop"); + funcType = mlir::FunctionType::get(builder.getContext(), {i32Ty}, {}); + } else { + funcName = getPRIFProcName("fail_image"); + } + funcName += "_termination_wrapper"; + mlir::func::FuncOp funcWrapperOp = + module.lookupSymbol<mlir::func::FuncOp>(funcName); + + if (!funcWrapperOp) { + funcWrapperOp = builder.createFunction(loc, funcName, funcType); + + // generating the body of the function. + mlir::OpBuilder::InsertPoint saveInsertPoint = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(funcWrapperOp.addEntryBlock()); + + if (termKind == TerminationKind::Normal) { + genPRIFStopErrorStop(builder, loc, funcWrapperOp.getArgument(0), + /*isError*/ false); + } else if (termKind == TerminationKind::Error) { + genPRIFStopErrorStop(builder, loc, funcWrapperOp.getArgument(0), + /*isError*/ true); + } else { + mlir::func::FuncOp fOp = builder.createFunction( + loc, getPRIFProcName("fail_image"), + mlir::FunctionType::get(builder.getContext(), {}, {})); + fir::CallOp::create(builder, loc, fOp); + } + + mlir::func::ReturnOp::create(builder, loc); + builder.restoreInsertionPoint(saveInsertPoint); + } + + mlir::SymbolRefAttr symbolRef = mlir::SymbolRefAttr::get( + builder.getContext(), funcWrapperOp.getSymNameAttr()); + return fir::AddrOfOp::create(builder, loc, funcType, symbolRef); +} + /// Convert mif.init operation to runtime call of 'prif_init' struct MIFInitOpConversion : public mlir::OpRewritePattern<mif::InitOp> { using OpRewritePattern::OpRewritePattern; @@ -80,6 +193,39 @@ struct MIFInitOpConversion : public mlir::OpRewritePattern<mif::InitOp> { mlir::Type i32Ty = builder.getI32Type(); mlir::Value result = builder.createTemporary(loc, i32Ty); + + // Registering PRIF runtime termination to the Fortran runtime + // STOP + mlir::Value funcStopOp = genTerminationOperationWrapper( + builder, loc, mod, TerminationKind::Normal); + mlir::func::FuncOp normalEndFunc = + fir::runtime::getRuntimeFunc<mkRTKey(RegisterImagesNormalEndCallback)>( + loc, builder); + llvm::SmallVector<mlir::Value> args1 = fir::runtime::createArguments( + builder, loc, normalEndFunc.getFunctionType(), funcStopOp); + fir::CallOp::create(builder, loc, normalEndFunc, args1); + + // ERROR STOP + mlir::Value funcErrorStopOp = genTerminationOperationWrapper( + builder, loc, mod, TerminationKind::Error); + mlir::func::FuncOp errorFunc = + fir::runtime::getRuntimeFunc<mkRTKey(RegisterImagesErrorCallback)>( + loc, builder); + llvm::SmallVector<mlir::Value> args2 = fir::runtime::createArguments( + builder, loc, errorFunc.getFunctionType(), funcErrorStopOp); + fir::CallOp::create(builder, loc, errorFunc, args2); + + // FAIL IMAGE + mlir::Value failImageOp = genTerminationOperationWrapper( + builder, loc, mod, TerminationKind::FailImage); + mlir::func::FuncOp failImageFunc = + fir::runtime::getRuntimeFunc<mkRTKey(RegisterFailImageCallback)>( + loc, builder); + llvm::SmallVector<mlir::Value> args3 = fir::runtime::createArguments( + builder, loc, errorFunc.getFunctionType(), failImageOp); + fir::CallOp::create(builder, loc, failImageFunc, args3); + + // Intialize the multi-image parallel environment mlir::FunctionType ftype = mlir::FunctionType::get( builder.getContext(), /*inputs*/ {builder.getRefType(i32Ty)}, /*results*/ {}); @@ -210,9 +356,7 @@ struct MIFSyncAllOpConversion : public mlir::OpRewritePattern<mif::SyncAllOp> { auto [errmsgArg, errmsgAllocArg] = genErrmsgPRIF(builder, loc, op.getErrmsg()); - mlir::Value stat = op.getStat(); - if (!stat) - stat = fir::AbsentOp::create(builder, loc, getPRIFStatType(builder)); + mlir::Value stat = genStatPRIF(builder, loc, op.getStat()); llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( builder, loc, ftype, stat, errmsgArg, errmsgAllocArg); rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args); @@ -261,9 +405,7 @@ struct MIFSyncImagesOpConversion } auto [errmsgArg, errmsgAllocArg] = genErrmsgPRIF(builder, loc, op.getErrmsg()); - mlir::Value stat = op.getStat(); - if (!stat) - stat = fir::AbsentOp::create(builder, loc, getPRIFStatType(builder)); + mlir::Value stat = genStatPRIF(builder, loc, op.getStat()); llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( builder, loc, ftype, imageSet, stat, errmsgArg, errmsgAllocArg); rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args); @@ -293,9 +435,7 @@ struct MIFSyncMemoryOpConversion auto [errmsgArg, errmsgAllocArg] = genErrmsgPRIF(builder, loc, op.getErrmsg()); - mlir::Value stat = op.getStat(); - if (!stat) - stat = fir::AbsentOp::create(builder, loc, getPRIFStatType(builder)); + mlir::Value stat = genStatPRIF(builder, loc, op.getStat()); llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( builder, loc, ftype, stat, errmsgArg, errmsgAllocArg); rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args); @@ -303,6 +443,37 @@ struct MIFSyncMemoryOpConversion } }; +/// Convert mif.sync_team operation to runtime call of 'prif_sync_team' +struct MIFSyncTeamOpConversion + : public mlir::OpRewritePattern<mif::SyncTeamOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::SyncTeamOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + mlir::Type boxTy = fir::BoxType::get(builder.getNoneType()); + mlir::Type errmsgTy = getPRIFErrmsgType(builder); + mlir::FunctionType ftype = mlir::FunctionType::get( + builder.getContext(), + /*inputs*/ {boxTy, getPRIFStatType(builder), errmsgTy, errmsgTy}, + /*results*/ {}); + mlir::func::FuncOp funcOp = + builder.createFunction(loc, getPRIFProcName("sync_team"), ftype); + + auto [errmsgArg, errmsgAllocArg] = + genErrmsgPRIF(builder, loc, op.getErrmsg()); + mlir::Value stat = genStatPRIF(builder, loc, op.getStat()); + llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( + builder, loc, ftype, op.getTeam(), stat, errmsgArg, errmsgAllocArg); + rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args); + return mlir::success(); + } +}; + /// Generate call to collective subroutines except co_reduce /// A must be lowered as a box static fir::CallOp genCollectiveSubroutine(fir::FirOpBuilder &builder, @@ -432,6 +603,208 @@ struct MIFCoSumOpConversion : public mlir::OpRewritePattern<mif::CoSumOp> { } }; +/// Convert mif.form_team operation to runtime call of 'prif_form_team' +struct MIFFormTeamOpConversion + : public mlir::OpRewritePattern<mif::FormTeamOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::FormTeamOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + mlir::Type errmsgTy = getPRIFErrmsgType(builder); + mlir::Type boxTy = fir::BoxType::get(builder.getNoneType()); + mlir::FunctionType ftype = mlir::FunctionType::get( + builder.getContext(), + /*inputs*/ + {builder.getRefType(builder.getI64Type()), boxTy, + builder.getRefType(builder.getI32Type()), getPRIFStatType(builder), + errmsgTy, errmsgTy}, + /*results*/ {}); + mlir::func::FuncOp funcOp = + builder.createFunction(loc, getPRIFProcName("form_team"), ftype); + + mlir::Type i64Ty = builder.getI64Type(); + mlir::Value teamNumber = builder.createTemporary(loc, i64Ty); + mlir::Value t = + (op.getTeamNumber().getType() == i64Ty) + ? op.getTeamNumber() + : fir::ConvertOp::create(builder, loc, i64Ty, op.getTeamNumber()); + fir::StoreOp::create(builder, loc, t, teamNumber); + + mlir::Type i32Ty = builder.getI32Type(); + mlir::Value newIndex; + if (op.getNewIndex()) { + newIndex = builder.createTemporary(loc, i32Ty); + mlir::Value ni = + (op.getNewIndex().getType() == i32Ty) + ? op.getNewIndex() + : fir::ConvertOp::create(builder, loc, i32Ty, op.getNewIndex()); + fir::StoreOp::create(builder, loc, ni, newIndex); + } else + newIndex = fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty)); + + mlir::Value stat = genStatPRIF(builder, loc, op.getStat()); + auto [errmsgArg, errmsgAllocArg] = + genErrmsgPRIF(builder, loc, op.getErrmsg()); + llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( + builder, loc, ftype, teamNumber, op.getTeamVar(), newIndex, stat, + errmsgArg, errmsgAllocArg); + fir::CallOp callOp = fir::CallOp::create(builder, loc, funcOp, args); + rewriter.replaceOp(op, callOp); + return mlir::success(); + } +}; + +/// Convert mif.change_team operation to runtime call of 'prif_change_team' +struct MIFChangeTeamOpConversion + : public mlir::OpRewritePattern<mif::ChangeTeamOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::ChangeTeamOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + builder.setInsertionPoint(op); + + mlir::Location loc = op.getLoc(); + mlir::Type errmsgTy = getPRIFErrmsgType(builder); + mlir::Type boxTy = fir::BoxType::get(builder.getNoneType()); + mlir::FunctionType ftype = mlir::FunctionType::get( + builder.getContext(), + /*inputs*/ {boxTy, getPRIFStatType(builder), errmsgTy, errmsgTy}, + /*results*/ {}); + mlir::func::FuncOp funcOp = + builder.createFunction(loc, getPRIFProcName("change_team"), ftype); + + mlir::Value stat = genStatPRIF(builder, loc, op.getStat()); + auto [errmsgArg, errmsgAllocArg] = + genErrmsgPRIF(builder, loc, op.getErrmsg()); + llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( + builder, loc, ftype, op.getTeam(), stat, errmsgArg, errmsgAllocArg); + fir::CallOp::create(builder, loc, funcOp, args); + + mlir::Operation *changeOp = op.getOperation(); + auto &bodyRegion = op.getRegion(); + mlir::Block &bodyBlock = bodyRegion.front(); + + rewriter.inlineBlockBefore(&bodyBlock, changeOp); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +/// Convert mif.end_team operation to runtime call of 'prif_end_team' +struct MIFEndTeamOpConversion : public mlir::OpRewritePattern<mif::EndTeamOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::EndTeamOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + mlir::Type errmsgTy = getPRIFErrmsgType(builder); + mlir::FunctionType ftype = mlir::FunctionType::get( + builder.getContext(), + /*inputs*/ {getPRIFStatType(builder), errmsgTy, errmsgTy}, + /*results*/ {}); + mlir::func::FuncOp funcOp = + builder.createFunction(loc, getPRIFProcName("end_team"), ftype); + + mlir::Value stat = genStatPRIF(builder, loc, op.getStat()); + auto [errmsgArg, errmsgAllocArg] = + genErrmsgPRIF(builder, loc, op.getErrmsg()); + llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( + builder, loc, ftype, stat, errmsgArg, errmsgAllocArg); + fir::CallOp callOp = fir::CallOp::create(builder, loc, funcOp, args); + rewriter.replaceOp(op, callOp); + return mlir::success(); + } +}; + +/// Convert mif.get_team operation to runtime call of 'prif_get_team' +struct MIFGetTeamOpConversion : public mlir::OpRewritePattern<mif::GetTeamOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::GetTeamOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + mlir::Type boxTy = fir::BoxType::get(builder.getNoneType()); + mlir::Type lvlTy = builder.getRefType(builder.getI32Type()); + mlir::FunctionType ftype = + mlir::FunctionType::get(builder.getContext(), + /*inputs*/ {lvlTy, boxTy}, + /*results*/ {}); + mlir::func::FuncOp funcOp = + builder.createFunction(loc, getPRIFProcName("get_team"), ftype); + + mlir::Value level = op.getLevel(); + if (!level) + level = fir::AbsentOp::create(builder, loc, lvlTy); + else { + mlir::Value cst = op.getLevel(); + mlir::Type i32Ty = builder.getI32Type(); + level = builder.createTemporary(loc, i32Ty); + if (cst.getType() != i32Ty) + cst = builder.createConvert(loc, i32Ty, cst); + fir::StoreOp::create(builder, loc, cst, level); + } + mlir::Type resultType = op.getResult().getType(); + mlir::Type baseTy = fir::unwrapRefType(resultType); + mlir::Value team = builder.createTemporary(loc, baseTy); + fir::EmboxOp box = fir::EmboxOp::create(builder, loc, resultType, team); + + llvm::SmallVector<mlir::Value> args = + fir::runtime::createArguments(builder, loc, ftype, level, box); + fir::CallOp::create(builder, loc, funcOp, args); + + rewriter.replaceOp(op, box); + return mlir::success(); + } +}; + +/// Convert mif.team_number operation to runtime call of 'prif_team_number' +struct MIFTeamNumberOpConversion + : public mlir::OpRewritePattern<mif::TeamNumberOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::TeamNumberOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + mlir::Type i64Ty = builder.getI64Type(); + mlir::Type boxTy = fir::BoxType::get(builder.getNoneType()); + mlir::FunctionType ftype = + mlir::FunctionType::get(builder.getContext(), + /*inputs*/ {boxTy, builder.getRefType(i64Ty)}, + /*results*/ {}); + mlir::func::FuncOp funcOp = + builder.createFunction(loc, getPRIFProcName("team_number"), ftype); + + mlir::Value team = op.getTeam(); + if (!team) + team = fir::AbsentOp::create(builder, loc, boxTy); + + mlir::Value result = builder.createTemporary(loc, i64Ty); + llvm::SmallVector<mlir::Value> args = + fir::runtime::createArguments(builder, loc, ftype, team, result); + fir::CallOp::create(builder, loc, funcOp, args); + fir::LoadOp load = fir::LoadOp::create(builder, loc, result); + rewriter.replaceOp(op, load); + return mlir::success(); + } +}; + class MIFOpConversion : public fir::impl::MIFOpConversionBase<MIFOpConversion> { public: void runOnOperation() override { @@ -458,7 +831,10 @@ void mif::populateMIFOpConversionPatterns(mlir::RewritePatternSet &patterns) { patterns.insert<MIFInitOpConversion, MIFThisImageOpConversion, MIFNumImagesOpConversion, MIFSyncAllOpConversion, MIFSyncImagesOpConversion, MIFSyncMemoryOpConversion, - MIFCoBroadcastOpConversion, MIFCoMaxOpConversion, - MIFCoMinOpConversion, MIFCoSumOpConversion>( + MIFSyncTeamOpConversion, MIFCoBroadcastOpConversion, + MIFCoMaxOpConversion, MIFCoMinOpConversion, + MIFCoSumOpConversion, MIFFormTeamOpConversion, + MIFChangeTeamOpConversion, MIFEndTeamOpConversion, + MIFGetTeamOpConversion, MIFTeamNumberOpConversion>( patterns.getContext()); } diff --git a/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp index 25a8f7a..c9d52c4 100644 --- a/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp @@ -246,7 +246,9 @@ struct DispatchOpConv : public OpConversionPattern<fir::DispatchOp> { args.append(dispatch.getArgs().begin(), dispatch.getArgs().end()); rewriter.replaceOpWithNewOp<fir::CallOp>( dispatch, resTypes, nullptr, args, dispatch.getArgAttrsAttr(), - dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr()); + dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr(), + /*inline_attr*/ fir::FortranInlineEnumAttr{}, + /*accessGroups*/ mlir::ArrayAttr{}); return mlir::success(); } diff --git a/flang/lib/Optimizer/Transforms/SetRuntimeCallAttributes.cpp b/flang/lib/Optimizer/Transforms/SetRuntimeCallAttributes.cpp index 378037e..4ba2ea5 100644 --- a/flang/lib/Optimizer/Transforms/SetRuntimeCallAttributes.cpp +++ b/flang/lib/Optimizer/Transforms/SetRuntimeCallAttributes.cpp @@ -85,7 +85,10 @@ static mlir::LLVM::MemoryEffectsAttr getGenericMemoryAttr(fir::CallOp callOp) { callOp->getContext(), {/*other=*/mlir::LLVM::ModRefInfo::NoModRef, /*argMem=*/mlir::LLVM::ModRefInfo::ModRef, - /*inaccessibleMem=*/mlir::LLVM::ModRefInfo::ModRef}); + /*inaccessibleMem=*/mlir::LLVM::ModRefInfo::ModRef, + /*errnoMem=*/mlir::LLVM::ModRefInfo::NoModRef, + /*targetMem0=*/mlir::LLVM::ModRefInfo::NoModRef, + /*targetMem1=*/mlir::LLVM::ModRefInfo::NoModRef}); } return {}; diff --git a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp index 03f97eb..3c4da62 100644 --- a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp @@ -254,6 +254,10 @@ public: // Collect iteration variable(s) allocations so that we can move them // outside the `fir.do_concurrent` wrapper. + // There actually may be more operations that just allocations + // at the beginning of the wrapper block, e.g. LICM may move + // some operations from the inner fir.do_concurrent.loop into + // this block. llvm::SmallVector<mlir::Operation *> opsToMove; for (mlir::Operation &op : llvm::drop_end(wrapperBlock)) opsToMove.push_back(&op); @@ -262,8 +266,13 @@ public: rewriter, doConcurentOp->getParentOfType<mlir::ModuleOp>()); auto *allocIt = firBuilder.getAllocaBlock(); - for (mlir::Operation *op : llvm::reverse(opsToMove)) - rewriter.moveOpBefore(op, allocIt, allocIt->begin()); + // Move alloca operations into the alloca-block, and all other + // operations - right before fir.do_concurrent. + for (mlir::Operation *op : opsToMove) + if (mlir::isa<fir::AllocaOp>(op)) + rewriter.moveOpBefore(op, allocIt, allocIt->begin()); + else + rewriter.moveOpBefore(op, doConcurentOp); rewriter.setInsertionPointAfter(doConcurentOp); fir::DoLoopOp innermostUnorderdLoop; diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp index 49a085e..49ae189 100644 --- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp @@ -730,7 +730,6 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder, mlir::Value ifCompatElem = fir::ConvertOp::create(builder, loc, ifCompatType, maskElem); - llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType}; fir::IfOp ifOp = fir::IfOp::create(builder, loc, elementType, ifCompatElem, /*withElseRegion=*/true); diff --git a/flang/lib/Optimizer/Transforms/VScaleAttr.cpp b/flang/lib/Optimizer/Transforms/VScaleAttr.cpp index 54a2456..d0e83ef 100644 --- a/flang/lib/Optimizer/Transforms/VScaleAttr.cpp +++ b/flang/lib/Optimizer/Transforms/VScaleAttr.cpp @@ -33,9 +33,11 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include <algorithm> +#include <mlir/IR/Diagnostics.h> namespace fir { #define GEN_PASS_DEF_VSCALEATTR @@ -49,7 +51,8 @@ namespace { class VScaleAttrPass : public fir::impl::VScaleAttrBase<VScaleAttrPass> { public: VScaleAttrPass(const fir::VScaleAttrOptions &options) { - vscaleRange = options.vscaleRange; + vscaleMin = options.vscaleMin; + vscaleMax = options.vscaleMax; } VScaleAttrPass() {} void runOnOperation() override; @@ -63,16 +66,28 @@ void VScaleAttrPass::runOnOperation() { LLVM_DEBUG(llvm::dbgs() << "Func-name:" << func.getSymName() << "\n"); + if (!llvm::isPowerOf2_32(vscaleMin)) { + func->emitError( + "VScaleAttr: vscaleMin has to be a power-of-two greater than 0\n"); + return signalPassFailure(); + } + + if (vscaleMax != 0 && + (!llvm::isPowerOf2_32(vscaleMax) || (vscaleMin > vscaleMax))) { + func->emitError("VScaleAttr: vscaleMax has to be a power-of-two " + "greater-than-or-equal to vscaleMin or 0 to signify " + "an unbounded maximum\n"); + return signalPassFailure(); + } + auto context = &getContext(); auto intTy = mlir::IntegerType::get(context, 32); - assert(vscaleRange.first && "VScaleRange minimum should be non-zero"); - func->setAttr("vscale_range", mlir::LLVM::VScaleRangeAttr::get( - context, mlir::IntegerAttr::get(intTy, vscaleRange.first), - mlir::IntegerAttr::get(intTy, vscaleRange.second))); + context, mlir::IntegerAttr::get(intTy, vscaleMin), + mlir::IntegerAttr::get(intTy, vscaleMax))); LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n"); } |
