//===-- SPSWrapperFunctionTest.cpp ----------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Test SPSWrapperFunction and associated utilities. // //===----------------------------------------------------------------------===// #include "CommonTestUtils.h" #include "orc-rt/SPSWrapperFunction.h" #include "orc-rt/WrapperFunction.h" #include "orc-rt/move_only_function.h" #include "gtest/gtest.h" using namespace orc_rt; /// Make calls and call result handlers directly on the current thread. class DirectCaller { private: class DirectResultSender { public: virtual ~DirectResultSender() {} virtual void send(orc_rt_SessionRef Session, WrapperFunctionBuffer ResultBytes) = 0; static void send(orc_rt_SessionRef Session, void *CallCtx, orc_rt_WrapperFunctionBuffer ResultBytes) { std::unique_ptr( reinterpret_cast(CallCtx)) ->send(Session, ResultBytes); } }; template class DirectResultSenderImpl : public DirectResultSender { public: DirectResultSenderImpl(ImplFn &&Fn) : Fn(std::forward(Fn)) {} void send(orc_rt_SessionRef Session, WrapperFunctionBuffer ResultBytes) override { Fn(Session, std::move(ResultBytes)); } private: std::decay_t Fn; }; template static std::unique_ptr makeDirectResultSender(ImplFn &&Fn) { return std::make_unique>( std::forward(Fn)); } public: DirectCaller(orc_rt_SessionRef Session, orc_rt_WrapperFunction Fn) : Session(Session), Fn(Fn) {} template void operator()(HandleResultFn &&HandleResult, WrapperFunctionBuffer ArgBytes) { auto DR = makeDirectResultSender(std::forward(HandleResult)); Fn(Session, reinterpret_cast(DR.release()), DirectResultSender::send, ArgBytes.release()); } private: orc_rt_SessionRef Session; orc_rt_WrapperFunction Fn; }; static void void_noop_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction::handle( Session, CallCtx, Return, ArgBytes, [](move_only_function Return) { Return(); }); } TEST(SPSWrapperFunctionUtilsTest, TestVoidNoop) { bool Ran = false; SPSWrapperFunction::call(DirectCaller(nullptr, void_noop_sps_wrapper), [&](Error Err) { cantFail(std::move(Err)); Ran = true; }); EXPECT_TRUE(Ran); } static void add_via_lambda_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction::handle( Session, CallCtx, Return, ArgBytes, [](move_only_function Return, int32_t X, int32_t Y) { Return(X + Y); }); } TEST(SPSWrapperFunctionUtilsTest, TestBinaryOpViaLambda) { int32_t Result = 0; SPSWrapperFunction::call( DirectCaller(nullptr, add_via_lambda_sps_wrapper), [&](Expected R) { Result = cantFail(std::move(R)); }, 41, 1); EXPECT_EQ(Result, 42); } static void add_via_function(move_only_function Return, int32_t X, int32_t Y) { Return(X + Y); } static void add_via_function_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction::handle( Session, CallCtx, Return, ArgBytes, add_via_function); } TEST(SPSWrapperFunctionUtilsTest, TestBinaryOpViaFunction) { int32_t Result = 0; SPSWrapperFunction::call( DirectCaller(nullptr, add_via_function_sps_wrapper), [&](Expected R) { Result = cantFail(std::move(R)); }, 41, 1); EXPECT_EQ(Result, 42); } static void add_via_function_pointer_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction::handle( Session, CallCtx, Return, ArgBytes, &add_via_function); } TEST(SPSWrapperFunctionUtilsTest, TestBinaryOpViaFunctionPointer) { int32_t Result = 0; SPSWrapperFunction::call( DirectCaller(nullptr, add_via_function_pointer_sps_wrapper), [&](Expected R) { Result = cantFail(std::move(R)); }, 41, 1); EXPECT_EQ(Result, 42); } static void improbable_feat_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction::handle( Session, CallCtx, Return, ArgBytes, [](move_only_function Return, bool LuckyHat) { if (LuckyHat) Return(Error::success()); else Return(make_error("crushed by boulder")); }); } TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningErrorSuccessCase) { bool DidRun = false; SPSWrapperFunction::call( DirectCaller(nullptr, improbable_feat_sps_wrapper), [&](Expected E) { DidRun = true; cantFail(cantFail(std::move(E))); }, true); EXPECT_TRUE(DidRun); } TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningErrorFailureCase) { std::string ErrMsg; SPSWrapperFunction::call( DirectCaller(nullptr, improbable_feat_sps_wrapper), [&](Expected E) { ErrMsg = toString(cantFail(std::move(E))); }, false); EXPECT_EQ(ErrMsg, "crushed by boulder"); } static void halve_number_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction(int32_t)>::handle( Session, CallCtx, Return, ArgBytes, [](move_only_function)> Return, int N) { if (N % 2 == 0) Return(N >> 1); else Return(make_error("N is not a multiple of 2")); }); } TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedSuccessCase) { int32_t Result = 0; SPSWrapperFunction(int32_t)>::call( DirectCaller(nullptr, halve_number_sps_wrapper), [&](Expected> R) { Result = cantFail(cantFail(std::move(R))); }, 2); EXPECT_EQ(Result, 1); } TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedFailureCase) { std::string ErrMsg; SPSWrapperFunction(int32_t)>::call( DirectCaller(nullptr, halve_number_sps_wrapper), [&](Expected> R) { ErrMsg = toString(cantFail(std::move(R)).takeError()); }, 3); EXPECT_EQ(ErrMsg, "N is not a multiple of 2"); } template struct SPSOpCounter {}; namespace orc_rt { template class SPSSerializationTraits, OpCounter> { public: static size_t size(const OpCounter &O) { return 0; } static bool serialize(SPSOutputBuffer &OB, const OpCounter &O) { return true; } static bool deserialize(SPSInputBuffer &OB, OpCounter &O) { return true; } }; } // namespace orc_rt static void handle_with_reference_types_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction, SPSOpCounter<1>, SPSOpCounter<2>, SPSOpCounter<3>)>::handle(Session, CallCtx, Return, ArgBytes, [](move_only_function Return, OpCounter<0>, OpCounter<1> &, const OpCounter<2> &, OpCounter<3> &&) { Return(); }); } TEST(SPSWrapperFunctionUtilsTest, TestHandlerWithReferences) { // Test that we can handle by-value, by-ref, by-const-ref, and by-rvalue-ref // arguments, and that we generate the expected number of moves. OpCounter<0>::reset(); OpCounter<1>::reset(); OpCounter<2>::reset(); OpCounter<3>::reset(); bool DidRun = false; SPSWrapperFunction, SPSOpCounter<1>, SPSOpCounter<2>, SPSOpCounter<3>)>:: call( DirectCaller(nullptr, handle_with_reference_types_sps_wrapper), [&](Error R) { cantFail(std::move(R)); DidRun = true; }, OpCounter<0>(), OpCounter<1>(), OpCounter<2>(), OpCounter<3>()); EXPECT_TRUE(DidRun); // We expect two default constructions for each parameter: one for the // argument to call, and one for the object to deserialize into. EXPECT_EQ(OpCounter<0>::defaultConstructions(), 2U); EXPECT_EQ(OpCounter<1>::defaultConstructions(), 2U); EXPECT_EQ(OpCounter<2>::defaultConstructions(), 2U); EXPECT_EQ(OpCounter<3>::defaultConstructions(), 2U); // Pass-by-value: we expect two moves (one for SPS transparent conversion, // one to copy the value to the parameter), and no copies. EXPECT_EQ(OpCounter<0>::moves(), 2U); EXPECT_EQ(OpCounter<0>::copies(), 0U); // Pass-by-lvalue-reference: we expect one move (for SPS transparent // conversion), no copies. EXPECT_EQ(OpCounter<1>::moves(), 1U); EXPECT_EQ(OpCounter<1>::copies(), 0U); // Pass-by-const-lvalue-reference: we expect one move (for SPS transparent // conversion), no copies. EXPECT_EQ(OpCounter<2>::moves(), 1U); EXPECT_EQ(OpCounter<2>::copies(), 0U); // Pass-by-rvalue-reference: we expect one move (for SPS transparent // conversion), no copies. EXPECT_EQ(OpCounter<3>::moves(), 1U); EXPECT_EQ(OpCounter<3>::copies(), 0U); }