aboutsummaryrefslogtreecommitdiff
path: root/llvm/unittests/ExecutionEngine/Orc/OrcTestCommon.h
blob: 469de2a3665a069edf5243818b21bab0317358a8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
//===------ OrcTestCommon.h - Utilities for Orc Unit Tests ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Common utilities for the Orc unit tests.
//
//===----------------------------------------------------------------------===//


#ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_ORCTESTCOMMON_H
#define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_ORCTESTCOMMON_H

#include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/ExecutionEngine/Orc/Core.h"
#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
#include "llvm/ExecutionEngine/Orc/InProcessMemoryAccess.h"
#include "llvm/ExecutionEngine/Orc/IndirectionUtils.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Object/ObjectFile.h"
#include "llvm/Support/TargetSelect.h"
#include "gtest/gtest.h"

#include <memory>

namespace llvm {

namespace orc {
// CoreAPIsStandardTest that saves a bunch of boilerplate by providing the
// following:
//
// (1) ES -- An ExecutionSession
// (2) Foo, Bar, Baz, Qux -- SymbolStringPtrs for strings "foo", "bar", "baz",
//     and "qux" respectively.
// (3) FooAddr, BarAddr, BazAddr, QuxAddr -- Dummy addresses. Guaranteed
//     distinct and non-null.
// (4) FooSym, BarSym, BazSym, QuxSym -- JITEvaluatedSymbols with FooAddr,
//     BarAddr, BazAddr, and QuxAddr respectively. All with default strong,
//     linkage and non-hidden visibility.
// (5) V -- A JITDylib associated with ES.
class CoreAPIsBasedStandardTest : public testing::Test {
public:
  ~CoreAPIsBasedStandardTest() {
    if (auto Err = ES.endSession())
      ES.reportError(std::move(Err));
  }

protected:
  class OverridableDispatcher : public InPlaceTaskDispatcher {
  public:
    OverridableDispatcher(CoreAPIsBasedStandardTest &Parent) : Parent(Parent) {}
    void dispatch(std::unique_ptr<Task> T) override;

  private:
    CoreAPIsBasedStandardTest &Parent;
  };

  std::unique_ptr<llvm::orc::ExecutorProcessControl>
  makeEPC(std::shared_ptr<SymbolStringPool> SSP);

