//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "ProtocolMCPTestUtilities.h" // IWYU pragma: keep #include "TestingSupport/Host/JSONTransportTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" #include "lldb/Host/FileSystem.h" #include "lldb/Host/HostInfo.h" #include "lldb/Host/JSONTransport.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" #include "lldb/Host/Socket.h" #include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Resource.h" #include "lldb/Protocol/MCP/Server.h" #include "lldb/Protocol/MCP/Tool.h" #include "lldb/Protocol/MCP/Transport.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" #include "llvm/Support/JSON.h" #include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include #include #include #include using namespace llvm; using namespace lldb; using namespace lldb_private; using namespace lldb_private::transport; using namespace lldb_protocol::mcp; namespace { template Response make_response(T &&result, Id id = 1) { return Response{id, std::forward(result)}; } /// Test tool that returns it argument as text. class TestTool : public Tool { public: using Tool::Tool; llvm::Expected Call(const ToolArguments &args) override { std::string argument; if (const json::Object *args_obj = std::get(args).getAsObject()) { if (const json::Value *s = args_obj->get("arguments")) { argument = s->getAsString().value_or(""); } } CallToolResult text_result; text_result.content.emplace_back(TextContent{{argument}}); return text_result; } }; class TestResourceProvider : public ResourceProvider { using ResourceProvider::ResourceProvider; std::vector GetResources() const override { std::vector resources; Resource resource; resource.uri = "lldb://foo/bar"; resource.name = "name"; resource.description = "description"; resource.mimeType = "application/json"; resources.push_back(resource); return resources; } llvm::Expected ReadResource(llvm::StringRef uri) const override { if (uri != "lldb://foo/bar") return llvm::make_error(uri.str()); TextResourceContents contents; contents.uri = "lldb://foo/bar"; contents.mimeType = "application/json"; contents.text = "foobar"; ReadResourceResult result; result.contents.push_back(contents); return result; } }; /// Test tool that returns an error. class ErrorTool : public Tool { public: using Tool::Tool; llvm::Expected Call(const ToolArguments &args) override { return llvm::createStringError( std::error_code(eErrorCodeInternalError, std::generic_category()), "error"); } }; /// Test tool that fails but doesn't return an error. class FailTool : public Tool { public: using Tool::Tool; llvm::Expected Call(const ToolArguments &args) override { CallToolResult text_result; text_result.content.emplace_back(TextContent{{"failed"}}); text_result.isError = true; return text_result; } }; class TestServer : public Server { public: using Server::Bind; using Server::Server; }; using Transport = TestTransport; class ProtocolServerMCPTest : public testing::Test { public: SubsystemRAII subsystems; MainLoop loop; lldb_private::MainLoop::ReadHandleUP handles[2]; std::unique_ptr to_server; MCPBinderUP binder; std::unique_ptr server_up; std::unique_ptr to_client; MockMessageHandler client; std::vector logged_messages; /// Runs the MainLoop a single time, executing any pending callbacks. void Run() { loop.AddPendingCallback( [](MainLoopBase &loop) { loop.RequestTermination(); }); EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded()); } void SetUp() override { std::tie(to_client, to_server) = Transport::createPair(); server_up = std::make_unique( "lldb-mcp", "0.1.0", [this](StringRef msg) { logged_messages.push_back(msg.str()); }); binder = server_up->Bind(*to_client); auto server_handle = to_server->RegisterMessageHandler(loop, *binder); EXPECT_THAT_EXPECTED(server_handle, Succeeded()); binder->OnError([](llvm::Error error) { llvm::errs() << formatv("Server transport error: {0}", error); }); handles[0] = std::move(*server_handle); auto client_handle = to_client->RegisterMessageHandler(loop, client); EXPECT_THAT_EXPECTED(client_handle, Succeeded()); handles[1] = std::move(*client_handle); } template Expected Call(StringRef method, const Params ¶ms) { std::promise promised_result; Request req = lldb_protocol::mcp::Request{/*id=*/1, method.str(), toJSON(params)}; EXPECT_THAT_ERROR(to_server->Send(req), Succeeded()); EXPECT_CALL(client, Received(testing::An())) .WillOnce( [&](const Response &resp) { promised_result.set_value(resp); }); Run(); Response resp = promised_result.get_future().get(); return toJSON(resp); } template Expected Capture(llvm::unique_function)> &fn) { std::promise> promised_result; fn([&promised_result](llvm::Expected result) { promised_result.set_value(std::move(result)); }); Run(); llvm::Expected result = promised_result.get_future().get(); if (!result) return result.takeError(); return toJSON(*result); } template Expected Capture(llvm::unique_function)> &fn, const Params ¶ms) { std::promise> promised_result; fn(params, [&promised_result](llvm::Expected result) { promised_result.set_value(std::move(result)); }); Run(); llvm::Expected result = promised_result.get_future().get(); if (!result) return result.takeError(); return toJSON(*result); } }; template inline testing::internal::EqMatcher HasJSON(T x) { return testing::internal::EqMatcher(toJSON(x)); } } // namespace TEST_F(ProtocolServerMCPTest, Initialization) { EXPECT_THAT_EXPECTED( (Call( "initialize", InitializeParams{/*protocolVersion=*/"2024-11-05", /*capabilities=*/{}, /*clientInfo=*/{"lldb-unit", "0.1.0"}})), HasValue(make_response( InitializeResult{/*protocolVersion=*/"2024-11-05", /*capabilities=*/ { /*supportsToolsList=*/true, /*supportsResourcesList=*/true, }, /*serverInfo=*/{"lldb-mcp", "0.1.0"}}))); } TEST_F(ProtocolServerMCPTest, ToolsList) { server_up->AddTool(std::make_unique("test", "test tool")); ToolDefinition test_tool; test_tool.name = "test"; test_tool.description = "test tool"; test_tool.inputSchema = json::Object{{"type", "object"}}; EXPECT_THAT_EXPECTED(Call("tools/list", Void{}), HasValue(make_response(ListToolsResult{{test_tool}}))); } TEST_F(ProtocolServerMCPTest, ResourcesList) { server_up->AddResourceProvider(std::make_unique()); EXPECT_THAT_EXPECTED(Call("resources/list", Void{}), HasValue(make_response(ListResourcesResult{{ { /*uri=*/"lldb://foo/bar", /*name=*/"name", /*description=*/"description", /*mimeType=*/"application/json", }, }}))); } TEST_F(ProtocolServerMCPTest, ToolsCall) { server_up->AddTool(std::make_unique("test", "test tool")); EXPECT_THAT_EXPECTED( (Call("tools/call", CallToolParams{ /*name=*/"test", /*arguments=*/ json::Object{ {"arguments", "foo"}, {"debugger_id", 0}, }, })), HasValue(make_response(CallToolResult{{{/*text=*/"foo"}}}))); } TEST_F(ProtocolServerMCPTest, ToolsCallError) { server_up->AddTool(std::make_unique("error", "error tool")); EXPECT_THAT_EXPECTED((Call( "tools/call", CallToolParams{ /*name=*/"error", /*arguments=*/ json::Object{ {"arguments", "foo"}, {"debugger_id", 0}, }, })), HasValue(make_response(lldb_protocol::mcp::Error{ eErrorCodeInternalError, "error"}))); } TEST_F(ProtocolServerMCPTest, ToolsCallFail) { server_up->AddTool(std::make_unique("fail", "fail tool")); EXPECT_THAT_EXPECTED((Call( "tools/call", CallToolParams{ /*name=*/"fail", /*arguments=*/ json::Object{ {"arguments", "foo"}, {"debugger_id", 0}, }, })), HasValue(make_response(CallToolResult{ {{/*text=*/"failed"}}, /*isError=*/true, }))); } TEST_F(ProtocolServerMCPTest, NotificationInitialized) { EXPECT_THAT_ERROR(to_server->Send(lldb_protocol::mcp::Notification{ "notifications/initialized", std::nullopt, }), Succeeded()); Run(); EXPECT_THAT(logged_messages, testing::Contains("MCP initialization complete")); }