diff options
Diffstat (limited to 'lldb/unittests/Protocol')
-rw-r--r-- | lldb/unittests/Protocol/ProtocolMCPServerTest.cpp | 312 | ||||
-rw-r--r-- | lldb/unittests/Protocol/ProtocolMCPTest.cpp | 5 |
2 files changed, 171 insertions, 146 deletions
diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index f3ca4cf..45464db 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -6,9 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "ProtocolMCPTestUtilities.h" +#include "ProtocolMCPTestUtilities.h" // IWYU pragma: keep #include "TestingSupport/Host/JSONTransportTestUtilities.h" -#include "TestingSupport/Host/PipeTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" #include "lldb/Host/FileSystem.h" #include "lldb/Host/HostInfo.h" @@ -28,20 +27,25 @@ #include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -#include <chrono> -#include <condition_variable> +#include <future> +#include <memory> +#include <optional> +#include <system_error> using namespace llvm; using namespace lldb; using namespace lldb_private; +using namespace lldb_private::transport; using namespace lldb_protocol::mcp; +// Flakey, see https://github.com/llvm/llvm-project/issues/152677. +#ifndef _WIN32 + namespace { -class TestServer : public Server { -public: - using Server::Server; -}; +template <typename T> Response make_response(T &&result, Id id = 1) { + return Response{id, std::forward<T>(result)}; +} /// Test tool that returns it argument as text. class TestTool : public Tool { @@ -101,7 +105,9 @@ public: using Tool::Tool; llvm::Expected<CallToolResult> Call(const ToolArguments &args) override { - return llvm::createStringError("error"); + return llvm::createStringError( + std::error_code(eErrorCodeInternalError, std::generic_category()), + "error"); } }; @@ -118,195 +124,209 @@ public: } }; -class ProtocolServerMCPTest : public PipePairTest { +class TestServer : public Server { +public: + using Server::Bind; + using Server::Server; +}; + +using Transport = TestTransport<lldb_protocol::mcp::ProtocolDescriptor>; + +class ProtocolServerMCPTest : public testing::Test { public: SubsystemRAII<FileSystem, HostInfo, Socket> subsystems; MainLoop loop; + lldb_private::MainLoop::ReadHandleUP handles[2]; - std::unique_ptr<lldb_protocol::mcp::Transport> from_client; - std::unique_ptr<lldb_protocol::mcp::Transport> to_client; - MainLoopBase::ReadHandleUP handles[2]; - + std::unique_ptr<Transport> to_server; + MCPBinderUP binder; std::unique_ptr<TestServer> server_up; - MockMessageHandler<Request, Response, Notification> message_handler; - llvm::Error Write(llvm::StringRef message) { - llvm::Expected<json::Value> value = json::parse(message); - if (!value) - return value.takeError(); - return from_client->Write(*value); - } + std::unique_ptr<Transport> to_client; + MockMessageHandler<lldb_protocol::mcp::ProtocolDescriptor> client; - llvm::Error Write(json::Value value) { return from_client->Write(value); } + std::vector<std::string> logged_messages; - /// Run the transport MainLoop and return any messages received. - llvm::Error Run() { - loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, - std::chrono::milliseconds(10)); - return loop.Run().takeError(); + /// Runs the MainLoop a single time, executing any pending callbacks. + void Run() { + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded()); } void SetUp() override { - PipePairTest::SetUp(); - - from_client = std::make_unique<lldb_protocol::mcp::Transport>( - std::make_shared<NativeFile>(input.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared<NativeFile>(output.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned), - [](StringRef message) { - // Uncomment for debugging - // llvm::errs() << "from_client: " << message << '\n'; - }); - to_client = std::make_unique<lldb_protocol::mcp::Transport>( - std::make_shared<NativeFile>(output.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared<NativeFile>(input.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned), - [](StringRef message) { - // Uncomment for debugging - // llvm::errs() << "to_client: " << message << '\n'; - }); - - server_up = std::make_unique<TestServer>("lldb-mcp", "0.1.0", *to_client, - [](StringRef message) { - // Uncomment for debugging - // llvm::errs() << "server: " << - // message << '\n'; - }); - - auto maybe_from_client_handle = - from_client->RegisterMessageHandler(loop, message_handler); - EXPECT_THAT_EXPECTED(maybe_from_client_handle, Succeeded()); - handles[0] = std::move(*maybe_from_client_handle); - - auto maybe_to_client_handle = - to_client->RegisterMessageHandler(loop, *server_up); - EXPECT_THAT_EXPECTED(maybe_to_client_handle, Succeeded()); - handles[1] = std::move(*maybe_to_client_handle); + std::tie(to_client, to_server) = Transport::createPair(); + + server_up = std::make_unique<TestServer>( + "lldb-mcp", "0.1.0", + [this](StringRef msg) { logged_messages.push_back(msg.str()); }); + binder = server_up->Bind(*to_client); + auto server_handle = to_server->RegisterMessageHandler(loop, *binder); + EXPECT_THAT_EXPECTED(server_handle, Succeeded()); + binder->OnError([](llvm::Error error) { + llvm::errs() << formatv("Server transport error: {0}", error); + }); + handles[0] = std::move(*server_handle); + + auto client_handle = to_client->RegisterMessageHandler(loop, client); + EXPECT_THAT_EXPECTED(client_handle, Succeeded()); + handles[1] = std::move(*client_handle); + } + + template <typename Result, typename Params> + Expected<json::Value> Call(StringRef method, const Params ¶ms) { + std::promise<Response> promised_result; + Request req = + lldb_protocol::mcp::Request{/*id=*/1, method.str(), toJSON(params)}; + EXPECT_THAT_ERROR(to_server->Send(req), Succeeded()); + EXPECT_CALL(client, Received(testing::An<const Response &>())) + .WillOnce( + [&](const Response &resp) { promised_result.set_value(resp); }); + Run(); + Response resp = promised_result.get_future().get(); + return toJSON(resp); + } + + template <typename Result> + Expected<json::Value> + Capture(llvm::unique_function<void(Reply<Result>)> &fn) { + std::promise<llvm::Expected<Result>> promised_result; + fn([&promised_result](llvm::Expected<Result> result) { + promised_result.set_value(std::move(result)); + }); + Run(); + llvm::Expected<Result> result = promised_result.get_future().get(); + if (!result) + return result.takeError(); + return toJSON(*result); + } + + template <typename Result, typename Params> + Expected<json::Value> + Capture(llvm::unique_function<void(const Params &, Reply<Result>)> &fn, + const Params ¶ms) { + std::promise<llvm::Expected<Result>> promised_result; + fn(params, [&promised_result](llvm::Expected<Result> result) { + promised_result.set_value(std::move(result)); + }); + Run(); + llvm::Expected<Result> result = promised_result.get_future().get(); + if (!result) + return result.takeError(); + return toJSON(*result); } }; template <typename T> -Request make_request(StringLiteral method, T &¶ms, Id id = 1) { - return Request{id, method.str(), toJSON(std::forward<T>(params))}; -} - -template <typename T> Response make_response(T &&result, Id id = 1) { - return Response{id, std::forward<T>(result)}; +inline testing::internal::EqMatcher<llvm::json::Value> HasJSON(T x) { + return testing::internal::EqMatcher<llvm::json::Value>(toJSON(x)); } } // namespace TEST_F(ProtocolServerMCPTest, Initialization) { - Request request = make_request( - "initialize", InitializeParams{/*protocolVersion=*/"2024-11-05", - /*capabilities=*/{}, - /*clientInfo=*/{"lldb-unit", "0.1.0"}}); - Response response = make_response( - InitializeResult{/*protocolVersion=*/"2024-11-05", - /*capabilities=*/{/*supportsToolsList=*/true}, - /*serverInfo=*/{"lldb-mcp", "0.1.0"}}); - - ASSERT_THAT_ERROR(Write(request), Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED( + (Call<InitializeResult, InitializeParams>( + "initialize", + InitializeParams{/*protocolVersion=*/"2024-11-05", + /*capabilities=*/{}, + /*clientInfo=*/{"lldb-unit", "0.1.0"}})), + HasValue(make_response( + InitializeResult{/*protocolVersion=*/"2024-11-05", + /*capabilities=*/ + { + /*supportsToolsList=*/true, + /*supportsResourcesList=*/true, + }, + /*serverInfo=*/{"lldb-mcp", "0.1.0"}}))); } TEST_F(ProtocolServerMCPTest, ToolsList) { server_up->AddTool(std::make_unique<TestTool>("test", "test tool")); - Request request = make_request("tools/list", Void{}, /*id=*/"one"); - ToolDefinition test_tool; test_tool.name = "test"; test_tool.description = "test tool"; test_tool.inputSchema = json::Object{{"type", "object"}}; - Response response = make_response(ListToolsResult{{test_tool}}, /*id=*/"one"); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED(Call<ListToolsResult>("tools/list", Void{}), + HasValue(make_response(ListToolsResult{{test_tool}}))); } TEST_F(ProtocolServerMCPTest, ResourcesList) { server_up->AddResourceProvider(std::make_unique<TestResourceProvider>()); - Request request = make_request("resources/list", Void{}); - Response response = make_response(ListResourcesResult{ - {{/*uri=*/"lldb://foo/bar", /*name=*/"name", - /*description=*/"description", /*mimeType=*/"application/json"}}}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED(Call<ListResourcesResult>("resources/list", Void{}), + HasValue(make_response(ListResourcesResult{{ + { + /*uri=*/"lldb://foo/bar", + /*name=*/"name", + /*description=*/"description", + /*mimeType=*/"application/json", + }, + }}))); } TEST_F(ProtocolServerMCPTest, ToolsCall) { server_up->AddTool(std::make_unique<TestTool>("test", "test tool")); - Request request = make_request( - "tools/call", CallToolParams{/*name=*/"test", /*arguments=*/json::Object{ - {"arguments", "foo"}, - {"debugger_id", 0}, - }}); - Response response = make_response(CallToolResult{{{/*text=*/"foo"}}}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED( + (Call<CallToolResult, CallToolParams>("tools/call", + CallToolParams{ + /*name=*/"test", + /*arguments=*/ + json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }, + })), + HasValue(make_response(CallToolResult{{{/*text=*/"foo"}}}))); } TEST_F(ProtocolServerMCPTest, ToolsCallError) { server_up->AddTool(std::make_unique<ErrorTool>("error", "error tool")); - Request request = make_request( - "tools/call", CallToolParams{/*name=*/"error", /*arguments=*/json::Object{ - {"arguments", "foo"}, - {"debugger_id", 0}, - }}); - Response response = - make_response(lldb_protocol::mcp::Error{eErrorCodeInternalError, - /*message=*/"error"}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED((Call<CallToolResult, CallToolParams>( + "tools/call", CallToolParams{ + /*name=*/"error", + /*arguments=*/ + json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }, + })), + HasValue(make_response(lldb_protocol::mcp::Error{ + eErrorCodeInternalError, "error"}))); } TEST_F(ProtocolServerMCPTest, ToolsCallFail) { server_up->AddTool(std::make_unique<FailTool>("fail", "fail tool")); - Request request = make_request( - "tools/call", CallToolParams{/*name=*/"fail", /*arguments=*/json::Object{ - {"arguments", "foo"}, - {"debugger_id", 0}, - }}); - Response response = - make_response(CallToolResult{{{/*text=*/"failed"}}, /*isError=*/true}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED((Call<CallToolResult, CallToolParams>( + "tools/call", CallToolParams{ + /*name=*/"fail", + /*arguments=*/ + json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }, + })), + HasValue(make_response(CallToolResult{ + {{/*text=*/"failed"}}, + /*isError=*/true, + }))); } TEST_F(ProtocolServerMCPTest, NotificationInitialized) { - bool handler_called = false; - std::condition_variable cv; - - server_up->AddNotificationHandler( - "notifications/initialized", - [&](const Notification ¬ification) { handler_called = true; }); - llvm::StringLiteral request = - R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json"; - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_THAT_ERROR(Run(), Succeeded()); - EXPECT_TRUE(handler_called); + EXPECT_THAT_ERROR(to_server->Send(lldb_protocol::mcp::Notification{ + "notifications/initialized", + std::nullopt, + }), + Succeeded()); + Run(); + EXPECT_THAT(logged_messages, + testing::Contains("MCP initialization complete")); } + +#endif diff --git a/lldb/unittests/Protocol/ProtocolMCPTest.cpp b/lldb/unittests/Protocol/ProtocolMCPTest.cpp index 396e361..5f7391e 100644 --- a/lldb/unittests/Protocol/ProtocolMCPTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPTest.cpp @@ -16,6 +16,9 @@ using namespace lldb; using namespace lldb_private; using namespace lldb_protocol::mcp; +// Flakey, see https://github.com/llvm/llvm-project/issues/152677. +#ifndef _WIN32 + TEST(ProtocolMCPTest, Request) { Request request; request.id = 1; @@ -292,3 +295,5 @@ TEST(ProtocolMCPTest, ReadResourceResultEmpty) { EXPECT_TRUE(deserialized_result->contents.empty()); } + +#endif |