aboutsummaryrefslogtreecommitdiff
path: root/offload/plugins-nextgen/level_zero/include/L0Kernel.h
blob: 1d5a014d9d0a56c12ffc3f1405770bd15cd697a9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
//===--- Level Zero Target RTL Implementation -----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// GenericKernel implementation for SPIR-V/Xe machine.
//
//===----------------------------------------------------------------------===//

#ifndef OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_LEVEL_ZERO_L0KERNEL_H
#define OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_LEVEL_ZERO_L0KERNEL_H

#include "AsyncQueue.h"
#include "L0Defs.h"
#include "L0Trace.h"
#include "PluginInterface.h"

namespace llvm::omp::target::plugin {

class L0DeviceTy;
class L0ProgramTy;

/// Loop descriptor.
struct TgtLoopDescTy {
  int64_t Lb = 0;     // The lower bound of the i-th loop.
  int64_t Ub = 0;     // The upper bound of the i-th loop.
  int64_t Stride = 0; // The stride of the i-th loop.

  bool operator==(const TgtLoopDescTy &other) const {
    return Lb == other.Lb && Ub == other.Ub && Stride == other.Stride;
  }
};

struct TgtNDRangeDescTy {
  int32_t NumLoops = 0;      // Number of loops/dimensions.
  int32_t DistributeDim = 0; // Dimensions lower than this one
                             // must end up in one WG.
  TgtLoopDescTy Levels[3];   // Up to 3 loops.

  bool operator==(const TgtNDRangeDescTy &other) const {
    return NumLoops == other.NumLoops && DistributeDim == other.DistributeDim &&
           std::equal(Levels, Levels + 3, other.Levels);
  }
  bool operator!=(const TgtNDRangeDescTy &other) const {
    return !(*this == other);
  }
};

/// Forward declaration.
struct L0LaunchEnvTy;

/// Kernel properties.
struct KernelPropertiesTy {
  uint32_t Width = 0;
  uint32_t SIMDWidth = 0;
  uint32_t MaxThreadGroupSize = 0;

  /// Cached input parameters used in the previous launch.
  int32_t NumTeams = -1;
  int32_t ThreadLimit = -1;

  /// Cached parameters used in the previous launch.
  ze_kernel_indirect_access_flags_t IndirectAccessFlags =
      std::numeric_limits<decltype(IndirectAccessFlags)>::max();
  uint32_t GroupSizes[3] = {0, 0, 0};
  ze_group_count_t GroupCounts{0, 0, 0};

  std::mutex Mtx;

  /// Check if we can reuse group parameters.
  bool reuseGroupParams(const int32_t NumTeamsIn, const int32_t ThreadLimitIn,
                        uint32_t *GroupSizesOut, L0LaunchEnvTy &KEnv) const;

  /// Update cached group parameters.
  void cacheGroupParams(const int32_t NumTeamsIn, const int32_t ThreadLimitIn,
                        const uint32_t *GroupSizesIn, L0LaunchEnvTy &KEnv);
};

struct L0LaunchEnvTy {
  bool IsAsync;
  AsyncQueueTy *AsyncQueue;
  ze_group_count_t GroupCounts = {0, 0, 0};
  KernelPropertiesTy &KernelPR;
  bool HalfNumThreads = false;
  bool IsTeamsNDRange = false;

  L0LaunchEnvTy(bool IsAsync, AsyncQueueTy *AsyncQueue,
                KernelPropertiesTy &KernelPR)
      : IsAsync(IsAsync), AsyncQueue(AsyncQueue), KernelPR(KernelPR) {}
};

class L0KernelTy : public GenericKernelTy {
  // L0 Kernel Handle.
  ze_kernel_handle_t zeKernel;
  // Kernel Properties.
  mutable KernelPropertiesTy Properties;

  void decideKernelGroupArguments(L0DeviceTy &Device, uint32_t NumTeams,
                                  uint32_t ThreadLimit, uint32_t *GroupSizes,
                                  L0LaunchEnvTy &KEnv) const;

  Error buildKernel(L0ProgramTy &Program);
  Error readKernelProperties(L0ProgramTy &Program);

  Error setKernelGroups(L0DeviceTy &l0Device, L0LaunchEnvTy &KEnv,
                        uint32_t NumThreads[3], uint32_t NumBlocks[3]) const;
  Error setIndirectFlags(L0DeviceTy &l0Device, L0LaunchEnvTy &KEnv) const;

public:
  /// Create a L0 kernel with a name and an execution mode.
  L0KernelTy(const char *Name) : GenericKernelTy(Name), zeKernel(nullptr) {}
  ~L0KernelTy() = default;
  L0KernelTy(const L0KernelTy &) = delete;
  L0KernelTy(L0KernelTy &&) = delete;
  L0KernelTy &operator=(const L0KernelTy &) = delete;
  L0KernelTy &operator=(const L0KernelTy &&) = delete;

  KernelPropertiesTy &getProperties() const { return Properties; }

  /// Initialize the L0 kernel.
  Error initImpl(GenericDeviceTy &GenericDevice, DeviceImageTy &Image) override;
  /// Launch the L0 kernel function.
  Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads[3],
                   uint32_t NumBlocks[3], KernelArgsTy &KernelArgs,
                   KernelLaunchParamsTy LaunchParams,
                   AsyncInfoWrapperTy &AsyncInfoWrapper) const override;
  Error deinit() {
    CALL_ZE_RET_ERROR(zeKernelDestroy, zeKernel);
    return Plugin::success();
  }

  Expected<uint64_t> maxGroupSize(GenericDeviceTy &GenericDevice,
                                  uint64_t DynamicMemSize) const override {
    return Plugin::error(ErrorCode::UNIMPLEMENTED,
                         "maxGroupSize not implemented yet");
  }

  ze_kernel_handle_t getZeKernel() const { return zeKernel; }

  Error getGroupsShape(L0DeviceTy &Device, int32_t NumTeams,
                       int32_t ThreadLimit, uint32_t *GroupSizes,
                       L0LaunchEnvTy &KEnv) const;
};

} // namespace llvm::omp::target::plugin

#endif // OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_LEVEL_ZERO_L0KERNEL_H