aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Lower/OpenACC.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Lower/OpenACC.cpp')
-rw-r--r--flang/lib/Lower/OpenACC.cpp268
1 files changed, 206 insertions, 62 deletions
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index cfb1891..af4f420 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -30,10 +30,12 @@
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Parser/parse-tree-visitor.h"
#include "flang/Parser/parse-tree.h"
+#include "flang/Parser/tools.h"
#include "flang/Semantics/expression.h"
#include "flang/Semantics/scope.h"
#include "flang/Semantics/tools.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/LLVM.h"
@@ -296,7 +298,7 @@ getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) {
if (const auto *designator =
std::get_if<Fortran::parser::Designator>(&accObject.u)) {
if (const auto *name =
- Fortran::semantics::getDesignatorNameIfDataRef(*designator))
+ Fortran::parser::GetDesignatorNameIfDataRef(*designator))
return *name->symbol;
if (const auto *arrayElement =
Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
@@ -327,7 +329,8 @@ genAtomicCaptureStatement(Fortran::lower::AbstractConverter &converter,
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::acc::AtomicReadOp::create(firOpBuilder, loc, fromAddress, toAddress,
- mlir::TypeAttr::get(elementType));
+ mlir::TypeAttr::get(elementType),
+ /*ifCond=*/mlir::Value{});
}
/// Used to generate atomic.write operation which is created in existing
@@ -347,7 +350,8 @@ genAtomicWriteStatement(Fortran::lower::AbstractConverter &converter,
rhsExpr = firOpBuilder.createConvert(loc, varType, rhsExpr);
firOpBuilder.restoreInsertionPoint(insertionPoint);
- mlir::acc::AtomicWriteOp::create(firOpBuilder, loc, lhsAddr, rhsExpr);
+ mlir::acc::AtomicWriteOp::create(firOpBuilder, loc, lhsAddr, rhsExpr,
+ /*ifCond=*/mlir::Value{});
}
/// Used to generate atomic.update operation which is created in existing
@@ -463,7 +467,8 @@ static inline void genAtomicUpdateStatement(
mlir::Operation *atomicUpdateOp = nullptr;
atomicUpdateOp =
- mlir::acc::AtomicUpdateOp::create(firOpBuilder, currentLocation, lhsAddr);
+ mlir::acc::AtomicUpdateOp::create(firOpBuilder, currentLocation, lhsAddr,
+ /*ifCond=*/mlir::Value{});
llvm::SmallVector<mlir::Type> varTys = {varType};
llvm::SmallVector<mlir::Location> locs = {currentLocation};
@@ -588,7 +593,9 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();
mlir::Operation *atomicCaptureOp = nullptr;
- atomicCaptureOp = mlir::acc::AtomicCaptureOp::create(firOpBuilder, loc);
+ atomicCaptureOp =
+ mlir::acc::AtomicCaptureOp::create(firOpBuilder, loc,
+ /*ifCond=*/mlir::Value{});
firOpBuilder.createBlock(&(atomicCaptureOp->getRegion(0)));
mlir::Block &block = atomicCaptureOp->getRegion(0).back();
@@ -711,6 +718,84 @@ static void genDataOperandOperations(
}
}
+template <typename GlobalCtorOrDtorOp, typename EntryOp, typename DeclareOp,
+ typename ExitOp>
+static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder,
+ fir::FirOpBuilder &builder,
+ mlir::Location loc, fir::GlobalOp globalOp,
+ mlir::acc::DataClause clause,
+ const std::string &declareGlobalName,
+ bool implicit, std::stringstream &asFortran) {
+ GlobalCtorOrDtorOp declareGlobalOp =
+ GlobalCtorOrDtorOp::create(modBuilder, loc, declareGlobalName);
+ builder.createBlock(&declareGlobalOp.getRegion(),
+ declareGlobalOp.getRegion().end(), {}, {});
+ builder.setInsertionPointToEnd(&declareGlobalOp.getRegion().back());
+
+ fir::AddrOfOp addrOp = fir::AddrOfOp::create(
+ builder, loc, fir::ReferenceType::get(globalOp.getType()),
+ globalOp.getSymbol());
+ addDeclareAttr(builder, addrOp, clause);
+
+ llvm::SmallVector<mlir::Value> bounds;
+ EntryOp entryOp = createDataEntryOp<EntryOp>(
+ builder, loc, addrOp.getResTy(), asFortran, bounds,
+ /*structured=*/false, implicit, clause, addrOp.getResTy().getType(),
+ /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
+ if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>)
+ DeclareOp::create(builder, loc,
+ mlir::acc::DeclareTokenType::get(entryOp.getContext()),
+ mlir::ValueRange(entryOp.getAccVar()));
+ else
+ DeclareOp::create(builder, loc, mlir::Value{},
+ mlir::ValueRange(entryOp.getAccVar()));
+ if constexpr (std::is_same_v<GlobalCtorOrDtorOp,
+ mlir::acc::GlobalDestructorOp>) {
+ if constexpr (std::is_same_v<ExitOp, mlir::acc::DeclareLinkOp>) {
+ // No destructor emission for declare link in this path to avoid
+ // complex var/varType/varPtrPtr signatures. The ctor registers the link.
+ } else if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
+ std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>) {
+ ExitOp::create(builder, entryOp.getLoc(), entryOp.getAccVar(),
+ entryOp.getVar(), entryOp.getVarType(),
+ entryOp.getBounds(), entryOp.getAsyncOperands(),
+ entryOp.getAsyncOperandsDeviceTypeAttr(),
+ entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
+ /*structured=*/false, /*implicit=*/false,
+ builder.getStringAttr(*entryOp.getName()));
+ } else {
+ ExitOp::create(builder, entryOp.getLoc(), entryOp.getAccVar(),
+ entryOp.getBounds(), entryOp.getAsyncOperands(),
+ entryOp.getAsyncOperandsDeviceTypeAttr(),
+ entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
+ /*structured=*/false, /*implicit=*/false,
+ builder.getStringAttr(*entryOp.getName()));
+ }
+ }
+ mlir::acc::TerminatorOp::create(builder, loc);
+ modBuilder.setInsertionPointAfter(declareGlobalOp);
+}
+
+template <typename EntryOp, typename ExitOp>
+static void
+emitCtorDtorPair(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder,
+ mlir::Location operandLocation, fir::GlobalOp globalOp,
+ mlir::acc::DataClause clause, std::stringstream &asFortran,
+ const std::string &ctorName) {
+ createDeclareGlobalOp<mlir::acc::GlobalConstructorOp, EntryOp,
+ mlir::acc::DeclareEnterOp, ExitOp>(
+ modBuilder, builder, operandLocation, globalOp, clause, ctorName,
+ /*implicit=*/false, asFortran);
+
+ std::stringstream dtorName;
+ dtorName << globalOp.getSymName().str() << "_acc_dtor";
+ createDeclareGlobalOp<mlir::acc::GlobalDestructorOp,
+ mlir::acc::GetDevicePtrOp, mlir::acc::DeclareExitOp,
+ ExitOp>(modBuilder, builder, operandLocation, globalOp,
+ clause, dtorName.str(),
+ /*implicit=*/false, asFortran);
+}
+
template <typename EntryOp, typename ExitOp>
static void genDeclareDataOperandOperations(
const Fortran::parser::AccObjectList &objectList,
@@ -726,6 +811,37 @@ static void genDeclareDataOperandOperations(
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
+ // Handle COMMON/global symbols via module-level ctor/dtor path.
+ if (symbol.detailsIf<Fortran::semantics::CommonBlockDetails>() ||
+ Fortran::semantics::FindCommonBlockContaining(symbol)) {
+ emitCommonGlobal(
+ converter, builder, accObject, dataClause,
+ [&](mlir::OpBuilder &modBuilder, mlir::Location loc,
+ fir::GlobalOp globalOp, mlir::acc::DataClause clause,
+ std::stringstream &asFortranStr, const std::string &ctorName) {
+ if constexpr (std::is_same_v<EntryOp, mlir::acc::DeclareLinkOp>) {
+ createDeclareGlobalOp<
+ mlir::acc::GlobalConstructorOp, mlir::acc::DeclareLinkOp,
+ mlir::acc::DeclareEnterOp, mlir::acc::DeclareLinkOp>(
+ modBuilder, builder, loc, globalOp, clause, ctorName,
+ /*implicit=*/false, asFortranStr);
+ } else if constexpr (std::is_same_v<EntryOp, mlir::acc::CreateOp> ||
+ std::is_same_v<EntryOp, mlir::acc::CopyinOp> ||
+ std::is_same_v<
+ EntryOp,
+ mlir::acc::DeclareDeviceResidentOp> ||
+ std::is_same_v<ExitOp, mlir::acc::CopyoutOp>) {
+ emitCtorDtorPair<EntryOp, ExitOp>(modBuilder, builder, loc,
+ globalOp, clause, asFortranStr,
+ ctorName);
+ } else {
+ // No module-level ctor/dtor for this clause (e.g., deviceptr,
+ // present). Handled via structured declare region only.
+ return;
+ }
+ });
+ continue;
+ }
Fortran::semantics::MaybeExpr designator = Fortran::common::visit(
[&](auto &&s) { return ea.Analyze(s); }, accObject.u);
fir::factory::AddrAndBoundsInfo info =
@@ -2712,12 +2828,19 @@ genACC(Fortran::lower::AbstractConverter &converter,
const auto &loopDirective =
std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t);
+ mlir::Location currentLocation =
+ converter.genLocation(beginLoopDirective.source);
bool needEarlyExitHandling = false;
- if (eval.lowerAsUnstructured())
+ if (eval.lowerAsUnstructured()) {
needEarlyExitHandling = hasEarlyReturn(eval);
+ // If the loop is lowered in an unstructured fashion, lowering generates
+ // explicit control flow that duplicates the looping semantics of the
+ // loops.
+ if (!needEarlyExitHandling)
+ TODO(currentLocation,
+ "loop with early exit inside OpenACC loop construct");
+ }
- mlir::Location currentLocation =
- converter.genLocation(beginLoopDirective.source);
Fortran::lower::StatementContext stmtCtx;
assert(loopDirective.v == llvm::acc::ACCD_loop &&
@@ -2900,7 +3023,7 @@ static Op createComputeOp(
if (const auto *designator =
std::get_if<Fortran::parser::Designator>(&accObject.u)) {
if (const auto *name =
- Fortran::semantics::getDesignatorNameIfDataRef(
+ Fortran::parser::GetDesignatorNameIfDataRef(
*designator)) {
auto cond = converter.getSymbolAddress(*name->symbol);
selfCond = builder.createConvert(clauseLocation,
@@ -3516,6 +3639,10 @@ genACC(Fortran::lower::AbstractConverter &converter,
converter.genLocation(beginCombinedDirective.source);
Fortran::lower::StatementContext stmtCtx;
+ if (eval.lowerAsUnstructured())
+ TODO(currentLocation,
+ "loop with early exit inside OpenACC combined construct");
+
if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) {
createComputeOp<mlir::acc::KernelsOp>(
converter, currentLocation, eval, semanticsContext, stmtCtx,
@@ -4080,49 +4207,6 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
waitOp.setAsyncAttr(firOpBuilder.getUnitAttr());
}
-template <typename GlobalOp, typename EntryOp, typename DeclareOp,
- typename ExitOp>
-static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder,
- fir::FirOpBuilder &builder,
- mlir::Location loc, fir::GlobalOp globalOp,
- mlir::acc::DataClause clause,
- const std::string &declareGlobalName,
- bool implicit, std::stringstream &asFortran) {
- GlobalOp declareGlobalOp =
- GlobalOp::create(modBuilder, loc, declareGlobalName);
- builder.createBlock(&declareGlobalOp.getRegion(),
- declareGlobalOp.getRegion().end(), {}, {});
- builder.setInsertionPointToEnd(&declareGlobalOp.getRegion().back());
-
- fir::AddrOfOp addrOp = fir::AddrOfOp::create(
- builder, loc, fir::ReferenceType::get(globalOp.getType()),
- globalOp.getSymbol());
- addDeclareAttr(builder, addrOp, clause);
-
- llvm::SmallVector<mlir::Value> bounds;
- EntryOp entryOp = createDataEntryOp<EntryOp>(
- builder, loc, addrOp.getResTy(), asFortran, bounds,
- /*structured=*/false, implicit, clause, addrOp.getResTy().getType(),
- /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
- if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>)
- DeclareOp::create(builder, loc,
- mlir::acc::DeclareTokenType::get(entryOp.getContext()),
- mlir::ValueRange(entryOp.getAccVar()));
- else
- DeclareOp::create(builder, loc, mlir::Value{},
- mlir::ValueRange(entryOp.getAccVar()));
- if constexpr (std::is_same_v<GlobalOp, mlir::acc::GlobalDestructorOp>) {
- ExitOp::create(builder, entryOp.getLoc(), entryOp.getAccVar(),
- entryOp.getBounds(), entryOp.getAsyncOperands(),
- entryOp.getAsyncOperandsDeviceTypeAttr(),
- entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
- /*structured=*/false, /*implicit=*/false,
- builder.getStringAttr(*entryOp.getName()));
- }
- mlir::acc::TerminatorOp::create(builder, loc);
- modBuilder.setInsertionPointAfter(declareGlobalOp);
-}
-
template <typename EntryOp>
static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
fir::FirOpBuilder &builder,
@@ -4261,8 +4345,7 @@ static void genGlobalCtors(Fortran::lower::AbstractConverter &converter,
Fortran::common::visitors{
[&](const Fortran::parser::Designator &designator) {
if (const auto *name =
- Fortran::semantics::getDesignatorNameIfDataRef(
- designator)) {
+ Fortran::parser::GetDesignatorNameIfDataRef(designator)) {
genCtors(operandLocation, *name->symbol);
}
},
@@ -4300,6 +4383,66 @@ genGlobalCtorsWithModifier(Fortran::lower::AbstractConverter &converter,
dataClause);
}
+static fir::GlobalOp
+lookupGlobalBySymbolOrEquivalence(Fortran::lower::AbstractConverter &converter,
+ fir::FirOpBuilder &builder,
+ const Fortran::semantics::Symbol &sym) {
+ const Fortran::semantics::Symbol *commonBlock =
+ Fortran::semantics::FindCommonBlockContaining(sym);
+ std::string globalName = commonBlock ? converter.mangleName(*commonBlock)
+ : converter.mangleName(sym);
+ if (fir::GlobalOp g = builder.getNamedGlobal(globalName)) {
+ return g;
+ }
+ // Not found: if not a COMMON member, try equivalence members
+ if (!commonBlock) {
+ if (const Fortran::semantics::EquivalenceSet *eqSet =
+ Fortran::semantics::FindEquivalenceSet(sym)) {
+ for (const Fortran::semantics::EquivalenceObject &eqObj : *eqSet) {
+ std::string eqName = converter.mangleName(eqObj.symbol);
+ if (fir::GlobalOp g = builder.getNamedGlobal(eqName))
+ return g;
+ }
+ }
+ }
+ return {};
+}
+
+template <typename EmitterFn>
+static void emitCommonGlobal(Fortran::lower::AbstractConverter &converter,
+ fir::FirOpBuilder &builder,
+ const Fortran::parser::AccObject &obj,
+ mlir::acc::DataClause clause,
+ EmitterFn &&emitCtorDtor) {
+ Fortran::semantics::Symbol &sym = getSymbolFromAccObject(obj);
+ if (!(sym.detailsIf<Fortran::semantics::CommonBlockDetails>() ||
+ Fortran::semantics::FindCommonBlockContaining(sym)))
+ return;
+
+ fir::GlobalOp globalOp =
+ lookupGlobalBySymbolOrEquivalence(converter, builder, sym);
+ if (!globalOp)
+ llvm::report_fatal_error("could not retrieve global symbol");
+
+ std::stringstream ctorName;
+ ctorName << globalOp.getSymName().str() << "_acc_ctor";
+ if (builder.getModule().lookupSymbol<mlir::acc::GlobalConstructorOp>(
+ ctorName.str()))
+ return;
+
+ mlir::Location operandLocation = genOperandLocation(converter, obj);
+ addDeclareAttr(builder, globalOp.getOperation(), clause);
+ mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion());
+ modBuilder.setInsertionPointAfter(globalOp);
+ std::stringstream asFortran;
+ asFortran << sym.name().ToString();
+
+ auto savedIP = builder.saveInsertionPoint();
+ emitCtorDtor(modBuilder, operandLocation, globalOp, clause, asFortran,
+ ctorName.str());
+ builder.restoreInsertionPoint(savedIP);
+}
+
static void
genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
@@ -4325,11 +4468,9 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
dataClauseOperands.end());
} else if (const auto *createClause =
std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
- const Fortran::parser::AccObjectListWithModifier &listWithModifier =
- createClause->v;
- const auto &accObjectList =
- std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
auto crtDataStart = dataClauseOperands.size();
+ const auto &accObjectList =
+ std::get<Fortran::parser::AccObjectList>(createClause->v.t);
genDeclareDataOperandOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
accObjectList, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_create,
@@ -4361,11 +4502,9 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
} else if (const auto *copyoutClause =
std::get_if<Fortran::parser::AccClause::Copyout>(
&clause.u)) {
- const Fortran::parser::AccObjectListWithModifier &listWithModifier =
- copyoutClause->v;
- const auto &accObjectList =
- std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
auto crtDataStart = dataClauseOperands.size();
+ const auto &accObjectList =
+ std::get<Fortran::parser::AccObjectList>(copyoutClause->v.t);
genDeclareDataOperandOperations<mlir::acc::CreateOp,
mlir::acc::CopyoutOp>(
accObjectList, converter, semanticsContext, stmtCtx,
@@ -4406,6 +4545,11 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
}
}
+ // If no structured operands were generated (all objects were COMMON),
+ // do not create a declare region.
+ if (dataClauseOperands.empty())
+ return;
+
mlir::func::FuncOp funcOp = builder.getFunction();
auto ops = funcOp.getOps<mlir::acc::DeclareEnterOp>();
mlir::Value declareToken;