aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Frontend/OpenMP
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Frontend/OpenMP')
-rw-r--r--llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp43
1 files changed, 27 insertions, 16 deletions
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index fff9a81..18a4f0a 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -530,7 +530,13 @@ void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
auto Int32Ty = Type::getInt32Ty(Builder.getContext());
constexpr size_t MaxDim = 3;
Value *ZeroArray = Constant::getNullValue(ArrayType::get(Int32Ty, MaxDim));
- Value *Flags = Builder.getInt64(KernelArgs.HasNoWait);
+
+ Value *HasNoWaitFlag = Builder.getInt64(KernelArgs.HasNoWait);
+
+ Value *DynCGroupMemFallbackFlag =
+ Builder.getInt64(static_cast<uint64_t>(KernelArgs.DynCGroupMemFallback));
+ DynCGroupMemFallbackFlag = Builder.CreateShl(DynCGroupMemFallbackFlag, 2);
+ Value *Flags = Builder.CreateOr(HasNoWaitFlag, DynCGroupMemFallbackFlag);
assert(!KernelArgs.NumTeams.empty() && !KernelArgs.NumThreads.empty());
@@ -559,7 +565,7 @@ void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
Flags,
NumTeams3D,
NumThreads3D,
- KernelArgs.DynCGGroupMem};
+ KernelArgs.DynCGroupMem};
}
void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
@@ -8224,7 +8230,8 @@ static void emitTargetCall(
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB,
const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
- bool HasNoWait) {
+ bool HasNoWait, Value *DynCGroupMem,
+ OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
// Generate a function call to the host fallback implementation of the target
// region. This is called by the host when no offload entry was generated for
// the target region and when the offloading call fails at runtime.
@@ -8360,12 +8367,13 @@ static void emitTargetCall(
/*isSigned=*/false)
: Builder.getInt64(0);
- // TODO: Use correct DynCGGroupMem
- Value *DynCGGroupMem = Builder.getInt32(0);
+ // Request zero groupprivate bytes by default.
+ if (!DynCGroupMem)
+ DynCGroupMem = Builder.getInt32(0);
- KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
- NumTeamsC, NumThreadsC,
- DynCGGroupMem, HasNoWait);
+ KArgs = OpenMPIRBuilder::TargetKernelArgs(
+ NumTargetItems, RTArgs, TripCount, NumTeamsC, NumThreadsC, DynCGroupMem,
+ HasNoWait, DynCGroupMemFallback);
// Assume no error was returned because TaskBodyCB and
// EmitTargetCallFallbackCB don't produce any.
@@ -8414,7 +8422,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
CustomMapperCallbackTy CustomMapperCB,
- const SmallVector<DependData> &Dependencies, bool HasNowait) {
+ const SmallVector<DependData> &Dependencies, bool HasNowait,
+ Value *DynCGroupMem, OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
if (!updateToLocation(Loc))
return InsertPointTy();
@@ -8437,7 +8446,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
if (!Config.isTargetDevice())
emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB,
- CustomMapperCB, Dependencies, HasNowait);
+ CustomMapperCB, Dependencies, HasNowait, DynCGroupMem,
+ DynCGroupMemFallback);
return Builder.saveIP();
}
@@ -8460,9 +8470,8 @@ OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
Config.separator());
}
-GlobalVariable *
-OpenMPIRBuilder::getOrCreateInternalVariable(Type *Ty, const StringRef &Name,
- unsigned AddressSpace) {
+GlobalVariable *OpenMPIRBuilder::getOrCreateInternalVariable(
+ Type *Ty, const StringRef &Name, std::optional<unsigned> AddressSpace) {
auto &Elem = *InternalVars.try_emplace(Name, nullptr).first;
if (Elem.second) {
assert(Elem.second->getValueType() == Ty &&
@@ -8472,16 +8481,18 @@ OpenMPIRBuilder::getOrCreateInternalVariable(Type *Ty, const StringRef &Name,
// variable for possibly changing that to internal or private, or maybe
// create different versions of the function for different OMP internal
// variables.
+ const DataLayout &DL = M.getDataLayout();
+ unsigned AddressSpaceVal =
+ AddressSpace ? *AddressSpace : DL.getDefaultGlobalsAddressSpace();
auto Linkage = this->M.getTargetTriple().getArch() == Triple::wasm32
? GlobalValue::InternalLinkage
: GlobalValue::CommonLinkage;
auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage,
Constant::getNullValue(Ty), Elem.first(),
/*InsertBefore=*/nullptr,
- GlobalValue::NotThreadLocal, AddressSpace);
- const DataLayout &DL = M.getDataLayout();
+ GlobalValue::NotThreadLocal, AddressSpaceVal);
const llvm::Align TypeAlign = DL.getABITypeAlign(Ty);
- const llvm::Align PtrAlign = DL.getPointerABIAlignment(AddressSpace);
+ const llvm::Align PtrAlign = DL.getPointerABIAlignment(AddressSpaceVal);
GV->setAlignment(std::max(TypeAlign, PtrAlign));
Elem.second = GV;
}