  std::shared_ptr<SymbolStringPool> SSP = std::make_shared<SymbolStringPool>();
  ExecutionSession ES{makeEPC(SSP)};
  JITDylib &JD = ES.createBareJITDylib("JD");
  SymbolStringPtr Foo = ES.intern("foo");
  SymbolStringPtr Bar = ES.intern("bar");
  SymbolStringPtr Baz = ES.intern("baz");
  SymbolStringPtr Qux = ES.intern("qux");
  static constexpr ExecutorAddr FooAddr{1};
  static constexpr ExecutorAddr BarAddr{2};
  static constexpr ExecutorAddr BazAddr{3};
  static constexpr ExecutorAddr QuxAddr{4};
  ExecutorSymbolDef FooSym{FooAddr, JITSymbolFlags::Exported};
  ExecutorSymbolDef BarSym{BarAddr, JITSymbolFlags::Exported};
  ExecutorSymbolDef BazSym{BazAddr, JITSymbolFlags::Exported};
  ExecutorSymbolDef QuxSym{QuxAddr, JITSymbolFlags::Exported};
  unique_function<void(std::unique_ptr<Task>)> DispatchOverride;
};

/// A ExecutorProcessControl instance that asserts if any of its methods are
/// used. Suitable for use is unit tests, and by ORC clients who haven't moved
/// to ExecutorProcessControl-based APIs yet.
class UnsupportedExecutorProcessControl : public ExecutorProcessControl,
                                          private InProcessMemoryAccess {
public:
  UnsupportedExecutorProcessControl(
      std::shared_ptr<SymbolStringPool> SSP = nullptr,
      std::unique_ptr<TaskDispatcher> D = nullptr, const std::string &TT = "",
      unsigned PageSize = 0)
      : ExecutorProcessControl(
            SSP ? std::move(SSP) : std::make_shared<SymbolStringPool>(),
            D ? std::move(D) : std::make_unique<InPlaceTaskDispatcher>()),
        InProcessMemoryAccess(Triple(TT).isArch64Bit()) {
    this->TargetTriple = Triple(TT);
    this->PageSize = PageSize;
    this->MemAccess = this;
  }

  Expected<int32_t> runAsMain(ExecutorAddr MainFnAddr,
                              ArrayRef<std::string> Args) override {
    llvm_unreachable("Unsupported");
  }

  Expected<int32_t> runAsVoidFunction(ExecutorAddr VoidFnAddr) override {
    llvm_unreachable("Unsupported");
  }

  Expected<int32_t> runAsIntFunction(ExecutorAddr IntFnAddr, int Arg) override {
    llvm_unreachable("Unsupported");
  }

  void callWrapperAsync(ExecutorAddr WrapperFnAddr,
                        IncomingWFRHandler OnComplete,
                        ArrayRef<char> ArgBuffer) override {
    llvm_unreachable("Unsupported");
  }

  Error disconnect() override { return Error::success(); }
};

} // end namespace orc

class OrcNativeTarget {
public:
  static void initialize() {
    if (!NativeTargetInitialized) {
      InitializeNativeTarget();
      InitializeNativeTargetAsmParser();
      InitializeNativeTargetAsmPrinter();
      NativeTargetInitialized = true;
    }
  }

private:
  static bool NativeTargetInitialized;
};

class SimpleMaterializationUnit : public orc::MaterializationUnit {
public:
  using MaterializeFunction =
      std::function<void(std::unique_ptr<orc::MaterializationResponsibility>)>;
  using DiscardFunction =
      std::function<void(const orc::JITDylib &, orc::SymbolStringPtr)>;
  using DestructorFunction = std::function<void()>;

  SimpleMaterializationUnit(
      orc::SymbolFlagsMap SymbolFlags, MaterializeFunction Materialize,
      orc::SymbolStringPtr InitSym = nullptr,
      DiscardFunction Discard = DiscardFunction(),
      DestructorFunction Destructor = DestructorFunction())
      : MaterializationUnit(
            Interface(std::move(SymbolFlags), std::move(InitSym))),
        Materialize(std::move(Materialize)), Discard(std::move(Discard)),
        Destructor(std::move(Destructor)) {}

  ~SimpleMaterializationUnit() override {
    if (Destructor)
      Destructor();
  }

  StringRef getName() const override { return "<Simple>"; }

  void
  materialize(std::unique_ptr<orc::MaterializationResponsibility> R) override {
    Materialize(std::move(R));
  }

  void discard(const orc::JITDylib &JD,
               const orc::SymbolStringPtr &Name) override {
    if (Discard)
      Discard(JD, std::move(Name));
    else
      llvm_unreachable("Discard not supported");
  }

private:
  MaterializeFunction Materialize;
  DiscardFunction Discard;
  DestructorFunction Destructor;
};

class ModuleBuilder {
public:
  ModuleBuilder(LLVMContext &Context, StringRef Triple,
                StringRef Name);

  Function *createFunctionDecl(FunctionType *FTy, StringRef Name) {
    return Function::Create(FTy, GlobalValue::ExternalLinkage, Name, M.get());
  }

  Module* getModule() { return M.get(); }
  const Module* getModule() const { return M.get(); }
  std::unique_ptr<Module> takeModule() { return std::move(M); }

private:
  std::unique_ptr<Module> M;
};

// Dummy struct type.
struct DummyStruct {
  int X[256];
};

inline StructType *getDummyStructTy(LLVMContext &Context) {
  return StructType::get(ArrayType::get(Type::getInt32Ty(Context), 256));
}

} // namespace llvm

#endif