diff options
Diffstat (limited to 'mlir/lib/Dialect/LLVMIR')
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 82 |
1 files changed, 72 insertions, 10 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 262d9b7..d43f881 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1752,15 +1752,21 @@ std::string NVVM::MBarrierInitOp::getPtx() { // getIntrinsicID/getIntrinsicIDAndArgs methods //===----------------------------------------------------------------------===// +static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) { + auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType()); + return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS); +} + +static bool isPtrInSharedCTASpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared); +} + mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast<NVVM::MBarrierInitOp>(op); - unsigned addressSpace = - llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) - .getAddressSpace(); - llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) - ? llvm::Intrinsic::nvvm_mbarrier_init_shared - : llvm::Intrinsic::nvvm_mbarrier_init; + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared + : llvm::Intrinsic::nvvm_mbarrier_init; // Fill the Intrinsic Args llvm::SmallVector<llvm::Value *> args; @@ -1773,16 +1779,72 @@ mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast<NVVM::MBarrierInvalOp>(op); - unsigned addressSpace = - llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) - .getAddressSpace(); - llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_inval_shared : llvm::Intrinsic::nvvm_mbarrier_inval; return {id, {mt.lookupValue(thisOp.getAddr())}}; } +mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared + ? llvm::Intrinsic::nvvm_mbarrier_arrive_shared + : llvm::Intrinsic::nvvm_mbarrier_arrive; + + return {id, {mt.lookupValue(thisOp.getAddr())}}; +} + +mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = + isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared + : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete; + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getCount())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared + ? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared + : llvm::Intrinsic::nvvm_mbarrier_test_wait; + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getState())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + + llvm::Intrinsic::ID id; + if (thisOp.getNoinc()) { + id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared + : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc; + } else { + id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared + : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive; + } + + return {id, {mt.lookupValue(thisOp.getAddr())}}; +} + #define CP_ASYNC_ID_IMPL(mod, size, suffix) \ llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix |
