diff options
author | Akash Banerjee <Akash.Banerjee@amd.com> | 2023-06-19 12:46:15 +0100 |
---|---|---|
committer | Akash Banerjee <Akash.Banerjee@amd.com> | 2023-06-19 13:09:35 +0100 |
commit | a032dc139ddaf4bfccdc4d2dfe073411118cb7e0 (patch) | |
tree | 07867ee3da3ba603bc9138b2a84323971e5f70aa /llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | |
parent | 7e229217f4215b519b886e7881bae4da3742a7d2 (diff) | |
download | llvm-a032dc139ddaf4bfccdc4d2dfe073411118cb7e0.zip llvm-a032dc139ddaf4bfccdc4d2dfe073411118cb7e0.tar.gz llvm-a032dc139ddaf4bfccdc4d2dfe073411118cb7e0.tar.bz2 |
[MLIR][OpenMP] Refactoring createTargetData in OMPIRBuilder
Key changes:
- Refactor the createTargetData function to make use of the emitOffloadingArrays and emitOffloadingArraysArgument functions to generate code.
- Added a new emitIfClause helper function to allow handling if clauses in a similar fashion to Clang.
- Updated the MLIR side of code to account for changes to createTargetData.
Depends on D149872
Reviewed By: jdoerfert
Differential Revision: https://reviews.llvm.org/D146557
Diffstat (limited to 'llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp')
-rw-r--r-- | llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | 214 |
1 files changed, 81 insertions, 133 deletions
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 3897329..776b04d 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4886,18 +4886,9 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) { OMPBuilder.initialize(); F->setName("func"); IRBuilder<> Builder(BB); - OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); - unsigned NumDataOperands = 1; int64_t DeviceID = 2; - struct OpenMPIRBuilder::MapperAllocas MapperAllocas; - SmallVector<uint64_t> MapTypeFlagsTo = {1}; - SmallVector<Constant *> MapNames; - auto *I8PtrTy = Builder.getInt8PtrTy(); - auto *ArrI8PtrTy = ArrayType::get(I8PtrTy, NumDataOperands); - auto *I64Ty = Builder.getInt64Ty(); - auto *ArrI64Ty = ArrayType::get(I64Ty, NumDataOperands); AllocaInst *Val1 = Builder.CreateAlloca(Builder.getInt32Ty(), Builder.getInt64(1)); @@ -4905,44 +4896,34 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) { IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(), F->getEntryBlock().getFirstInsertionPt()); - OMPBuilder.createMapperAllocas(Builder.saveIP(), AllocaIP, NumDataOperands, - MapperAllocas); + llvm::OpenMPIRBuilder::MapInfosTy CombinedInfo; using InsertPointTy = OpenMPIRBuilder::InsertPointTy; - auto ProcessMapOpCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) { - Value *DataValue = Val1; - Value *DataPtrBase; - Value *DataPtr; - DataPtrBase = DataValue; - DataPtr = DataValue; - Builder.restoreIP(CodeGenIP); + auto GenMapInfoCB = + [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & { + // Get map clause information. + Builder.restoreIP(codeGenIP); - Value *Null = Constant::getNullValue(DataValue->getType()->getPointerTo()); - Value *SizeGep = - Builder.CreateGEP(DataValue->getType(), Null, Builder.getInt32(1)); - Value *SizePtrToInt = Builder.CreatePtrToInt(SizeGep, I64Ty); - - Value *PtrBaseGEP = - Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.ArgsBase, - {Builder.getInt32(0), Builder.getInt32(0)}); - Value *PtrBaseCast = Builder.CreateBitCast( - PtrBaseGEP, DataPtrBase->getType()->getPointerTo()); - Builder.CreateStore(DataPtrBase, PtrBaseCast); - Value *PtrGEP = - Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.Args, - {Builder.getInt32(0), Builder.getInt32(0)}); - Value *PtrCast = - Builder.CreateBitCast(PtrGEP, DataPtr->getType()->getPointerTo()); - Builder.CreateStore(DataPtr, PtrCast); - Value *SizeGEP = - Builder.CreateInBoundsGEP(ArrI64Ty, MapperAllocas.ArgSizes, - {Builder.getInt32(0), Builder.getInt32(0)}); - Builder.CreateStore(SizePtrToInt, SizeGEP); + CombinedInfo.BasePointers.emplace_back(Val1); + CombinedInfo.Pointers.emplace_back(Val1); + CombinedInfo.Sizes.emplace_back(Builder.getInt64(4)); + CombinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(1)); + uint32_t temp; + CombinedInfo.Names.emplace_back( + OMPBuilder.getOrCreateSrcLocStr("unknown", temp)); + return CombinedInfo; }; + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); + + OMPBuilder.Config.setIsTargetCodegen(true); + + llvm::omp::RuntimeFunction RTLFunc = OMPRTL___tgt_target_data_begin_mapper; Builder.restoreIP(OMPBuilder.createTargetData( - Loc, Builder.saveIP(), MapTypeFlagsTo, MapNames, MapperAllocas, - /* IsBegin= */ true, DeviceID, /* IfCond= */ nullptr, ProcessMapOpCB)); + Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID), + /* IfCond= */ nullptr, Info, GenMapInfoCB, &RTLFunc)); CallInst *TargetDataCall = dyn_cast<CallInst>(&BB->back()); EXPECT_NE(TargetDataCall, nullptr); @@ -4962,18 +4943,9 @@ TEST_F(OpenMPIRBuilderTest, TargetExitData) { OMPBuilder.initialize(); F->setName("func"); IRBuilder<> Builder(BB); - OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); - unsigned NumDataOperands = 1; int64_t DeviceID = 2; - struct OpenMPIRBuilder::MapperAllocas MapperAllocas; - SmallVector<uint64_t> MapTypeFlagsFrom = {2}; - SmallVector<Constant *> MapNames; - auto *I8PtrTy = Builder.getInt8PtrTy(); - auto *ArrI8PtrTy = ArrayType::get(I8PtrTy, NumDataOperands); - auto *I64Ty = Builder.getInt64Ty(); - auto *ArrI64Ty = ArrayType::get(I64Ty, NumDataOperands); AllocaInst *Val1 = Builder.CreateAlloca(Builder.getInt32Ty(), Builder.getInt64(1)); @@ -4981,44 +4953,34 @@ TEST_F(OpenMPIRBuilderTest, TargetExitData) { IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(), F->getEntryBlock().getFirstInsertionPt()); - OMPBuilder.createMapperAllocas(Builder.saveIP(), AllocaIP, NumDataOperands, - MapperAllocas); + llvm::OpenMPIRBuilder::MapInfosTy CombinedInfo; using InsertPointTy = OpenMPIRBuilder::InsertPointTy; - auto ProcessMapOpCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) { - Value *DataValue = Val1; - Value *DataPtrBase; - Value *DataPtr; - DataPtrBase = DataValue; - DataPtr = DataValue; - Builder.restoreIP(CodeGenIP); + auto GenMapInfoCB = + [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & { + // Get map clause information. + Builder.restoreIP(codeGenIP); - Value *Null = Constant::getNullValue(DataValue->getType()->getPointerTo()); - Value *SizeGep = - Builder.CreateGEP(DataValue->getType(), Null, Builder.getInt32(1)); - Value *SizePtrToInt = Builder.CreatePtrToInt(SizeGep, I64Ty); - - Value *PtrBaseGEP = - Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.ArgsBase, - {Builder.getInt32(0), Builder.getInt32(0)}); - Value *PtrBaseCast = Builder.CreateBitCast( - PtrBaseGEP, DataPtrBase->getType()->getPointerTo()); - Builder.CreateStore(DataPtrBase, PtrBaseCast); - Value *PtrGEP = - Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.Args, - {Builder.getInt32(0), Builder.getInt32(0)}); - Value *PtrCast = - Builder.CreateBitCast(PtrGEP, DataPtr->getType()->getPointerTo()); - Builder.CreateStore(DataPtr, PtrCast); - Value *SizeGEP = - Builder.CreateInBoundsGEP(ArrI64Ty, MapperAllocas.ArgSizes, - {Builder.getInt32(0), Builder.getInt32(0)}); - Builder.CreateStore(SizePtrToInt, SizeGEP); + CombinedInfo.BasePointers.emplace_back(Val1); + CombinedInfo.Pointers.emplace_back(Val1); + CombinedInfo.Sizes.emplace_back(Builder.getInt64(4)); + CombinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(2)); + uint32_t temp; + CombinedInfo.Names.emplace_back( + OMPBuilder.getOrCreateSrcLocStr("unknown", temp)); + return CombinedInfo; }; + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); + + OMPBuilder.Config.setIsTargetCodegen(true); + + llvm::omp::RuntimeFunction RTLFunc = OMPRTL___tgt_target_data_end_mapper; Builder.restoreIP(OMPBuilder.createTargetData( - Loc, Builder.saveIP(), MapTypeFlagsFrom, MapNames, MapperAllocas, - /* IsBegin= */ false, DeviceID, /* IfCond= */ nullptr, ProcessMapOpCB)); + Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID), + /* IfCond= */ nullptr, Info, GenMapInfoCB, &RTLFunc)); CallInst *TargetDataCall = dyn_cast<CallInst>(&BB->back()); EXPECT_NE(TargetDataCall, nullptr); @@ -5038,18 +5000,9 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) { OMPBuilder.initialize(); F->setName("func"); IRBuilder<> Builder(BB); - OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); - unsigned NumDataOperands = 1; int64_t DeviceID = 2; - struct OpenMPIRBuilder::MapperAllocas MapperAllocas; - SmallVector<uint64_t> MapTypeFlagsToFrom = {3}; - SmallVector<Constant *> MapNames; - auto *I8PtrTy = Builder.getInt8PtrTy(); - auto *ArrI8PtrTy = ArrayType::get(I8PtrTy, NumDataOperands); - auto *I64Ty = Builder.getInt64Ty(); - auto *ArrI64Ty = ArrayType::get(I64Ty, NumDataOperands); AllocaInst *Val1 = Builder.CreateAlloca(Builder.getInt32Ty(), Builder.getInt64(1)); @@ -5057,57 +5010,52 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) { IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(), F->getEntryBlock().getFirstInsertionPt()); - OMPBuilder.createMapperAllocas(Builder.saveIP(), AllocaIP, NumDataOperands, - MapperAllocas); + llvm::OpenMPIRBuilder::MapInfosTy CombinedInfo; using InsertPointTy = OpenMPIRBuilder::InsertPointTy; - auto ProcessMapOpCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) { - Value *DataValue = Val1; - Value *DataPtrBase; - Value *DataPtr; - DataPtrBase = DataValue; - DataPtr = DataValue; - Builder.restoreIP(CodeGenIP); + auto GenMapInfoCB = + [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & { + // Get map clause information. + Builder.restoreIP(codeGenIP); - Value *Null = Constant::getNullValue(DataValue->getType()->getPointerTo()); - Value *SizeGep = - Builder.CreateGEP(DataValue->getType(), Null, Builder.getInt32(1)); - Value *SizePtrToInt = Builder.CreatePtrToInt(SizeGep, I64Ty); - - Value *PtrBaseGEP = - Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.ArgsBase, - {Builder.getInt32(0), Builder.getInt32(0)}); - Value *PtrBaseCast = Builder.CreateBitCast( - PtrBaseGEP, DataPtrBase->getType()->getPointerTo()); - Builder.CreateStore(DataPtrBase, PtrBaseCast); - Value *PtrGEP = - Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.Args, - {Builder.getInt32(0), Builder.getInt32(0)}); - Value *PtrCast = - Builder.CreateBitCast(PtrGEP, DataPtr->getType()->getPointerTo()); - Builder.CreateStore(DataPtr, PtrCast); - Value *SizeGEP = - Builder.CreateInBoundsGEP(ArrI64Ty, MapperAllocas.ArgSizes, - {Builder.getInt32(0), Builder.getInt32(0)}); - Builder.CreateStore(SizePtrToInt, SizeGEP); + CombinedInfo.BasePointers.emplace_back(Val1); + CombinedInfo.Pointers.emplace_back(Val1); + CombinedInfo.Sizes.emplace_back(Builder.getInt64(4)); + CombinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(3)); + uint32_t temp; + CombinedInfo.Names.emplace_back( + OMPBuilder.getOrCreateSrcLocStr("unknown", temp)); + return CombinedInfo; }; - auto BodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) { - Builder.restoreIP(codeGenIP); - auto *SI = Builder.CreateStore(Builder.getInt32(99), Val1); - auto *newBB = SplitBlock(Builder.GetInsertBlock(), SI); - Builder.SetInsertPoint(newBB); - auto *UI = &Builder.GetInsertBlock()->back(); - SplitBlock(Builder.GetInsertBlock(), UI); + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); + + OMPBuilder.Config.setIsTargetCodegen(true); + + auto BodyCB = [&](InsertPointTy CodeGenIP, int BodyGenType) { + if (BodyGenType == 3) { + Builder.restoreIP(CodeGenIP); + CallInst *TargetDataCall = dyn_cast<CallInst>(&BB->back()); + EXPECT_NE(TargetDataCall, nullptr); + EXPECT_EQ(TargetDataCall->arg_size(), 9U); + EXPECT_EQ(TargetDataCall->getCalledFunction()->getName(), + "__tgt_target_data_begin_mapper"); + EXPECT_TRUE(TargetDataCall->getOperand(1)->getType()->isIntegerTy(64)); + EXPECT_TRUE(TargetDataCall->getOperand(2)->getType()->isIntegerTy(32)); + EXPECT_TRUE(TargetDataCall->getOperand(8)->getType()->isPointerTy()); + Builder.restoreIP(CodeGenIP); + Builder.CreateStore(Builder.getInt32(99), Val1); + } + return Builder.saveIP(); }; Builder.restoreIP(OMPBuilder.createTargetData( - Loc, Builder.saveIP(), MapTypeFlagsToFrom, MapNames, MapperAllocas, - /* IsBegin= */ false, DeviceID, /* IfCond= */ nullptr, ProcessMapOpCB, - BodyCB)); + Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID), + /* IfCond= */ nullptr, Info, GenMapInfoCB, nullptr, BodyCB)); - CallInst *TargetDataCall = - dyn_cast<CallInst>(&Builder.GetInsertBlock()->back()); + CallInst *TargetDataCall = dyn_cast<CallInst>(&BB->back()); EXPECT_NE(TargetDataCall, nullptr); EXPECT_EQ(TargetDataCall->arg_size(), 9U); EXPECT_EQ(TargetDataCall->getCalledFunction()->getName(), |