aboutsummaryrefslogtreecommitdiff
path: root/offload/include/PerThreadTable.h
blob: 45b196171b4c891069bcad1b041bd0d80171ed01 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
//===-- PerThreadTable.h -- PerThread Storage Structure ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Table indexed with one entry per thread.
//
//===----------------------------------------------------------------------===//

#ifndef OFFLOAD_PERTHREADTABLE_H
#define OFFLOAD_PERTHREADTABLE_H

#include <list>
#include <memory>
#include <mutex>

// Using an STL container (such as std::vector) indexed by thread ID has
// too many race conditions issues so we store each thread entry into a
// thread_local variable.
// T is the container type used to store the objects, e.g., std::vector,
// std::set, etc. by each thread. O is the type of the stored objects e.g.,
// omp_interop_val_t *, ...

template <typename ContainerType, typename ObjectType> struct PerThreadTable {
  using iterator = typename ContainerType::iterator;

  struct PerThreadData {
    size_t NElements = 0;
    std::unique_ptr<ContainerType> ThEntry;
  };

  std::mutex Mtx;
  std::list<std::shared_ptr<PerThreadData>> ThreadDataList;

  // define default constructors, disable copy and move constructors
  PerThreadTable() = default;
  PerThreadTable(const PerThreadTable &) = delete;
  PerThreadTable(PerThreadTable &&) = delete;
  PerThreadTable &operator=(const PerThreadTable &) = delete;
  PerThreadTable &operator=(PerThreadTable &&) = delete;
  ~PerThreadTable() {
    std::lock_guard<std::mutex> Lock(Mtx);
    ThreadDataList.clear();
  }

private:
  PerThreadData &getThreadData() {
    static thread_local std::shared_ptr<PerThreadData> ThData = nullptr;
    if (!ThData) {
      ThData = std::make_shared<PerThreadData>();
      std::lock_guard<std::mutex> Lock(Mtx);
      ThreadDataList.push_back(ThData);
    }
    return *ThData;
  }

protected:
  ContainerType &getThreadEntry() {
    auto &ThData = getThreadData();
    if (ThData.ThEntry)
      return *ThData.ThEntry;
    ThData.ThEntry = std::make_unique<ContainerType>();
    return *ThData.ThEntry;
  }

  size_t &getThreadNElements() {
    auto &ThData = getThreadData();
    return ThData.NElements;
  }

public:
  void add(ObjectType obj) {
    auto &Entry = getThreadEntry();
    auto &NElements = getThreadNElements();
    NElements++;
    Entry.add(obj);
  }

  iterator erase(iterator it) {
    auto &Entry = getThreadEntry();
    auto &NElements = getThreadNElements();
    NElements--;
    return Entry.erase(it);
  }

  size_t size() { return getThreadNElements(); }

  // Iterators to traverse objects owned by
  // the current thread
  iterator begin() {
    auto &Entry = getThreadEntry();
    return Entry.begin();
  }
  iterator end() {
    auto &Entry = getThreadEntry();
    return Entry.end();
  }

  template <class F> void clear(F f) {
    std::lock_guard<std::mutex> Lock(Mtx);
    for (auto ThData : ThreadDataList) {
      if (!ThData->ThEntry || ThData->NElements == 0)
        continue;
      ThData->ThEntry->clear(f);
      ThData->NElements = 0;
    }
    ThreadDataList.clear();
  }
};

#endif