aboutsummaryrefslogtreecommitdiff
path: root/offload/liboffload/src/Helpers.hpp
blob: 62e55e500fac77deee8077567702f03b912c32f4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
//===- helpers.hpp- GetInfo return helpers for the new LLVM/Offload API ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// The getInfo*/ReturnHelper facilities provide shortcut way of writing return
// data + size for the various getInfo APIs. Based on the equivalent
// implementations in Unified Runtime.
//
//===----------------------------------------------------------------------===//

#include "OffloadAPI.h"
#include "OffloadError.h"
#include "llvm/Support/Error.h"

#include <cstring>

template <typename T, typename Assign>
llvm::Error getInfoImpl(size_t ParamValueSize, void *ParamValue,
                        size_t *ParamValueSizeRet, T Value, size_t ValueSize,
                        Assign &&AssignFunc) {
  if (!ParamValue && !ParamValueSizeRet) {
    return error::createOffloadError(error::ErrorCode::INVALID_NULL_POINTER,
                                     "value and size outputs are nullptr");
  }

  if (ParamValue != nullptr) {
    if (ParamValueSize < ValueSize) {
      return error::createOffloadError(error::ErrorCode::INVALID_SIZE,
                                       "provided size is invalid");
    }
    AssignFunc(ParamValue, Value, ValueSize);
  }

  if (ParamValueSizeRet != nullptr) {
    *ParamValueSizeRet = ValueSize;
  }

  return llvm::Error::success();
}

template <typename T>
llvm::Error getInfo(size_t ParamValueSize, void *ParamValue,
                    size_t *ParamValueSizeRet, T Value) {
  auto Assignment = [](void *ParamValue, T Value, size_t) {
    *static_cast<T *>(ParamValue) = Value;
  };

  return getInfoImpl(ParamValueSize, ParamValue, ParamValueSizeRet, Value,
                     sizeof(T), Assignment);
}

template <typename T>
llvm::Error getInfoArray(size_t array_length, size_t ParamValueSize,
                         void *ParamValue, size_t *ParamValueSizeRet,
                         const T *Value) {
  return getInfoImpl(ParamValueSize, ParamValue, ParamValueSizeRet, Value,
                     array_length * sizeof(T), memcpy);
}

llvm::Error getInfoString(size_t ParamValueSize, void *ParamValue,
                          size_t *ParamValueSizeRet, llvm::StringRef Value) {
  return getInfoArray(Value.size() + 1, ParamValueSize, ParamValue,
                      ParamValueSizeRet, Value.data());
}

class InfoWriter {
public:
  InfoWriter(size_t Size, void *Target, size_t *SizeRet)
      : Size(Size), Target(Target), SizeRet(SizeRet) {};
  InfoWriter() = delete;
  InfoWriter(InfoWriter &) = delete;
  ~InfoWriter() = default;

  template <typename T> llvm::Error write(T Val) {
    return getInfo(Size, Target, SizeRet, Val);
  }

  template <typename T> llvm::Error writeArray(T Val, size_t Elems) {
    return getInfoArray(Elems, Size, Target, SizeRet, Val);
  }

  llvm::Error writeString(llvm::StringRef Val) {
    return getInfoString(Size, Target, SizeRet, Val);
  }

private:
  size_t Size;
  void *Target;
  size_t *SizeRet;
};