//===-- PerThreadTable.h -- PerThread Storage Structure ----*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Table indexed with one entry per thread. // //===----------------------------------------------------------------------===// #ifndef OFFLOAD_PERTHREADTABLE_H #define OFFLOAD_PERTHREADTABLE_H #include #include #include #include #include #include template class PerThread { std::mutex Mutex; llvm::SmallVector> ThreadDataList; ObjectType &getThreadData() { static thread_local std::shared_ptr ThreadData = nullptr; if (!ThreadData) { ThreadData = std::make_shared(); std::lock_guard Lock(Mutex); ThreadDataList.push_back(ThreadData); } return *ThreadData; } public: // Define default constructors, disable copy and move constructors. PerThread() = default; PerThread(const PerThread &) = delete; PerThread(PerThread &&) = delete; PerThread &operator=(const PerThread &) = delete; PerThread &operator=(PerThread &&) = delete; ~PerThread() { assert(Mutex.try_lock() && (Mutex.unlock(), true) && "Cannot be deleted while other threads are adding entries"); ThreadDataList.clear(); } ObjectType &get() { return getThreadData(); } template void clear(ClearFuncTy ClearFunc) { assert(Mutex.try_lock() && (Mutex.unlock(), true) && "Clear cannot be called while other threads are adding entries"); for (std::shared_ptr ThreadData : ThreadDataList) { if (!ThreadData) continue; ClearFunc(*ThreadData); } ThreadDataList.clear(); } }; template struct ContainerConcepts { template class, typename = std::void_t<>> struct has : std::false_type {}; template class Op> struct has>> : std::true_type {}; template using IteratorTypeCheck = typename Ty::iterator; template using MappedTypeCheck = typename Ty::mapped_type; template using ValueTypeCheck = typename Ty::value_type; template using KeyTypeCheck = typename Ty::key_type; template using SizeTypeCheck = typename Ty::size_type; template using ClearCheck = decltype(std::declval().clear()); template using ReserveCheck = decltype(std::declval().reserve(1)); template using ResizeCheck = decltype(std::declval().resize(1)); static constexpr bool hasIterator = has::value; static constexpr bool hasClear = has::value; static constexpr bool isAssociative = has::value; static constexpr bool hasReserve = has::value; static constexpr bool hasResize = has::value; template class, typename = std::void_t<>> struct has_type { using type = void; }; template class Op> struct has_type>> { using type = Op; }; using iterator = typename has_type::type; using value_type = typename std::conditional_t< isAssociative, typename has_type::type, typename has_type::type>; using key_type = typename std::conditional_t< isAssociative, typename has_type::type, typename has_type::type>; }; // Using an STL container (such as std::vector) indexed by thread ID has // too many race conditions issues so we store each thread entry into a // thread_local variable. // ContainerType is the container type used to store the objects, e.g., // std::vector, std::set, etc. by each thread. ObjectType is the type of the // stored objects e.g., omp_interop_val_t *, ... template class PerThreadTable { using iterator = typename ContainerConcepts::iterator; struct PerThreadData { size_t Size = 0; std::unique_ptr ThreadEntry; }; std::mutex Mutex; llvm::SmallVector> ThreadDataList; PerThreadData &getThreadData() { static thread_local std::shared_ptr ThreadData = nullptr; if (!ThreadData) { ThreadData = std::make_shared(); std::lock_guard Lock(Mutex); ThreadDataList.push_back(ThreadData); } return *ThreadData; } protected: ContainerType &getThreadEntry() { PerThreadData &ThreadData = getThreadData(); if (ThreadData.ThreadEntry) return *ThreadData.ThreadEntry; ThreadData.ThreadEntry = std::make_unique(); return *ThreadData.ThreadEntry; } size_t &getThreadSize() { PerThreadData &ThreadData = getThreadData(); return ThreadData.Size; } void setSize(size_t Size) { size_t &SizeRef = getThreadSize(); SizeRef = Size; } public: // define default constructors, disable copy and move constructors. PerThreadTable() = default; PerThreadTable(const PerThreadTable &) = delete; PerThreadTable(PerThreadTable &&) = delete; PerThreadTable &operator=(const PerThreadTable &) = delete; PerThreadTable &operator=(PerThreadTable &&) = delete; ~PerThreadTable() { assert(Mutex.try_lock() && (Mutex.unlock(), true) && "Cannot be deleted while other threads are adding entries"); ThreadDataList.clear(); } void add(ObjectType obj) { ContainerType &Entry = getThreadEntry(); size_t &SizeRef = getThreadSize(); SizeRef++; Entry.add(obj); } iterator erase(iterator it) { ContainerType &Entry = getThreadEntry(); size_t &SizeRef = getThreadSize(); SizeRef--; return Entry.erase(it); } size_t size() { return getThreadSize(); } // Iterators to traverse objects owned by // the current thread. iterator begin() { ContainerType &Entry = getThreadEntry(); return Entry.begin(); } iterator end() { ContainerType &Entry = getThreadEntry(); return Entry.end(); } template void clear(ClearFuncTy ClearFunc) { assert(Mutex.try_lock() && (Mutex.unlock(), true) && "Clear cannot be called while other threads are adding entries"); for (std::shared_ptr ThreadData : ThreadDataList) { if (!ThreadData->ThreadEntry || ThreadData->Size == 0) continue; if constexpr (ContainerConcepts::hasIterator && ContainerConcepts::hasClear) { for (auto &Obj : *ThreadData->ThreadEntry) { if constexpr (ContainerConcepts::isAssociative) { ClearFunc(Obj.second); } else { ClearFunc(Obj); } } ThreadData->ThreadEntry->clear(); } else { static_assert(true, "Container type not supported"); } ThreadData->Size = 0; } ThreadDataList.clear(); } template llvm::Error deinit(DeinitFuncTy DeinitFunc) { assert(Mutex.try_lock() && (Mutex.unlock(), true) && "Deinit cannot be called while other threads are adding entries"); for (std::shared_ptr ThreadData : ThreadDataList) { if (!ThreadData->ThreadEntry || ThreadData->Size == 0) continue; for (auto &Obj : *ThreadData->ThreadEntry) { if constexpr (ContainerConcepts::isAssociative) { if (auto Err = DeinitFunc(Obj.second)) return Err; } else { if (auto Err = DeinitFunc(Obj)) return Err; } } } return llvm::Error::success(); } }; template class PerThreadContainer : public PerThreadTable::value_type> { using IndexType = typename ContainerConcepts::key_type; using ObjectType = typename ContainerConcepts::value_type; public: // Get the object for the given index in the current thread. ObjectType &get(IndexType Index) { ContainerType &Entry = this->getThreadEntry(); // Specialized code for vector-like containers. if constexpr (ContainerConcepts::hasResize) { if (Index >= Entry.size()) { if constexpr (ContainerConcepts::hasReserve && ReserveSize > 0) Entry.reserve(ReserveSize); // If the index is out of bounds, try resize the container. Entry.resize(Index + 1); } } ObjectType &Ret = Entry[Index]; this->setSize(Entry.size()); return Ret; } }; #endif