aboutsummaryrefslogtreecommitdiff
path: root/orc-rt
diff options
context:
space:
mode:
Diffstat (limited to 'orc-rt')
-rw-r--r--orc-rt/include/orc-rt/SPSWrapperFunction.h59
-rw-r--r--orc-rt/include/orc-rt/WrapperFunction.h3
-rw-r--r--orc-rt/unittests/SPSWrapperFunctionTest.cpp74
3 files changed, 129 insertions, 7 deletions
diff --git a/orc-rt/include/orc-rt/SPSWrapperFunction.h b/orc-rt/include/orc-rt/SPSWrapperFunction.h
index 3ea6406..14a3d8e 100644
--- a/orc-rt/include/orc-rt/SPSWrapperFunction.h
+++ b/orc-rt/include/orc-rt/SPSWrapperFunction.h
@@ -21,8 +21,10 @@ namespace orc_rt {
namespace detail {
template <typename... SPSArgTs> struct WFSPSHelper {
- template <typename... ArgTs>
- std::optional<WrapperFunctionBuffer> serialize(const ArgTs &...Args) {
+private:
+ template <typename... SerializableArgTs>
+ std::optional<WrapperFunctionBuffer>
+ serializeImpl(const SerializableArgTs &...Args) {
auto R =
WrapperFunctionBuffer::allocate(SPSArgList<SPSArgTs...>::size(Args...));
SPSOutputBuffer OB(R.data(), R.size());
@@ -31,16 +33,61 @@ template <typename... SPSArgTs> struct WFSPSHelper {
return std::move(R);
}
+ template <typename T> static const T &toSerializable(const T &Arg) noexcept {
+ return Arg;
+ }
+
+ static SPSSerializableError toSerializable(Error Err) noexcept {
+ return SPSSerializableError(std::move(Err));
+ }
+
+ template <typename T>
+ static SPSSerializableExpected<T> toSerializable(Expected<T> Arg) noexcept {
+ return SPSSerializableExpected<T>(std::move(Arg));
+ }
+
+ template <typename... Ts> struct DeserializableTuple;
+
+ template <typename... Ts> struct DeserializableTuple<std::tuple<Ts...>> {
+ typedef std::tuple<
+ std::decay_t<decltype(toSerializable(std::declval<Ts>()))>...>
+ type;
+ };
+
+ template <typename... Ts>
+ using DeserializableTuple_t = typename DeserializableTuple<Ts...>::type;
+
+ template <typename T> static T fromSerializable(T &&Arg) noexcept {
+ return Arg;
+ }
+
+ static Error fromSerializable(SPSSerializableError Err) noexcept {
+ return Err.toError();
+ }
+
+ template <typename T>
+ static Expected<T> fromSerializable(SPSSerializableExpected<T> Val) noexcept {
+ return Val.toExpected();
+ }
+
+public:
+ template <typename... ArgTs>
+ std::optional<WrapperFunctionBuffer> serialize(ArgTs &&...Args) {
+ return serializeImpl(toSerializable(std::forward<ArgTs>(Args))...);
+ }
+
template <typename ArgTuple>
std::optional<ArgTuple> deserialize(WrapperFunctionBuffer ArgBytes) {
assert(!ArgBytes.getOutOfBandError() &&
"Should not attempt to deserialize out-of-band error");
SPSInputBuffer IB(ArgBytes.data(), ArgBytes.size());
- ArgTuple Args;
- if (!SPSSerializationTraits<SPSTuple<SPSArgTs...>, ArgTuple>::deserialize(
- IB, Args))
+ DeserializableTuple_t<ArgTuple> Args;
+ if (!SPSSerializationTraits<SPSTuple<SPSArgTs...>,
+ decltype(Args)>::deserialize(IB, Args))
return std::nullopt;
- return Args;
+ return std::apply(
+ [](auto &&...A) { return ArgTuple(fromSerializable(A)...); },
+ std::move(Args));
}
};
diff --git a/orc-rt/include/orc-rt/WrapperFunction.h b/orc-rt/include/orc-rt/WrapperFunction.h
index 233c3b2..ca165db 100644
--- a/orc-rt/include/orc-rt/WrapperFunction.h
+++ b/orc-rt/include/orc-rt/WrapperFunction.h
@@ -168,7 +168,8 @@ struct ResultDeserializer<std::tuple<Expected<T>>, Serializer> {
Serializer &S) {
if (auto Val = S.result().template deserialize<std::tuple<T>>(
std::move(ResultBytes)))
- return std::move(std::get<0>(*Val));
+ return Expected<T>(std::move(std::get<0>(*Val)),
+ ForceExpectedSuccessValue());
else
return make_error<StringError>("Could not deserialize result");
}
diff --git a/orc-rt/unittests/SPSWrapperFunctionTest.cpp b/orc-rt/unittests/SPSWrapperFunctionTest.cpp
index 0b65515..c0c86ff 100644
--- a/orc-rt/unittests/SPSWrapperFunctionTest.cpp
+++ b/orc-rt/unittests/SPSWrapperFunctionTest.cpp
@@ -144,3 +144,77 @@ TEST(SPSWrapperFunctionUtilsTest, TestBinaryOpViaFunctionPointer) {
[&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
EXPECT_EQ(Result, 42);
}
+
+static void improbable_feat_sps_wrapper(orc_rt_SessionRef Session,
+ void *CallCtx,
+ orc_rt_WrapperFunctionReturn Return,
+ orc_rt_WrapperFunctionBuffer ArgBytes) {
+ SPSWrapperFunction<SPSError(bool)>::handle(
+ Session, CallCtx, Return, ArgBytes,
+ [](move_only_function<void(Error)> Return, bool LuckyHat) {
+ if (LuckyHat)
+ Return(Error::success());
+ else
+ Return(make_error<StringError>("crushed by boulder"));
+ });
+}
+
+TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningErrorSuccessCase) {
+ bool DidRun = false;
+ SPSWrapperFunction<SPSError(bool)>::call(
+ DirectCaller(nullptr, improbable_feat_sps_wrapper),
+ [&](Expected<Error> E) {
+ DidRun = true;
+ cantFail(cantFail(std::move(E)));
+ },
+ true);
+
+ EXPECT_TRUE(DidRun);
+}
+
+TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningErrorFailureCase) {
+ std::string ErrMsg;
+ SPSWrapperFunction<SPSError(bool)>::call(
+ DirectCaller(nullptr, improbable_feat_sps_wrapper),
+ [&](Expected<Error> E) { ErrMsg = toString(cantFail(std::move(E))); },
+ false);
+
+ EXPECT_EQ(ErrMsg, "crushed by boulder");
+}
+
+static void halve_number_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
+ orc_rt_WrapperFunctionReturn Return,
+ orc_rt_WrapperFunctionBuffer ArgBytes) {
+ SPSWrapperFunction<SPSExpected<int32_t>(int32_t)>::handle(
+ Session, CallCtx, Return, ArgBytes,
+ [](move_only_function<void(Expected<int32_t>)> Return, int N) {
+ if (N % 2 == 0)
+ Return(N >> 1);
+ else
+ Return(make_error<StringError>("N is not a multiple of 2"));
+ });
+}
+
+TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedSuccessCase) {
+ int32_t Result = 0;
+ SPSWrapperFunction<SPSExpected<int32_t>(int32_t)>::call(
+ DirectCaller(nullptr, halve_number_sps_wrapper),
+ [&](Expected<Expected<int32_t>> R) {
+ Result = cantFail(cantFail(std::move(R)));
+ },
+ 2);
+
+ EXPECT_EQ(Result, 1);
+}
+
+TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedFailureCase) {
+ std::string ErrMsg;
+ SPSWrapperFunction<SPSExpected<int32_t>(int32_t)>::call(
+ DirectCaller(nullptr, halve_number_sps_wrapper),
+ [&](Expected<Expected<int32_t>> R) {
+ ErrMsg = toString(cantFail(std::move(R)).takeError());
+ },
+ 3);
+
+ EXPECT_EQ(ErrMsg, "N is not a multiple of 2");
+}