//===-- Lower/Support/Utils.cpp -- utilities --------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ // //===----------------------------------------------------------------------===// #include "flang/Lower/Support/Utils.h" #include "flang/Common/indirection.h" #include "flang/Lower/AbstractConverter.h" #include "flang/Lower/ConvertVariable.h" #include "flang/Lower/IterationSpace.h" #include "flang/Lower/Support/PrivateReductionUtils.h" #include "flang/Optimizer/Builder/HLFIRTools.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/HLFIR/HLFIRDialect.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include #include #include namespace Fortran::lower { // Fortran::evaluate::Expr are functional values organized like an AST. A // Fortran::evaluate::Expr is meant to be moved and cloned. Using the front end // tools can often cause copies and extra wrapper classes to be added to any // Fortran::evaluate::Expr. These values should not be assumed or relied upon to // have an *object* identity. They are deeply recursive, irregular structures // built from a large number of classes which do not use inheritance and // necessitate a large volume of boilerplate code as a result. // // Contrastingly, LLVM data structures make ubiquitous assumptions about an // object's identity via pointers to the object. An object's location in memory // is thus very often an identifying relation. // This class defines a hash computation of a Fortran::evaluate::Expr tree value // so it can be used with llvm::DenseMap. The Fortran::evaluate::Expr need not // have the same address. class HashEvaluateExpr { public: // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an // identity property. static unsigned getHashValue(const Fortran::semantics::Symbol &x) { return static_cast(reinterpret_cast(&x)); } template static unsigned getHashValue(const Fortran::common::Indirection &x) { return getHashValue(x.value()); } template static unsigned getHashValue(const std::optional &x) { if (x.has_value()) return getHashValue(x.value()); return 0u; } static unsigned getHashValue(const Fortran::evaluate::Subscript &x) { return Fortran::common::visit( [&](const auto &v) { return getHashValue(v); }, x.u); } static unsigned getHashValue(const Fortran::evaluate::Triplet &x) { return getHashValue(x.lower()) - getHashValue(x.upper()) * 5u - getHashValue(x.stride()) * 11u; } static unsigned getHashValue(const Fortran::evaluate::Component &x) { return getHashValue(x.base()) * 83u - getHashValue(x.GetLastSymbol()); } static unsigned getHashValue(const Fortran::evaluate::ArrayRef &x) { unsigned subs = 1u; for (const Fortran::evaluate::Subscript &v : x.subscript()) subs -= getHashValue(v); return getHashValue(x.base()) * 89u - subs; } static unsigned getHashValue(const Fortran::evaluate::CoarrayRef &x) { unsigned cosubs = 3u; for (const Fortran::evaluate::Expr &v : x.cosubscript()) cosubs -= getHashValue(v); return getHashValue(x.base()) * 97u - cosubs + getHashValue(x.stat()) + 257u + getHashValue(x.team()); } static unsigned getHashValue(const Fortran::evaluate::NamedEntity &x) { if (x.IsSymbol()) return getHashValue(x.GetFirstSymbol()) * 11u; return getHashValue(x.GetComponent()) * 13u; } static unsigned getHashValue(const Fortran::evaluate::DataRef &x) { return Fortran::common::visit( [&](const auto &v) { return getHashValue(v); }, x.u); } static unsigned getHashValue(const Fortran::evaluate::ComplexPart &x) { return getHashValue(x.complex()) - static_cast(x.part()); } template static unsigned getHashValue( const Fortran::evaluate::Convert, TC2> &x) { return getHashValue(x.left()) - (static_cast(TC1) + 2u) - (static_cast(KIND) + 5u); } template static unsigned getHashValue(const Fortran::evaluate::ComplexComponent &x) { return getHashValue(x.left()) - (static_cast(x.isImaginaryPart) + 1u) * 3u; } template static unsigned getHashValue(const Fortran::evaluate::Parentheses &x) { return getHashValue(x.left()) * 17u; } template static unsigned getHashValue( const Fortran::evaluate::Negate> &x) { return getHashValue(x.left()) - (static_cast(TC) + 5u) - (static_cast(KIND) + 7u); } template static unsigned getHashValue( const Fortran::evaluate::Add> &x) { return (getHashValue(x.left()) + getHashValue(x.right())) * 23u + static_cast(TC) + static_cast(KIND); } template static unsigned getHashValue( const Fortran::evaluate::Subtract> &x) { return (getHashValue(x.left()) - getHashValue(x.right())) * 19u + static_cast(TC) + static_cast(KIND); } template static unsigned getHashValue( const Fortran::evaluate::Multiply> &x) { return (getHashValue(x.left()) + getHashValue(x.right())) * 29u + static_cast(TC) + static_cast(KIND); } template static unsigned getHashValue( const Fortran::evaluate::Divide> &x) { return (getHashValue(x.left()) - getHashValue(x.right())) * 31u + static_cast(TC) + static_cast(KIND); } template static unsigned getHashValue( const Fortran::evaluate::Power> &x) { return (getHashValue(x.left()) - getHashValue(x.right())) * 37u + static_cast(TC) + static_cast(KIND); } template static unsigned getHashValue( const Fortran::evaluate::Extremum> &x) { return (getHashValue(x.left()) + getHashValue(x.right())) * 41u + static_cast(TC) + static_cast(KIND) + static_cast(x.ordering) * 7u; } template static unsigned getHashValue( const Fortran::evaluate::RealToIntPower> &x) { return (getHashValue(x.left()) - getHashValue(x.right())) * 43u + static_cast(TC) + static_cast(KIND); } template static unsigned getHashValue(const Fortran::evaluate::ComplexConstructor &x) { return (getHashValue(x.left()) - getHashValue(x.right())) * 47u + static_cast(KIND); } template static unsigned getHashValue(const Fortran::evaluate::Concat &x) { return (getHashValue(x.left()) - getHashValue(x.right())) * 53u + static_cast(KIND); } template static unsigned getHashValue(const Fortran::evaluate::SetLength &x) { return (getHashValue(x.left()) - getHashValue(x.right())) * 59u + static_cast(KIND); } static unsigned getHashValue(const Fortran::semantics::SymbolRef &sym) { return getHashValue(sym.get()); } static unsigned getHashValue(const Fortran::evaluate::Substring &x) { return 61u * Fortran::common::visit( [&](const auto &p) { return getHashValue(p); }, x.parent()) - getHashValue(x.lower()) - (getHashValue(x.lower()) + 1u); } static unsigned getHashValue(const Fortran::evaluate::StaticDataObject::Pointer &x) { return llvm::hash_value(x->name()); } static unsigned getHashValue(const Fortran::evaluate::SpecificIntrinsic &x) { return llvm::hash_value(x.name); } template static unsigned getHashValue(const Fortran::evaluate::Constant &x) { // FIXME: Should hash the content. return 103u; } static unsigned getHashValue(const Fortran::evaluate::ActualArgument &x) { if (const Fortran::evaluate::Symbol *sym = x.GetAssumedTypeDummy()) return getHashValue(*sym); return getHashValue(*x.UnwrapExpr()); } static unsigned getHashValue(const Fortran::evaluate::ProcedureDesignator &x) { return Fortran::common::visit( [&](const auto &v) { return getHashValue(v); }, x.u); } static unsigned getHashValue(const Fortran::evaluate::ProcedureRef &x) { unsigned args = 13u; for (const std::optional &v : x.arguments()) args -= getHashValue(v); return getHashValue(x.proc()) * 101u - args; } template static unsigned getHashValue(const Fortran::evaluate::ArrayConstructor &x) { // FIXME: hash the contents. return 127u; } static unsigned getHashValue(const Fortran::evaluate::ImpliedDoIndex &x) { return llvm::hash_value(toStringRef(x.name).str()) * 131u; } static unsigned getHashValue(const Fortran::evaluate::TypeParamInquiry &x) { return getHashValue(x.base()) * 137u - getHashValue(x.parameter()) * 3u; } static unsigned getHashValue(const Fortran::evaluate::DescriptorInquiry &x) { return getHashValue(x.base()) * 139u - static_cast(x.field()) * 13u + static_cast(x.dimension()); } static unsigned getHashValue(const Fortran::evaluate::StructureConstructor &x) { // FIXME: hash the contents. return 149u; } template static unsigned getHashValue(const Fortran::evaluate::Not &x) { return getHashValue(x.left()) * 61u + static_cast(KIND); } template static unsigned getHashValue(const Fortran::evaluate::LogicalOperation &x) { unsigned result = getHashValue(x.left()) + getHashValue(x.right()); return result * 67u + static_cast(x.logicalOperator) * 5u; } template static unsigned getHashValue( const Fortran::evaluate::Relational> &x) { return (getHashValue(x.left()) + getHashValue(x.right())) * 71u + static_cast(TC) + static_cast(KIND) + static_cast(x.opr) * 11u; } template static unsigned getHashValue(const Fortran::evaluate::Expr &x) { return Fortran::common::visit( [&](const auto &v) { return getHashValue(v); }, x.u); } static unsigned getHashValue( const Fortran::evaluate::Relational &x) { return Fortran::common::visit( [&](const auto &v) { return getHashValue(v); }, x.u); } template static unsigned getHashValue(const Fortran::evaluate::Designator &x) { return Fortran::common::visit( [&](const auto &v) { return getHashValue(v); }, x.u); } template static unsigned getHashValue(const Fortran::evaluate::value::Integer &x) { return static_cast(x.ToSInt()); } static unsigned getHashValue(const Fortran::evaluate::NullPointer &x) { return ~179u; } }; // Define the is equals test for using Fortran::evaluate::Expr values with // llvm::DenseMap. class IsEqualEvaluateExpr { public: // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an // identity property. static bool isEqual(const Fortran::semantics::Symbol &x, const Fortran::semantics::Symbol &y) { return isEqual(&x, &y); } static bool isEqual(const Fortran::semantics::Symbol *x, const Fortran::semantics::Symbol *y) { return x == y; } template static bool isEqual(const Fortran::common::Indirection &x, const Fortran::common::Indirection &y) { return isEqual(x.value(), y.value()); } template static bool isEqual(const std::optional &x, const std::optional &y) { if (x.has_value() && y.has_value()) return isEqual(x.value(), y.value()); return !x.has_value() && !y.has_value(); } template static bool isEqual(const std::vector &x, const std::vector &y) { if (x.size() != y.size()) return false; const std::size_t size = x.size(); for (std::remove_const_t i = 0; i < size; ++i) if (!isEqual(x[i], y[i])) return false; return true; } static bool isEqual(const Fortran::evaluate::Subscript &x, const Fortran::evaluate::Subscript &y) { return Fortran::common::visit( [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); } static bool isEqual(const Fortran::evaluate::Triplet &x, const Fortran::evaluate::Triplet &y) { return isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper()) && isEqual(x.stride(), y.stride()); } static bool isEqual(const Fortran::evaluate::Component &x, const Fortran::evaluate::Component &y) { return isEqual(x.base(), y.base()) && isEqual(x.GetLastSymbol(), y.GetLastSymbol()); } static bool isEqual(const Fortran::evaluate::ArrayRef &x, const Fortran::evaluate::ArrayRef &y) { return isEqual(x.base(), y.base()) && isEqual(x.subscript(), y.subscript()); } static bool isEqual(const Fortran::evaluate::CoarrayRef &x, const Fortran::evaluate::CoarrayRef &y) { return isEqual(x.base(), y.base()) && isEqual(x.cosubscript(), y.cosubscript()) && isEqual(x.stat(), y.stat()) && isEqual(x.team(), y.team()); } static bool isEqual(const Fortran::evaluate::NamedEntity &x, const Fortran::evaluate::NamedEntity &y) { if (x.IsSymbol() && y.IsSymbol()) return isEqual(x.GetFirstSymbol(), y.GetFirstSymbol()); return !x.IsSymbol() && !y.IsSymbol() && isEqual(x.GetComponent(), y.GetComponent()); } static bool isEqual(const Fortran::evaluate::DataRef &x, const Fortran::evaluate::DataRef &y) { return Fortran::common::visit( [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); } static bool isEqual(const Fortran::evaluate::ComplexPart &x, const Fortran::evaluate::ComplexPart &y) { return isEqual(x.complex(), y.complex()) && x.part() == y.part(); } template static bool isEqual(const Fortran::evaluate::Convert &x, const Fortran::evaluate::Convert &y) { return isEqual(x.left(), y.left()); } template static bool isEqual(const Fortran::evaluate::ComplexComponent &x, const Fortran::evaluate::ComplexComponent &y) { return isEqual(x.left(), y.left()) && x.isImaginaryPart == y.isImaginaryPart; } template static bool isEqual(const Fortran::evaluate::Parentheses &x, const Fortran::evaluate::Parentheses &y) { return isEqual(x.left(), y.left()); } template static bool isEqual(const Fortran::evaluate::Negate &x, const Fortran::evaluate::Negate &y) { return isEqual(x.left(), y.left()); } template static bool isBinaryEqual(const A &x, const A &y) { return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); } template static bool isEqual(const Fortran::evaluate::Add &x, const Fortran::evaluate::Add &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::Subtract &x, const Fortran::evaluate::Subtract &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::Multiply &x, const Fortran::evaluate::Multiply &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::Divide &x, const Fortran::evaluate::Divide &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::Power &x, const Fortran::evaluate::Power &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::Extremum &x, const Fortran::evaluate::Extremum &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::RealToIntPower &x, const Fortran::evaluate::RealToIntPower &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::ComplexConstructor &x, const Fortran::evaluate::ComplexConstructor &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::Concat &x, const Fortran::evaluate::Concat &y) { return isBinaryEqual(x, y); } template static bool isEqual(const Fortran::evaluate::SetLength &x, const Fortran::evaluate::SetLength &y) { return isBinaryEqual(x, y); } static bool isEqual(const Fortran::semantics::SymbolRef &x, const Fortran::semantics::SymbolRef &y) { return isEqual(x.get(), y.get()); } static bool isEqual(const Fortran::evaluate::Substring &x, const Fortran::evaluate::Substring &y) { return Fortran::common::visit( [&](const auto &p, const auto &q) { return isEqual(p, q); }, x.parent(), y.parent()) && isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper()); } static bool isEqual(const Fortran::evaluate::StaticDataObject::Pointer &x, const Fortran::evaluate::StaticDataObject::Pointer &y) { return x->name() == y->name(); } static bool isEqual(const Fortran::evaluate::SpecificIntrinsic &x, const Fortran::evaluate::SpecificIntrinsic &y) { return x.name == y.name; } template static bool isEqual(const Fortran::evaluate::Constant &x, const Fortran::evaluate::Constant &y) { return x == y; } static bool isEqual(const Fortran::evaluate::ActualArgument &x, const Fortran::evaluate::ActualArgument &y) { if (const Fortran::evaluate::Symbol *xs = x.GetAssumedTypeDummy()) { if (const Fortran::evaluate::Symbol *ys = y.GetAssumedTypeDummy()) return isEqual(*xs, *ys); return false; } return !y.GetAssumedTypeDummy() && isEqual(*x.UnwrapExpr(), *y.UnwrapExpr()); } static bool isEqual(const Fortran::evaluate::ProcedureDesignator &x, const Fortran::evaluate::ProcedureDesignator &y) { return Fortran::common::visit( [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); } static bool isEqual(const Fortran::evaluate::ProcedureRef &x, const Fortran::evaluate::ProcedureRef &y) { return isEqual(x.proc(), y.proc()) && isEqual(x.arguments(), y.arguments()); } template static bool isEqual(const Fortran::evaluate::ImpliedDo &x, const Fortran::evaluate::ImpliedDo &y) { return isEqual(x.values(), y.values()) && isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper()) && isEqual(x.stride(), y.stride()); } template static bool isEqual(const Fortran::evaluate::ArrayConstructorValues &x, const Fortran::evaluate::ArrayConstructorValues &y) { using Expr = Fortran::evaluate::Expr; using ImpliedDo = Fortran::evaluate::ImpliedDo; for (const auto &[xValue, yValue] : llvm::zip(x, y)) { bool checkElement = Fortran::common::visit( common::visitors{ [&](const Expr &v, const Expr &w) { return isEqual(v, w); }, [&](const ImpliedDo &v, const ImpliedDo &w) { return isEqual(v, w); }, [&](const Expr &, const ImpliedDo &) { return false; }, [&](const ImpliedDo &, const Expr &) { return false; }, }, xValue.u, yValue.u); if (!checkElement) { return false; } } return true; } static bool isEqual(const Fortran::evaluate::SubscriptInteger &x, const Fortran::evaluate::SubscriptInteger &y) { return x == y; } template static bool isEqual(const Fortran::evaluate::ArrayConstructor &x, const Fortran::evaluate::ArrayConstructor &y) { bool checkCharacterType = true; if constexpr (A::category == Fortran::common::TypeCategory::Character) { checkCharacterType = isEqual(*x.LEN(), *y.LEN()); } using Base = Fortran::evaluate::ArrayConstructorValues; return isEqual((Base)x, (Base)y) && (x.GetType() == y.GetType() && checkCharacterType); } static bool isEqual(const Fortran::evaluate::ImpliedDoIndex &x, const Fortran::evaluate::ImpliedDoIndex &y) { return toStringRef(x.name) == toStringRef(y.name); } static bool isEqual(const Fortran::evaluate::TypeParamInquiry &x, const Fortran::evaluate::TypeParamInquiry &y) { return isEqual(x.base(), y.base()) && isEqual(x.parameter(), y.parameter()); } static bool isEqual(const Fortran::evaluate::DescriptorInquiry &x, const Fortran::evaluate::DescriptorInquiry &y) { return isEqual(x.base(), y.base()) && x.field() == y.field() && x.dimension() == y.dimension(); } static bool isEqual(const Fortran::evaluate::StructureConstructor &x, const Fortran::evaluate::StructureConstructor &y) { const auto &xValues = x.values(); const auto &yValues = y.values(); if (xValues.size() != yValues.size()) return false; if (x.derivedTypeSpec() != y.derivedTypeSpec()) return false; for (const auto &[xSymbol, xValue] : xValues) { auto yIt = yValues.find(xSymbol); // This should probably never happen, since the derived type // should be the same. if (yIt == yValues.end()) return false; if (!isEqual(xValue, yIt->second)) return false; } return true; } template static bool isEqual(const Fortran::evaluate::Not &x, const Fortran::evaluate::Not &y) { return isEqual(x.left(), y.left()); } template static bool isEqual(const Fortran::evaluate::LogicalOperation &x, const Fortran::evaluate::LogicalOperation &y) { return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); } template static bool isEqual(const Fortran::evaluate::Relational &x, const Fortran::evaluate::Relational &y) { return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); } template static bool isEqual(const Fortran::evaluate::Expr &x, const Fortran::evaluate::Expr &y) { return Fortran::common::visit( [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); } static bool isEqual(const Fortran::evaluate::Relational &x, const Fortran::evaluate::Relational &y) { return Fortran::common::visit( [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); } template static bool isEqual(const Fortran::evaluate::Designator &x, const Fortran::evaluate::Designator &y) { return Fortran::common::visit( [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); } template static bool isEqual(const Fortran::evaluate::value::Integer &x, const Fortran::evaluate::value::Integer &y) { return x == y; } static bool isEqual(const Fortran::evaluate::NullPointer &x, const Fortran::evaluate::NullPointer &y) { return true; } template , bool> = true> static bool isEqual(const A &, const B &) { return false; } }; unsigned getHashValue(const Fortran::lower::SomeExpr *x) { return HashEvaluateExpr::getHashValue(*x); } unsigned getHashValue(const Fortran::lower::ExplicitIterSpace::ArrayBases &x) { return Fortran::common::visit( [&](const auto *p) { return HashEvaluateExpr::getHashValue(*p); }, x); } bool isEqual(const Fortran::lower::SomeExpr *x, const Fortran::lower::SomeExpr *y) { const auto *empty = llvm::DenseMapInfo::getEmptyKey(); const auto *tombstone = llvm::DenseMapInfo::getTombstoneKey(); if (x == empty || y == empty || x == tombstone || y == tombstone) return x == y; return x == y || IsEqualEvaluateExpr::isEqual(*x, *y); } bool isEqual(const Fortran::lower::ExplicitIterSpace::ArrayBases &x, const Fortran::lower::ExplicitIterSpace::ArrayBases &y) { return Fortran::common::visit( Fortran::common::visitors{ // Fortran::semantics::Symbol * are the exception here. These pointers // have identity; if two Symbol * values are the same (different) then // they are the same (different) logical symbol. [&](Fortran::lower::FrontEndSymbol p, Fortran::lower::FrontEndSymbol q) { return p == q; }, [&](const auto *p, const auto *q) { if constexpr (std::is_same_v) { return IsEqualEvaluateExpr::isEqual(*p, *q); } else { // Different subtree types are never equal. return false; } }}, x, y); } void copyFirstPrivateSymbol(lower::AbstractConverter &converter, const semantics::Symbol *sym, mlir::OpBuilder::InsertPoint *copyAssignIP) { if (sym->test(semantics::Symbol::Flag::OmpFirstPrivate) || sym->test(semantics::Symbol::Flag::LocalityLocalInit)) converter.copyHostAssociateVar(*sym, copyAssignIP); } template void privatizeSymbol( lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder, lower::SymMap &symTable, llvm::SetVector &allPrivatizedSymbols, llvm::SmallSet &mightHaveReadHostSym, const semantics::Symbol *symToPrivatize, OperandsStructType *clauseOps) { constexpr bool isDoConcurrent = std::is_same_v; mlir::OpBuilder::InsertPoint dcIP; if (isDoConcurrent) { dcIP = firOpBuilder.saveInsertionPoint(); firOpBuilder.setInsertionPoint( firOpBuilder.getRegion().getParentOfType()); } const semantics::Symbol *sym = isDoConcurrent ? &symToPrivatize->GetUltimate() : symToPrivatize; const lower::SymbolBox hsb = converter.lookupOneLevelUpSymbol(*sym); assert(hsb && "Host symbol box not found"); mlir::Location symLoc = hsb.getAddr().getLoc(); std::string privatizerName = sym->name().ToString() + ".privatizer"; bool emitCopyRegion = symToPrivatize->test(semantics::Symbol::Flag::OmpFirstPrivate) || symToPrivatize->test(semantics::Symbol::Flag::LocalityLocalInit); mlir::Value privVal = hsb.getAddr(); mlir::Type allocType = privVal.getType(); if (!mlir::isa(privVal.getType())) allocType = fir::unwrapRefType(privVal.getType()); if (auto poly = mlir::dyn_cast(allocType)) { if (!mlir::isa(poly.getEleTy()) && emitCopyRegion) TODO(symLoc, "create polymorphic host associated copy"); } // fir.array<> cannot be converted to any single llvm type and fir helpers // are not available in openmp to llvmir translation so we cannot generate // an alloca for a fir.array type there. Get around this by boxing all // arrays. if (mlir::isa(allocType)) { hlfir::Entity entity{hsb.getAddr()}; entity = genVariableBox(symLoc, firOpBuilder, entity); privVal = entity.getBase(); allocType = privVal.getType(); } if (mlir::isa(privVal.getType())) { // Boxes should be passed by reference into nested regions: auto oldIP = firOpBuilder.saveInsertionPoint(); firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); auto alloca = fir::AllocaOp::create(firOpBuilder, symLoc, privVal.getType()); firOpBuilder.restoreInsertionPoint(oldIP); fir::StoreOp::create(firOpBuilder, symLoc, privVal, alloca); privVal = alloca; } mlir::Type argType = privVal.getType(); OpType privatizerOp = [&]() { auto moduleOp = firOpBuilder.getModule(); auto uniquePrivatizerName = fir::getTypeAsString( allocType, converter.getKindMap(), converter.mangleName(*sym) + (emitCopyRegion ? "_firstprivate" : "_private")); if (auto existingPrivatizer = moduleOp.lookupSymbol(uniquePrivatizerName)) return existingPrivatizer; mlir::OpBuilder::InsertionGuard guard(firOpBuilder); firOpBuilder.setInsertionPointToStart(moduleOp.getBody()); OpType result; if constexpr (std::is_same_v) { result = OpType::create( firOpBuilder, symLoc, uniquePrivatizerName, allocType, emitCopyRegion ? mlir::omp::DataSharingClauseType::FirstPrivate : mlir::omp::DataSharingClauseType::Private); } else { result = OpType::create(firOpBuilder, symLoc, uniquePrivatizerName, allocType, emitCopyRegion ? fir::LocalitySpecifierType::LocalInit : fir::LocalitySpecifierType::Local); } fir::ExtendedValue symExV = converter.getSymbolExtendedValue(*sym); lower::SymMapScope outerScope(symTable); // Populate the `init` region. // We need to initialize in the following cases: // 1. The allocation was for a derived type which requires initialization // (this can be skipped if it will be initialized anyway by the copy // region, unless the derived type has allocatable components) // 2. The allocation was for any kind of box // 3. The allocation was for a boxed character const bool needsInitialization = (Fortran::lower::hasDefaultInitialization(sym->GetUltimate()) && (!emitCopyRegion || hlfir::mayHaveAllocatableComponent(allocType))) || mlir::isa(allocType) || mlir::isa(allocType); if (needsInitialization) { lower::SymbolBox hsb = converter.lookupOneLevelUpSymbol( isDoConcurrent ? symToPrivatize->GetUltimate() : *symToPrivatize); assert(hsb && "Host symbol box not found"); hlfir::Entity entity{hsb.getAddr()}; bool cannotHaveNonDefaultLowerBounds = !entity.mayHaveNonDefaultLowerBounds(); mlir::Region &initRegion = result.getInitRegion(); mlir::Location symLoc = hsb.getAddr().getLoc(); mlir::Block *initBlock = firOpBuilder.createBlock( &initRegion, /*insertPt=*/{}, {argType, argType}, {symLoc, symLoc}); bool emitCopyRegion = symToPrivatize->test(semantics::Symbol::Flag::OmpFirstPrivate) || symToPrivatize->test( Fortran::semantics::Symbol::Flag::LocalityLocalInit); populateByRefInitAndCleanupRegions( converter, symLoc, argType, /*scalarInitValue=*/nullptr, initBlock, result.getInitPrivateArg(), result.getInitMoldArg(), result.getDeallocRegion(), emitCopyRegion ? DeclOperationKind::FirstPrivateOrLocalInit : DeclOperationKind::PrivateOrLocal, symToPrivatize, cannotHaveNonDefaultLowerBounds, isDoConcurrent); // TODO: currently there are false positives from dead uses of the mold // arg if (result.initReadsFromMold()) mightHaveReadHostSym.insert(symToPrivatize); } // Populate the `copy` region if this is a `firstprivate`. if (emitCopyRegion) { mlir::Region ©Region = result.getCopyRegion(); // First block argument corresponding to the original/host value while // second block argument corresponding to the privatized value. mlir::Block *copyEntryBlock = firOpBuilder.createBlock( ©Region, /*insertPt=*/{}, {argType, argType}, {symLoc, symLoc}); firOpBuilder.setInsertionPointToEnd(copyEntryBlock); auto addSymbol = [&](unsigned argIdx, const semantics::Symbol *symToMap, bool force = false) { symExV.match( [&](const fir::MutableBoxValue &box) { symTable.addSymbol( *symToMap, fir::substBase(box, copyRegion.getArgument(argIdx)), force); }, [&](const auto &box) { symTable.addSymbol(*symToMap, copyRegion.getArgument(argIdx), force); }); }; addSymbol(0, sym, true); lower::SymMapScope innerScope(symTable); addSymbol(1, symToPrivatize); auto ip = firOpBuilder.saveInsertionPoint(); copyFirstPrivateSymbol(converter, symToPrivatize, &ip); if constexpr (std::is_same_v) { mlir::omp::YieldOp::create( firOpBuilder, hsb.getAddr().getLoc(), symTable.shallowLookupSymbol(*symToPrivatize).getAddr()); } else { fir::YieldOp::create( firOpBuilder, hsb.getAddr().getLoc(), symTable.shallowLookupSymbol(*symToPrivatize).getAddr()); } } return result; }(); if (clauseOps) { clauseOps->privateSyms.push_back(mlir::SymbolRefAttr::get(privatizerOp)); clauseOps->privateVars.push_back(privVal); } if (isDoConcurrent) allPrivatizedSymbols.insert(symToPrivatize); if (isDoConcurrent) firOpBuilder.restoreInsertionPoint(dcIP); } template void privatizeSymbol( lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder, lower::SymMap &symTable, llvm::SetVector &allPrivatizedSymbols, llvm::SmallSet &mightHaveReadHostSym, const semantics::Symbol *symToPrivatize, mlir::omp::PrivateClauseOps *clauseOps); template void privatizeSymbol( lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder, lower::SymMap &symTable, llvm::SetVector &allPrivatizedSymbols, llvm::SmallSet &mightHaveReadHostSym, const semantics::Symbol *symToPrivatize, fir::LocalitySpecifierOperands *clauseOps); } // end namespace Fortran::lower