aboutsummaryrefslogtreecommitdiff
path: root/offload/unittests/OffloadAPI/common/Environment.cpp
blob: f07a66cda21892804783887568b2eb074d2987fa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
//===------- Offload API tests - gtest environment ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "Environment.hpp"
#include "Fixtures.hpp"
#include "llvm/Support/CommandLine.h"
#include <OffloadAPI.h>

using namespace llvm;

// Wrapper so we don't have to constantly init and shutdown Offload in every
// test, while having sensible lifetime for the platform environment
struct OffloadInitWrapper {
  OffloadInitWrapper() { olInit(); }
  ~OffloadInitWrapper() { olShutDown(); }
};
static OffloadInitWrapper Wrapper{};

static cl::opt<std::string>
    SelectedPlatform("platform", cl::desc("Only test the specified platform"),
                     cl::value_desc("platform"));

std::ostream &operator<<(std::ostream &Out,
                         const ol_platform_handle_t &Platform) {
  size_t Size;
  olGetPlatformInfoSize(Platform, OL_PLATFORM_INFO_NAME, &Size);
  std::vector<char> Name(Size);
  olGetPlatformInfo(Platform, OL_PLATFORM_INFO_NAME, Size, Name.data());
  Out << Name.data();
  return Out;
}

std::ostream &operator<<(std::ostream &Out,
                         const std::vector<ol_platform_handle_t> &Platforms) {
  for (auto Platform : Platforms) {
    Out << "\n  * \"" << Platform << "\"";
  }
  return Out;
}

const std::vector<ol_platform_handle_t> &TestEnvironment::getPlatforms() {
  static std::vector<ol_platform_handle_t> Platforms{};

  if (Platforms.empty()) {
    uint32_t PlatformCount = 0;
    olGetPlatformCount(&PlatformCount);
    if (PlatformCount > 0) {
      Platforms.resize(PlatformCount);
      olGetPlatform(PlatformCount, Platforms.data());
    }
  }

  return Platforms;
}

// Get a single platform, which may be selected by the user.
ol_platform_handle_t TestEnvironment::getPlatform() {
  static ol_platform_handle_t Platform = nullptr;
  const auto &Platforms = getPlatforms();

  if (!Platform) {
    if (SelectedPlatform != "") {
      for (const auto CandidatePlatform : Platforms) {
        std::stringstream PlatformName;
        PlatformName << CandidatePlatform;
        if (SelectedPlatform == PlatformName.str()) {
          Platform = CandidatePlatform;
          return Platform;
        }
      }
      std::cout << "No platform found with the name \"" << SelectedPlatform
                << "\". Choose from:" << Platforms << "\n";
      std::exit(1);
    } else {
      // Pick a single platform. We prefer one that has available devices, but
      // just pick the first initially in case none have any devices.
      Platform = Platforms[0];
      for (auto CandidatePlatform : Platforms) {
        uint32_t NumDevices = 0;
        if (olGetDeviceCount(CandidatePlatform, &NumDevices) == OL_SUCCESS) {
          if (NumDevices > 0) {
            Platform = CandidatePlatform;
            break;
          }
        }
      }
    }
  }

  return Platform;
}