aboutsummaryrefslogtreecommitdiff
path: root/offload/plugins-nextgen/level_zero/include/L0Plugin.h
blob: cd964a0d468907ea5431a711a39fab9fd29bce85 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
//===--- Level Zero Target RTL Implementation -----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Plugin interface for SPIR-V/Xe machine.
//
//===----------------------------------------------------------------------===//

#ifndef OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_LEVEL_ZERO_L0PLUGIN_H
#define OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_LEVEL_ZERO_L0PLUGIN_H

#include "AsyncQueue.h"
#include "L0Defs.h"
#include "L0Device.h"
#include "L0Memory.h"
#include "L0Options.h"
#include "L0Program.h"
#include "TLS.h"

namespace llvm::omp::target::plugin {

/// Class implementing the LevelZero specific functionalities of the plugin.
class LevelZeroPluginTy final : public GenericPluginTy {
private:
  struct DeviceInfoTy {
    L0DeviceIdTy Id;
    L0ContextTy *Driver;
    bool isRoot() const { return Id.SubId < 0 && Id.CCSId < 0; }
  };
  llvm::SmallVector<DeviceInfoTy> DetectedDevices;

  /// Context (and Driver) specific data.
  std::list<L0ContextTy> ContextList;

  // Table containing per-thread information using TLS.
  L0ThreadTblTy ThreadTLSTable;
  // Table containing per-thread information for each device using TLS.
  L0DeviceTLSTableTy DeviceTLSTable;
  // Table containing per-thread information for each Context using TLS.
  L0ContextTLSTableTy ContextTLSTable;

  /// L0 plugin options.
  L0OptionsTy Options;

  /// Common pool of AsyncQueue.
  AsyncQueuePoolTy AsyncQueuePool;

  L0ThreadTLSTy &getTLS() { return ThreadTLSTable.get(); }

  /// Find L0 devices and initialize device properties.
  /// Returns number of devices reported to omptarget.
  Expected<int32_t> findDevices();

public:
  LevelZeroPluginTy() : GenericPluginTy(getTripleArch()) {}
  virtual ~LevelZeroPluginTy() = default;

  L0DeviceTLSTy &getDeviceTLS(int32_t DeviceId) {
    return DeviceTLSTable.get(DeviceId);
  }
  L0ContextTLSTy &getContextTLS(ze_context_handle_t Context) {
    return ContextTLSTable.get(Context);
  }

  const L0OptionsTy &getOptions() { return Options; }

  const L0DeviceTy &getDeviceFromId(int32_t DeviceId) const {
    return static_cast<const L0DeviceTy &>(getDevice(DeviceId));
  }
  L0DeviceTy &getDeviceFromId(int32_t DeviceId) {
    return static_cast<L0DeviceTy &>(getDevice(DeviceId));
  }

  AsyncQueueTy *getAsyncQueue() {
    auto *Queue = getTLS().getAsyncQueue();
    if (!Queue)
      Queue = AsyncQueuePool.get();
    return Queue;
  }

  void releaseAsyncQueue(AsyncQueueTy *Queue) {
    if (!Queue)
      return;
    Queue->reset();
    if (!getTLS().releaseAsyncQueue(Queue))
      AsyncQueuePool.release(Queue);
  }

  // Plugin interface.
  Expected<int32_t> initImpl() override;
  Error deinitImpl() override;
  GenericDeviceTy *createDevice(GenericPluginTy &Plugin, int32_t DeviceId,
                                int32_t NumDevices) override;
  GenericGlobalHandlerTy *createGlobalHandler() override;

  uint16_t getMagicElfBits() const override { return ELF::EM_INTELGT; }
  Triple::ArchType getTripleArch() const override { return Triple::spirv64; }
  const char *getName() const override { return GETNAME(TARGET_NAME); }

  Expected<bool> isELFCompatible(uint32_t DeviceId,
                                 StringRef Image) const override;

  Error flushQueueImpl(omp_interop_val_t *Interop) override;
  Error syncBarrierImpl(omp_interop_val_t *Interop) override;
  Error asyncBarrierImpl(omp_interop_val_t *Interop) override;
};

} // namespace llvm::omp::target::plugin

#endif // OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_LEVEL_ZERO_L0PLUGIN_H