diff options
Diffstat (limited to 'orc-rt/unittests/SPSWrapperFunctionTest.cpp')
-rw-r--r-- | orc-rt/unittests/SPSWrapperFunctionTest.cpp | 80 |
1 files changed, 26 insertions, 54 deletions
diff --git a/orc-rt/unittests/SPSWrapperFunctionTest.cpp b/orc-rt/unittests/SPSWrapperFunctionTest.cpp index 7f88ce0..d9f34e6 100644 --- a/orc-rt/unittests/SPSWrapperFunctionTest.cpp +++ b/orc-rt/unittests/SPSWrapperFunctionTest.cpp @@ -16,64 +16,12 @@ #include "orc-rt/WrapperFunction.h" #include "orc-rt/move_only_function.h" +#include "DirectCaller.h" + #include "gtest/gtest.h" using namespace orc_rt; -/// Make calls and call result handlers directly on the current thread. -class DirectCaller { -private: - class DirectResultSender { - public: - virtual ~DirectResultSender() {} - virtual void send(orc_rt_SessionRef Session, - WrapperFunctionBuffer ResultBytes) = 0; - static void send(orc_rt_SessionRef Session, void *CallCtx, - orc_rt_WrapperFunctionBuffer ResultBytes) { - std::unique_ptr<DirectResultSender>( - reinterpret_cast<DirectResultSender *>(CallCtx)) - ->send(Session, ResultBytes); - } - }; - - template <typename ImplFn> - class DirectResultSenderImpl : public DirectResultSender { - public: - DirectResultSenderImpl(ImplFn &&Fn) : Fn(std::forward<ImplFn>(Fn)) {} - void send(orc_rt_SessionRef Session, - WrapperFunctionBuffer ResultBytes) override { - Fn(Session, std::move(ResultBytes)); - } - - private: - std::decay_t<ImplFn> Fn; - }; - - template <typename ImplFn> - static std::unique_ptr<DirectResultSender> - makeDirectResultSender(ImplFn &&Fn) { - return std::make_unique<DirectResultSenderImpl<ImplFn>>( - std::forward<ImplFn>(Fn)); - } - -public: - DirectCaller(orc_rt_SessionRef Session, orc_rt_WrapperFunction Fn) - : Session(Session), Fn(Fn) {} - - template <typename HandleResultFn> - void operator()(HandleResultFn &&HandleResult, - WrapperFunctionBuffer ArgBytes) { - auto DR = - makeDirectResultSender(std::forward<HandleResultFn>(HandleResult)); - Fn(Session, reinterpret_cast<void *>(DR.release()), - DirectResultSender::send, ArgBytes.release()); - } - -private: - orc_rt_SessionRef Session; - orc_rt_WrapperFunction Fn; -}; - static void void_noop_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { @@ -242,6 +190,30 @@ TEST(SPSWrapperFunctionUtilsTest, TransparentSerializationPointers) { EXPECT_EQ(P, &X); } +static void +expected_int_pointer_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx, + orc_rt_WrapperFunctionReturn Return, + orc_rt_WrapperFunctionBuffer ArgBytes) { + SPSWrapperFunction<SPSExpected<SPSExecutorAddr>(SPSExecutorAddr)>::handle( + Session, CallCtx, Return, ArgBytes, + [](move_only_function<void(Expected<int32_t *>)> Return, int32_t *P) { + Return(P); + }); +} + +TEST(SPSWrapperFunctionUtilsTest, TransparentSerializationExpectedPointers) { + int X = 42; + int *P = nullptr; + SPSWrapperFunction<SPSExpected<SPSExecutorAddr>(SPSExecutorAddr)>::call( + DirectCaller(nullptr, expected_int_pointer_sps_wrapper), + [&](Expected<Expected<int32_t *>> R) { + P = cantFail(cantFail(std::move(R))); + }, + &X); + + EXPECT_EQ(P, &X); +} + template <size_t N> struct SPSOpCounter {}; namespace orc_rt { |