aboutsummaryrefslogtreecommitdiff
path: root/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lldb/unittests/Protocol/ProtocolMCPServerTest.cpp')
-rw-r--r--lldb/unittests/Protocol/ProtocolMCPServerTest.cpp312
1 files changed, 166 insertions, 146 deletions
diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp
index f3ca4cf..45464db 100644
--- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp
+++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp
@@ -6,9 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#include "ProtocolMCPTestUtilities.h"
+#include "ProtocolMCPTestUtilities.h" // IWYU pragma: keep
#include "TestingSupport/Host/JSONTransportTestUtilities.h"
-#include "TestingSupport/Host/PipeTestUtilities.h"
#include "TestingSupport/SubsystemRAII.h"
#include "lldb/Host/FileSystem.h"
#include "lldb/Host/HostInfo.h"
@@ -28,20 +27,25 @@
#include "llvm/Testing/Support/Error.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
-#include <chrono>
-#include <condition_variable>
+#include <future>
+#include <memory>
+#include <optional>
+#include <system_error>
using namespace llvm;
using namespace lldb;
using namespace lldb_private;
+using namespace lldb_private::transport;
using namespace lldb_protocol::mcp;
+// Flakey, see https://github.com/llvm/llvm-project/issues/152677.
+#ifndef _WIN32
+
namespace {
-class TestServer : public Server {
-public:
- using Server::Server;
-};
+template <typename T> Response make_response(T &&result, Id id = 1) {
+ return Response{id, std::forward<T>(result)};
+}
/// Test tool that returns it argument as text.
class TestTool : public Tool {
@@ -101,7 +105,9 @@ public:
using Tool::Tool;
llvm::Expected<CallToolResult> Call(const ToolArguments &args) override {
- return llvm::createStringError("error");
+ return llvm::createStringError(
+ std::error_code(eErrorCodeInternalError, std::generic_category()),
+ "error");
}
};
@@ -118,195 +124,209 @@ public:
}
};
-class ProtocolServerMCPTest : public PipePairTest {
+class TestServer : public Server {
+public:
+ using Server::Bind;
+ using Server::Server;
+};
+
+using Transport = TestTransport<lldb_protocol::mcp::ProtocolDescriptor>;
+
+class ProtocolServerMCPTest : public testing::Test {
public:
SubsystemRAII<FileSystem, HostInfo, Socket> subsystems;
MainLoop loop;
+ lldb_private::MainLoop::ReadHandleUP handles[2];
- std::unique_ptr<lldb_protocol::mcp::Transport> from_client;
- std::unique_ptr<lldb_protocol::mcp::Transport> to_client;
- MainLoopBase::ReadHandleUP handles[2];
-
+ std::unique_ptr<Transport> to_server;
+ MCPBinderUP binder;
std::unique_ptr<TestServer> server_up;
- MockMessageHandler<Request, Response, Notification> message_handler;
- llvm::Error Write(llvm::StringRef message) {
- llvm::Expected<json::Value> value = json::parse(message);
- if (!value)
- return value.takeError();
- return from_client->Write(*value);
- }
+ std::unique_ptr<Transport> to_client;
+ MockMessageHandler<lldb_protocol::mcp::ProtocolDescriptor> client;
- llvm::Error Write(json::Value value) { return from_client->Write(value); }
+ std::vector<std::string> logged_messages;
- /// Run the transport MainLoop and return any messages received.
- llvm::Error Run() {
- loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); },
- std::chrono::milliseconds(10));
- return loop.Run().takeError();
+ /// Runs the MainLoop a single time, executing any pending callbacks.
+ void Run() {
+ loop.AddPendingCallback(
+ [](MainLoopBase &loop) { loop.RequestTermination(); });
+ EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded());
}
void SetUp() override {
- PipePairTest::SetUp();
-
- from_client = std::make_unique<lldb_protocol::mcp::Transport>(
- std::make_shared<NativeFile>(input.GetReadFileDescriptor(),
- File::eOpenOptionReadOnly,
- NativeFile::Unowned),
- std::make_shared<NativeFile>(output.GetWriteFileDescriptor(),
- File::eOpenOptionWriteOnly,
- NativeFile::Unowned),
- [](StringRef message) {
- // Uncomment for debugging
- // llvm::errs() << "from_client: " << message << '\n';
- });
- to_client = std::make_unique<lldb_protocol::mcp::Transport>(
- std::make_shared<NativeFile>(output.GetReadFileDescriptor(),
- File::eOpenOptionReadOnly,
- NativeFile::Unowned),
- std::make_shared<NativeFile>(input.GetWriteFileDescriptor(),
- File::eOpenOptionWriteOnly,
- NativeFile::Unowned),
- [](StringRef message) {
- // Uncomment for debugging
- // llvm::errs() << "to_client: " << message << '\n';
- });
-
- server_up = std::make_unique<TestServer>("lldb-mcp", "0.1.0", *to_client,
- [](StringRef message) {
- // Uncomment for debugging
- // llvm::errs() << "server: " <<
- // message << '\n';
- });
-
- auto maybe_from_client_handle =
- from_client->RegisterMessageHandler(loop, message_handler);
- EXPECT_THAT_EXPECTED(maybe_from_client_handle, Succeeded());
- handles[0] = std::move(*maybe_from_client_handle);
-
- auto maybe_to_client_handle =
- to_client->RegisterMessageHandler(loop, *server_up);
- EXPECT_THAT_EXPECTED(maybe_to_client_handle, Succeeded());
- handles[1] = std::move(*maybe_to_client_handle);
+ std::tie(to_client, to_server) = Transport::createPair();
+
+ server_up = std::make_unique<TestServer>(
+ "lldb-mcp", "0.1.0",
+ [this](StringRef msg) { logged_messages.push_back(msg.str()); });
+ binder = server_up->Bind(*to_client);
+ auto server_handle = to_server->RegisterMessageHandler(loop, *binder);
+ EXPECT_THAT_EXPECTED(server_handle, Succeeded());
+ binder->OnError([](llvm::Error error) {
+ llvm::errs() << formatv("Server transport error: {0}", error);
+ });
+ handles[0] = std::move(*server_handle);
+
+ auto client_handle = to_client->RegisterMessageHandler(loop, client);
+ EXPECT_THAT_EXPECTED(client_handle, Succeeded());
+ handles[1] = std::move(*client_handle);
+ }
+
+ template <typename Result, typename Params>
+ Expected<json::Value> Call(StringRef method, const Params &params) {
+ std::promise<Response> promised_result;
+ Request req =
+ lldb_protocol::mcp::Request{/*id=*/1, method.str(), toJSON(params)};
+ EXPECT_THAT_ERROR(to_server->Send(req), Succeeded());
+ EXPECT_CALL(client, Received(testing::An<const Response &>()))
+ .WillOnce(
+ [&](const Response &resp) { promised_result.set_value(resp); });
+ Run();
+ Response resp = promised_result.get_future().get();
+ return toJSON(resp);
+ }
+
+ template <typename Result>
+ Expected<json::Value>
+ Capture(llvm::unique_function<void(Reply<Result>)> &fn) {
+ std::promise<llvm::Expected<Result>> promised_result;
+ fn([&promised_result](llvm::Expected<Result> result) {
+ promised_result.set_value(std::move(result));
+ });
+ Run();
+ llvm::Expected<Result> result = promised_result.get_future().get();
+ if (!result)
+ return result.takeError();
+ return toJSON(*result);
+ }
+
+ template <typename Result, typename Params>
+ Expected<json::Value>
+ Capture(llvm::unique_function<void(const Params &, Reply<Result>)> &fn,
+ const Params &params) {
+ std::promise<llvm::Expected<Result>> promised_result;
+ fn(params, [&promised_result](llvm::Expected<Result> result) {
+ promised_result.set_value(std::move(result));
+ });
+ Run();
+ llvm::Expected<Result> result = promised_result.get_future().get();
+ if (!result)
+ return result.takeError();
+ return toJSON(*result);
}
};
template <typename T>
-Request make_request(StringLiteral method, T &&params, Id id = 1) {
- return Request{id, method.str(), toJSON(std::forward<T>(params))};
-}
-
-template <typename T> Response make_response(T &&result, Id id = 1) {
- return Response{id, std::forward<T>(result)};
+inline testing::internal::EqMatcher<llvm::json::Value> HasJSON(T x) {
+ return testing::internal::EqMatcher<llvm::json::Value>(toJSON(x));
}
} // namespace
TEST_F(ProtocolServerMCPTest, Initialization) {
- Request request = make_request(
- "initialize", InitializeParams{/*protocolVersion=*/"2024-11-05",
- /*capabilities=*/{},
- /*clientInfo=*/{"lldb-unit", "0.1.0"}});
- Response response = make_response(
- InitializeResult{/*protocolVersion=*/"2024-11-05",
- /*capabilities=*/{/*supportsToolsList=*/true},
- /*serverInfo=*/{"lldb-mcp", "0.1.0"}});
-
- ASSERT_THAT_ERROR(Write(request), Succeeded());
- EXPECT_CALL(message_handler, Received(response));
- EXPECT_THAT_ERROR(Run(), Succeeded());
+ EXPECT_THAT_EXPECTED(
+ (Call<InitializeResult, InitializeParams>(
+ "initialize",
+ InitializeParams{/*protocolVersion=*/"2024-11-05",
+ /*capabilities=*/{},
+ /*clientInfo=*/{"lldb-unit", "0.1.0"}})),
+ HasValue(make_response(
+ InitializeResult{/*protocolVersion=*/"2024-11-05",
+ /*capabilities=*/
+ {
+ /*supportsToolsList=*/true,
+ /*supportsResourcesList=*/true,
+ },
+ /*serverInfo=*/{"lldb-mcp", "0.1.0"}})));
}
TEST_F(ProtocolServerMCPTest, ToolsList) {
server_up->AddTool(std::make_unique<TestTool>("test", "test tool"));
- Request request = make_request("tools/list", Void{}, /*id=*/"one");
-
ToolDefinition test_tool;
test_tool.name = "test";
test_tool.description = "test tool";
test_tool.inputSchema = json::Object{{"type", "object"}};
- Response response = make_response(ListToolsResult{{test_tool}}, /*id=*/"one");
-
- ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
- EXPECT_CALL(message_handler, Received(response));
- EXPECT_THAT_ERROR(Run(), Succeeded());
+ EXPECT_THAT_EXPECTED(Call<ListToolsResult>("tools/list", Void{}),
+ HasValue(make_response(ListToolsResult{{test_tool}})));
}
TEST_F(ProtocolServerMCPTest, ResourcesList) {
server_up->AddResourceProvider(std::make_unique<TestResourceProvider>());
- Request request = make_request("resources/list", Void{});
- Response response = make_response(ListResourcesResult{
- {{/*uri=*/"lldb://foo/bar", /*name=*/"name",
- /*description=*/"description", /*mimeType=*/"application/json"}}});
-
- ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
- EXPECT_CALL(message_handler, Received(response));
- EXPECT_THAT_ERROR(Run(), Succeeded());
+ EXPECT_THAT_EXPECTED(Call<ListResourcesResult>("resources/list", Void{}),
+ HasValue(make_response(ListResourcesResult{{
+ {
+ /*uri=*/"lldb://foo/bar",
+ /*name=*/"name",
+ /*description=*/"description",
+ /*mimeType=*/"application/json",
+ },
+ }})));
}
TEST_F(ProtocolServerMCPTest, ToolsCall) {
server_up->AddTool(std::make_unique<TestTool>("test", "test tool"));
- Request request = make_request(
- "tools/call", CallToolParams{/*name=*/"test", /*arguments=*/json::Object{
- {"arguments", "foo"},
- {"debugger_id", 0},
- }});
- Response response = make_response(CallToolResult{{{/*text=*/"foo"}}});
-
- ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
- EXPECT_CALL(message_handler, Received(response));
- EXPECT_THAT_ERROR(Run(), Succeeded());
+ EXPECT_THAT_EXPECTED(
+ (Call<CallToolResult, CallToolParams>("tools/call",
+ CallToolParams{
+ /*name=*/"test",
+ /*arguments=*/
+ json::Object{
+ {"arguments", "foo"},
+ {"debugger_id", 0},
+ },
+ })),
+ HasValue(make_response(CallToolResult{{{/*text=*/"foo"}}})));
}
TEST_F(ProtocolServerMCPTest, ToolsCallError) {
server_up->AddTool(std::make_unique<ErrorTool>("error", "error tool"));
- Request request = make_request(
- "tools/call", CallToolParams{/*name=*/"error", /*arguments=*/json::Object{
- {"arguments", "foo"},
- {"debugger_id", 0},
- }});
- Response response =
- make_response(lldb_protocol::mcp::Error{eErrorCodeInternalError,
- /*message=*/"error"});
-
- ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
- EXPECT_CALL(message_handler, Received(response));
- EXPECT_THAT_ERROR(Run(), Succeeded());
+ EXPECT_THAT_EXPECTED((Call<CallToolResult, CallToolParams>(
+ "tools/call", CallToolParams{
+ /*name=*/"error",
+ /*arguments=*/
+ json::Object{
+ {"arguments", "foo"},
+ {"debugger_id", 0},
+ },
+ })),
+ HasValue(make_response(lldb_protocol::mcp::Error{
+ eErrorCodeInternalError, "error"})));
}
TEST_F(ProtocolServerMCPTest, ToolsCallFail) {
server_up->AddTool(std::make_unique<FailTool>("fail", "fail tool"));
- Request request = make_request(
- "tools/call", CallToolParams{/*name=*/"fail", /*arguments=*/json::Object{
- {"arguments", "foo"},
- {"debugger_id", 0},
- }});
- Response response =
- make_response(CallToolResult{{{/*text=*/"failed"}}, /*isError=*/true});
-
- ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
- EXPECT_CALL(message_handler, Received(response));
- EXPECT_THAT_ERROR(Run(), Succeeded());
+ EXPECT_THAT_EXPECTED((Call<CallToolResult, CallToolParams>(
+ "tools/call", CallToolParams{
+ /*name=*/"fail",
+ /*arguments=*/
+ json::Object{
+ {"arguments", "foo"},
+ {"debugger_id", 0},
+ },
+ })),
+ HasValue(make_response(CallToolResult{
+ {{/*text=*/"failed"}},
+ /*isError=*/true,
+ })));
}
TEST_F(ProtocolServerMCPTest, NotificationInitialized) {
- bool handler_called = false;
- std::condition_variable cv;
-
- server_up->AddNotificationHandler(
- "notifications/initialized",
- [&](const Notification &notification) { handler_called = true; });
- llvm::StringLiteral request =
- R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json";
-
- ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
- EXPECT_THAT_ERROR(Run(), Succeeded());
- EXPECT_TRUE(handler_called);
+ EXPECT_THAT_ERROR(to_server->Send(lldb_protocol::mcp::Notification{
+ "notifications/initialized",
+ std::nullopt,
+ }),
+ Succeeded());
+ Run();
+ EXPECT_THAT(logged_messages,
+ testing::Contains("MCP initialization complete"));
}
+
+#endif