aboutsummaryrefslogtreecommitdiff
path: root/orc-rt/unittests/SPSWrapperFunctionTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'orc-rt/unittests/SPSWrapperFunctionTest.cpp')
-rw-r--r--orc-rt/unittests/SPSWrapperFunctionTest.cpp80
1 files changed, 26 insertions, 54 deletions
diff --git a/orc-rt/unittests/SPSWrapperFunctionTest.cpp b/orc-rt/unittests/SPSWrapperFunctionTest.cpp
index 7f88ce0..d9f34e6 100644
--- a/orc-rt/unittests/SPSWrapperFunctionTest.cpp
+++ b/orc-rt/unittests/SPSWrapperFunctionTest.cpp
@@ -16,64 +16,12 @@
#include "orc-rt/WrapperFunction.h"
#include "orc-rt/move_only_function.h"
+#include "DirectCaller.h"
+
#include "gtest/gtest.h"
using namespace orc_rt;
-/// Make calls and call result handlers directly on the current thread.
-class DirectCaller {
-private:
- class DirectResultSender {
- public:
- virtual ~DirectResultSender() {}
- virtual void send(orc_rt_SessionRef Session,
- WrapperFunctionBuffer ResultBytes) = 0;
- static void send(orc_rt_SessionRef Session, void *CallCtx,
- orc_rt_WrapperFunctionBuffer ResultBytes) {
- std::unique_ptr<DirectResultSender>(
- reinterpret_cast<DirectResultSender *>(CallCtx))
- ->send(Session, ResultBytes);
- }
- };
-
- template <typename ImplFn>
- class DirectResultSenderImpl : public DirectResultSender {
- public:
- DirectResultSenderImpl(ImplFn &&Fn) : Fn(std::forward<ImplFn>(Fn)) {}
- void send(orc_rt_SessionRef Session,
- WrapperFunctionBuffer ResultBytes) override {
- Fn(Session, std::move(ResultBytes));
- }
-
- private:
- std::decay_t<ImplFn> Fn;
- };
-
- template <typename ImplFn>
- static std::unique_ptr<DirectResultSender>
- makeDirectResultSender(ImplFn &&Fn) {
- return std::make_unique<DirectResultSenderImpl<ImplFn>>(
- std::forward<ImplFn>(Fn));
- }
-
-public:
- DirectCaller(orc_rt_SessionRef Session, orc_rt_WrapperFunction Fn)
- : Session(Session), Fn(Fn) {}
-
- template <typename HandleResultFn>
- void operator()(HandleResultFn &&HandleResult,
- WrapperFunctionBuffer ArgBytes) {
- auto DR =
- makeDirectResultSender(std::forward<HandleResultFn>(HandleResult));
- Fn(Session, reinterpret_cast<void *>(DR.release()),
- DirectResultSender::send, ArgBytes.release());
- }
-
-private:
- orc_rt_SessionRef Session;
- orc_rt_WrapperFunction Fn;
-};
-
static void void_noop_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
@@ -242,6 +190,30 @@ TEST(SPSWrapperFunctionUtilsTest, TransparentSerializationPointers) {
EXPECT_EQ(P, &X);
}
+static void
+expected_int_pointer_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
+ orc_rt_WrapperFunctionReturn Return,
+ orc_rt_WrapperFunctionBuffer ArgBytes) {
+ SPSWrapperFunction<SPSExpected<SPSExecutorAddr>(SPSExecutorAddr)>::handle(
+ Session, CallCtx, Return, ArgBytes,
+ [](move_only_function<void(Expected<int32_t *>)> Return, int32_t *P) {
+ Return(P);
+ });
+}
+
+TEST(SPSWrapperFunctionUtilsTest, TransparentSerializationExpectedPointers) {
+ int X = 42;
+ int *P = nullptr;
+ SPSWrapperFunction<SPSExpected<SPSExecutorAddr>(SPSExecutorAddr)>::call(
+ DirectCaller(nullptr, expected_int_pointer_sps_wrapper),
+ [&](Expected<Expected<int32_t *>> R) {
+ P = cantFail(cantFail(std::move(R)));
+ },
+ &X);
+
+ EXPECT_EQ(P, &X);
+}
+
template <size_t N> struct SPSOpCounter {};
namespace orc_rt {