// RUN: %libomptarget-compilexx-run-and-check-generic // REQUIRES: libc // REQUIRES: gpu #include #include #include #include #include // CHECK: PASS // If the RPC headers are not present, just pass the test. #if !__has_include(<../../libc/shared/rpc.h>) int main() { printf("PASS\n"); } #else #include <../../libc/shared/rpc.h> #include <../../libc/shared/rpc_dispatch.h> [[gnu::weak]] rpc::Client client asm("__llvm_rpc_client"); #pragma omp declare target to(client) device_type(nohost) //===------------------------------------------------------------------------=== // Opcodes. //===------------------------------------------------------------------------=== constexpr uint32_t FOO_OPCODE = 1; constexpr uint32_t VOID_OPCODE = 2; constexpr uint32_t WRITEBACK_OPCODE = 3; constexpr uint32_t CONST_PTR_OPCODE = 4; constexpr uint32_t STRING_OPCODE = 5; constexpr uint32_t EMPTY_OPCODE = 6; constexpr uint32_t DIVERGENT_OPCODE = 7; //===------------------------------------------------------------------------=== // Server-side implementations. //===------------------------------------------------------------------------=== struct S { int arr[4]; }; // 1. Non-pointer arguments, non-void return. int foo(int x, double d, char c) { assert(x == 42); assert(d == 0.0); assert(c == 'c'); return -1; } // 2. Void return type. void void_fn(int x) { assert(x == 7); } // 3. Write-back pointer. void writeback_fn(int *out) { assert(out != nullptr && *out == 42); *out = 99; } // 4. Const pointer. int sum_const(const S *p) { int s = 0; for (int i = 0; i < 4; ++i) s += p->arr[i]; return s; } // 5. const char * string. int c_string(const char *s) { assert(s != nullptr); assert(strcmp(s, "hello") == 0); return strlen(s); } // 6. Empty function. int empty() { return 42; } // 7. Divergent values. void divergent(int *p) { assert(p); *p = *p; } //===------------------------------------------------------------------------=== // RPC client dispatch. //===------------------------------------------------------------------------=== #pragma omp begin declare variant match(device = {kind(gpu)}) int foo(int x, double d, char c) { return rpc::dispatch(client, foo, x, d, c); } void void_fn(int x) { rpc::dispatch(client, void_fn, x); } void writeback_fn(int *out) { rpc::dispatch(client, writeback_fn, out); } int sum_const(const S *p) { return rpc::dispatch(client, sum_const, p); } int c_string(const char *s) { return rpc::dispatch(client, c_string, s); } int empty() { return rpc::dispatch(client, empty); } void divergent(int *p) { rpc::dispatch(client, divergent, p); } #pragma omp end declare variant //===------------------------------------------------------------------------=== // RPC server dispatch. //===------------------------------------------------------------------------=== template rpc::Status handleOpcodesImpl(rpc::Server::Port &Port) { switch (Port.get_opcode()) { case FOO_OPCODE: rpc::invoke(Port, foo); break; case VOID_OPCODE: rpc::invoke(Port, void_fn); break; case WRITEBACK_OPCODE: rpc::invoke(Port, writeback_fn); break; case CONST_PTR_OPCODE: rpc::invoke(Port, sum_const); break; case STRING_OPCODE: rpc::invoke(Port, c_string); break; case EMPTY_OPCODE: rpc::invoke(Port, empty); break; case DIVERGENT_OPCODE: rpc::invoke(Port, [](int *p) { assert(p); *p = *p; }); break; default: return rpc::RPC_UNHANDLED_OPCODE; } return rpc::RPC_SUCCESS; } static uint32_t handleOpcodes(void *raw, uint32_t numLanes) { rpc::Server::Port &Port = *reinterpret_cast(raw); if (numLanes == 1) return handleOpcodesImpl<1>(Port); else if (numLanes == 32) return handleOpcodesImpl<32>(Port); else if (numLanes == 64) return handleOpcodesImpl<64>(Port); else return rpc::RPC_ERROR; } extern "C" void __tgt_register_rpc_callback(unsigned (*callback)(void *, unsigned)); [[gnu::constructor]] void register_callback() { __tgt_register_rpc_callback(&handleOpcodes); } int main() { #pragma omp target #pragma omp parallel num_threads(32) { // 1. Non-pointer return. assert(foo(42, 0.0, 'c') == -1); // 2. Void return. void_fn(7); // 3. Write-back pointer. int value = 42; writeback_fn(&value); assert(value == 99); // 4. Const pointer. S s{1, 2, 3, 4}; int sum = sum_const(&s); assert(sum == 10); // 5. const char * string. const char *msg = "hello"; int len = c_string(msg); assert(len == 5); // 6. No arguments. int ret = empty(); assert(ret == 42); // 7. Divergent values. int id = omp_get_thread_num(); if (id % 2) divergent(&id); assert(id == omp_get_thread_num()); } printf("PASS\n"); } #endif