diff options
Diffstat (limited to 'lldb/source/Protocol/MCP')
-rw-r--r-- | lldb/source/Protocol/MCP/CMakeLists.txt | 1 | ||||
-rw-r--r-- | lldb/source/Protocol/MCP/MCPError.cpp | 6 | ||||
-rw-r--r-- | lldb/source/Protocol/MCP/Protocol.cpp | 335 | ||||
-rw-r--r-- | lldb/source/Protocol/MCP/Server.cpp | 210 |
4 files changed, 379 insertions, 173 deletions
diff --git a/lldb/source/Protocol/MCP/CMakeLists.txt b/lldb/source/Protocol/MCP/CMakeLists.txt index a73e7e6..a4f270a 100644 --- a/lldb/source/Protocol/MCP/CMakeLists.txt +++ b/lldb/source/Protocol/MCP/CMakeLists.txt @@ -7,6 +7,7 @@ add_lldb_library(lldbProtocolMCP NO_PLUGIN_DEPENDENCIES LINK_COMPONENTS Support LINK_LIBS + lldbHost lldbUtility ) diff --git a/lldb/source/Protocol/MCP/MCPError.cpp b/lldb/source/Protocol/MCP/MCPError.cpp index c610e88..e140d11 100644 --- a/lldb/source/Protocol/MCP/MCPError.cpp +++ b/lldb/source/Protocol/MCP/MCPError.cpp @@ -25,10 +25,10 @@ std::error_code MCPError::convertToErrorCode() const { return llvm::inconvertibleErrorCode(); } -lldb_protocol::mcp::Error MCPError::toProtcolError() const { +lldb_protocol::mcp::Error MCPError::toProtocolError() const { lldb_protocol::mcp::Error error; - error.error.code = m_error_code; - error.error.message = m_message; + error.code = m_error_code; + error.message = m_message; return error; } diff --git a/lldb/source/Protocol/MCP/Protocol.cpp b/lldb/source/Protocol/MCP/Protocol.cpp index d579b88..0988f45 100644 --- a/lldb/source/Protocol/MCP/Protocol.cpp +++ b/lldb/source/Protocol/MCP/Protocol.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/JSON.h" using namespace llvm; @@ -26,8 +27,45 @@ static bool mapRaw(const json::Value &Params, StringLiteral Prop, return true; } +static llvm::json::Value toJSON(const Id &Id) { + if (const int64_t *I = std::get_if<int64_t>(&Id)) + return json::Value(*I); + if (const std::string *S = std::get_if<std::string>(&Id)) + return json::Value(*S); + llvm_unreachable("unexpected type in protocol::Id"); +} + +static bool mapId(const llvm::json::Value &V, StringLiteral Prop, Id &Id, + llvm::json::Path P) { + const auto *O = V.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + + const auto *E = O->get(Prop); + if (!E) { + P.field(Prop).report("not found"); + return false; + } + + if (auto S = E->getAsString()) { + Id = S->str(); + return true; + } + + if (auto I = E->getAsInteger()) { + Id = *I; + return true; + } + + P.report("expected string or number"); + return false; +} + llvm::json::Value toJSON(const Request &R) { - json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}, {"method", R.method}}; + json::Object Result{ + {"jsonrpc", "2.0"}, {"id", toJSON(R.id)}, {"method", R.method}}; if (R.params) Result.insert({"params", R.params}); return Result; @@ -35,47 +73,75 @@ llvm::json::Value toJSON(const Request &R) { bool fromJSON(const llvm::json::Value &V, Request &R, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); - if (!O || !O.map("id", R.id) || !O.map("method", R.method)) - return false; - return mapRaw(V, "params", R.params, P); + return O && mapId(V, "id", R.id, P) && O.map("method", R.method) && + mapRaw(V, "params", R.params, P); } -llvm::json::Value toJSON(const ErrorInfo &EI) { - llvm::json::Object Result{{"code", EI.code}, {"message", EI.message}}; - if (!EI.data.empty()) - Result.insert({"data", EI.data}); - return Result; -} - -bool fromJSON(const llvm::json::Value &V, ErrorInfo &EI, llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("code", EI.code) && O.map("message", EI.message) && - O.mapOptional("data", EI.data); +bool operator==(const Request &a, const Request &b) { + return a.id == b.id && a.method == b.method && a.params == b.params; } llvm::json::Value toJSON(const Error &E) { - return json::Object{{"jsonrpc", "2.0"}, {"id", E.id}, {"error", E.error}}; + llvm::json::Object Result{{"code", E.code}, {"message", E.message}}; + if (E.data) + Result.insert({"data", *E.data}); + return Result; } bool fromJSON(const llvm::json::Value &V, Error &E, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); - return O && O.map("id", E.id) && O.map("error", E.error); + return O && O.map("code", E.code) && O.map("message", E.message) && + mapRaw(V, "data", E.data, P); +} + +bool operator==(const Error &a, const Error &b) { + return a.code == b.code && a.message == b.message && a.data == b.data; } llvm::json::Value toJSON(const Response &R) { - llvm::json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}}; - if (R.result) - Result.insert({"result", R.result}); - if (R.error) - Result.insert({"error", R.error}); + llvm::json::Object Result{{"jsonrpc", "2.0"}, {"id", toJSON(R.id)}}; + + if (const Error *error = std::get_if<Error>(&R.result)) + Result.insert({"error", *error}); + if (const json::Value *result = std::get_if<json::Value>(&R.result)) + Result.insert({"result", *result}); return Result; } bool fromJSON(const llvm::json::Value &V, Response &R, llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - if (!O || !O.map("id", R.id) || !O.map("error", R.error)) + const json::Object *E = V.getAsObject(); + if (!E) { + P.report("expected object"); + return false; + } + + const json::Value *result = E->get("result"); + const json::Value *raw_error = E->get("error"); + + if (result && raw_error) { + P.report("'result' and 'error' fields are mutually exclusive"); return false; - return mapRaw(V, "result", R.result, P); + } + + if (!result && !raw_error) { + P.report("'result' or 'error' fields are required'"); + return false; + } + + if (result) { + R.result = std::move(*result); + } else { + Error error; + if (!fromJSON(*raw_error, error, P)) + return false; + R.result = std::move(error); + } + + return mapId(V, "id", R.id, P); +} + +bool operator==(const Response &a, const Response &b) { + return a.id == b.id && a.result == b.result; } llvm::json::Value toJSON(const Notification &N) { @@ -97,30 +163,8 @@ bool fromJSON(const llvm::json::Value &V, Notification &N, llvm::json::Path P) { return true; } -llvm::json::Value toJSON(const ToolCapability &TC) { - return llvm::json::Object{{"listChanged", TC.listChanged}}; -} - -bool fromJSON(const llvm::json::Value &V, ToolCapability &TC, - llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("listChanged", TC.listChanged); -} - -llvm::json::Value toJSON(const ResourceCapability &RC) { - return llvm::json::Object{{"listChanged", RC.listChanged}, - {"subscribe", RC.subscribe}}; -} - -bool fromJSON(const llvm::json::Value &V, ResourceCapability &RC, - llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("listChanged", RC.listChanged) && - O.map("subscribe", RC.subscribe); -} - -llvm::json::Value toJSON(const Capabilities &C) { - return llvm::json::Object{{"tools", C.tools}, {"resources", C.resources}}; +bool operator==(const Notification &a, const Notification &b) { + return a.method == b.method && a.params == b.params; } bool fromJSON(const llvm::json::Value &V, Resource &R, llvm::json::Path P) { @@ -139,30 +183,25 @@ llvm::json::Value toJSON(const Resource &R) { return Result; } -bool fromJSON(const llvm::json::Value &V, Capabilities &C, llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("tools", C.tools); -} - -llvm::json::Value toJSON(const ResourceContents &RC) { +llvm::json::Value toJSON(const TextResourceContents &RC) { llvm::json::Object Result{{"uri", RC.uri}, {"text", RC.text}}; if (!RC.mimeType.empty()) Result.insert({"mimeType", RC.mimeType}); return Result; } -bool fromJSON(const llvm::json::Value &V, ResourceContents &RC, +bool fromJSON(const llvm::json::Value &V, TextResourceContents &RC, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); return O && O.map("uri", RC.uri) && O.map("text", RC.text) && O.mapOptional("mimeType", RC.mimeType); } -llvm::json::Value toJSON(const ResourceResult &RR) { +llvm::json::Value toJSON(const ReadResourceResult &RR) { return llvm::json::Object{{"contents", RR.contents}}; } -bool fromJSON(const llvm::json::Value &V, ResourceResult &RR, +bool fromJSON(const llvm::json::Value &V, ReadResourceResult &RR, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); return O && O.map("contents", RR.contents); @@ -177,15 +216,6 @@ bool fromJSON(const llvm::json::Value &V, TextContent &TC, llvm::json::Path P) { return O && O.map("text", TC.text); } -llvm::json::Value toJSON(const TextResult &TR) { - return llvm::json::Object{{"content", TR.content}, {"isError", TR.isError}}; -} - -bool fromJSON(const llvm::json::Value &V, TextResult &TR, llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("content", TR.content) && O.map("isError", TR.isError); -} - llvm::json::Value toJSON(const ToolDefinition &TD) { llvm::json::Object Result{{"name", TD.name}}; if (!TD.description.empty()) @@ -235,24 +265,16 @@ bool fromJSON(const llvm::json::Value &V, Message &M, llvm::json::Path P) { return true; } - if (O->get("error")) { - Error E; - if (!fromJSON(V, E, P)) - return false; - M = std::move(E); - return true; - } - - if (O->get("result")) { - Response R; + if (O->get("method")) { + Request R; if (!fromJSON(V, R, P)) return false; M = std::move(R); return true; } - if (O->get("method")) { - Request R; + if (O->get("result") || O->get("error")) { + Response R; if (!fromJSON(V, R, P)) return false; M = std::move(R); @@ -263,4 +285,159 @@ bool fromJSON(const llvm::json::Value &V, Message &M, llvm::json::Path P) { return false; } +json::Value toJSON(const Implementation &I) { + json::Object result{{"name", I.name}, {"version", I.version}}; + + if (!I.title.empty()) + result.insert({"title", I.title}); + + return result; +} + +bool fromJSON(const json::Value &V, Implementation &I, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("name", I.name) && O.mapOptional("title", I.title) && + O.mapOptional("version", I.version); +} + +json::Value toJSON(const ClientCapabilities &C) { return json::Object{}; } + +bool fromJSON(const json::Value &, ClientCapabilities &, json::Path) { + return true; +} + +json::Value toJSON(const ServerCapabilities &C) { + json::Object result{}; + + if (C.supportsToolsList) + result.insert({"tools", json::Object{{"listChanged", true}}}); + + if (C.supportsResourcesList || C.supportsResourcesSubscribe) { + json::Object resources; + if (C.supportsResourcesList) + resources.insert({"listChanged", true}); + if (C.supportsResourcesSubscribe) + resources.insert({"subscribe", true}); + result.insert({"resources", std::move(resources)}); + } + + if (C.supportsCompletions) + result.insert({"completions", json::Object{}}); + + if (C.supportsLogging) + result.insert({"logging", json::Object{}}); + + return result; +} + +bool fromJSON(const json::Value &V, ServerCapabilities &C, json::Path P) { + const json::Object *O = V.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + + if (O->find("tools") != O->end()) + C.supportsToolsList = true; + + return true; +} + +json::Value toJSON(const InitializeParams &P) { + return json::Object{ + {"protocolVersion", P.protocolVersion}, + {"capabilities", P.capabilities}, + {"clientInfo", P.clientInfo}, + }; +} + +bool fromJSON(const json::Value &V, InitializeParams &I, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("protocolVersion", I.protocolVersion) && + O.map("capabilities", I.capabilities) && + O.map("clientInfo", I.clientInfo); +} + +json::Value toJSON(const InitializeResult &R) { + json::Object result{{"protocolVersion", R.protocolVersion}, + {"capabilities", R.capabilities}, + {"serverInfo", R.serverInfo}}; + + if (!R.instructions.empty()) + result.insert({"instructions", R.instructions}); + + return result; +} + +bool fromJSON(const json::Value &V, InitializeResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("protocolVersion", R.protocolVersion) && + O.map("capabilities", R.capabilities) && + O.map("serverInfo", R.serverInfo) && + O.mapOptional("instructions", R.instructions); +} + +json::Value toJSON(const ListToolsResult &R) { + return json::Object{{"tools", R.tools}}; +} + +bool fromJSON(const json::Value &V, ListToolsResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("tools", R.tools); +} + +json::Value toJSON(const CallToolResult &R) { + json::Object result{{"content", R.content}}; + + if (R.isError) + result.insert({"isError", R.isError}); + if (R.structuredContent) + result.insert({"structuredContent", *R.structuredContent}); + + return result; +} + +bool fromJSON(const json::Value &V, CallToolResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("content", R.content) && + O.mapOptional("isError", R.isError) && + mapRaw(V, "structuredContent", R.structuredContent, P); +} + +json::Value toJSON(const CallToolParams &R) { + json::Object result{{"name", R.name}}; + + if (R.arguments) + result.insert({"arguments", *R.arguments}); + + return result; +} + +bool fromJSON(const json::Value &V, CallToolParams &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("name", R.name) && mapRaw(V, "arguments", R.arguments, P); +} + +json::Value toJSON(const ReadResourceParams &R) { + return json::Object{{"uri", R.uri}}; +} + +bool fromJSON(const json::Value &V, ReadResourceParams &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("uri", R.uri); +} + +json::Value toJSON(const ListResourcesResult &R) { + return json::Object{{"resources", R.resources}}; +} + +bool fromJSON(const json::Value &V, ListResourcesResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("resources", R.resources); +} + +json::Value toJSON(const Void &R) { return json::Object{}; } + +bool fromJSON(const json::Value &V, Void &R, json::Path P) { return true; } + } // namespace lldb_protocol::mcp diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index 4ec127fe..0381b7f 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -8,12 +8,29 @@ #include "lldb/Protocol/MCP/Server.h" #include "lldb/Protocol/MCP/MCPError.h" +#include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/Support/JSON.h" using namespace lldb_protocol::mcp; using namespace llvm; -Server::Server(std::string name, std::string version) - : m_name(std::move(name)), m_version(std::move(version)) { +llvm::json::Value lldb_protocol::mcp::toJSON(const ServerInfo &SM) { + return llvm::json::Object{{"connection_uri", SM.connection_uri}, + {"pid", SM.pid}}; +} + +bool lldb_protocol::mcp::fromJSON(const llvm::json::Value &V, ServerInfo &SM, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("connection_uri", SM.connection_uri) && + O.map("pid", SM.pid); +} + +Server::Server(std::string name, std::string version, + std::unique_ptr<MCPTransport> transport_up, + lldb_private::MainLoop &loop) + : m_name(std::move(name)), m_version(std::move(version)), + m_transport_up(std::move(transport_up)), m_loop(loop) { AddRequestHandlers(); } @@ -30,7 +47,7 @@ void Server::AddRequestHandlers() { this, std::placeholders::_1)); } -llvm::Expected<Response> Server::Handle(Request request) { +llvm::Expected<Response> Server::Handle(const Request &request) { auto it = m_request_handlers.find(request.method); if (it != m_request_handlers.end()) { llvm::Expected<Response> response = it->second(request); @@ -44,7 +61,7 @@ llvm::Expected<Response> Server::Handle(Request request) { llvm::formatv("no handler for request: {0}", request.method).str()); } -void Server::Handle(Notification notification) { +void Server::Handle(const Notification ¬ification) { auto it = m_notification_handlers.find(notification.method); if (it != m_notification_handlers.end()) { it->second(notification); @@ -52,50 +69,7 @@ void Server::Handle(Notification notification) { } } -llvm::Expected<std::optional<Message>> -Server::HandleData(llvm::StringRef data) { - auto message = llvm::json::parse<Message>(/*JSON=*/data); - if (!message) - return message.takeError(); - - if (const Request *request = std::get_if<Request>(&(*message))) { - llvm::Expected<Response> response = Handle(*request); - - // Handle failures by converting them into an Error message. - if (!response) { - Error protocol_error; - llvm::handleAllErrors( - response.takeError(), - [&](const MCPError &err) { protocol_error = err.toProtcolError(); }, - [&](const llvm::ErrorInfoBase &err) { - protocol_error.error.code = MCPError::kInternalError; - protocol_error.error.message = err.message(); - }); - protocol_error.id = request->id; - return protocol_error; - } - - return *response; - } - - if (const Notification *notification = - std::get_if<Notification>(&(*message))) { - Handle(*notification); - return std::nullopt; - } - - if (std::get_if<Error>(&(*message))) - return llvm::createStringError("unexpected MCP message: error"); - - if (std::get_if<Response>(&(*message))) - return llvm::createStringError("unexpected MCP message: response"); - - llvm_unreachable("all message types handled"); -} - void Server::AddTool(std::unique_ptr<Tool> tool) { - std::lock_guard<std::mutex> guard(m_mutex); - if (!tool) return; m_tools[tool->GetName()] = std::move(tool); @@ -103,42 +77,39 @@ void Server::AddTool(std::unique_ptr<Tool> tool) { void Server::AddResourceProvider( std::unique_ptr<ResourceProvider> resource_provider) { - std::lock_guard<std::mutex> guard(m_mutex); - if (!resource_provider) return; m_resource_providers.push_back(std::move(resource_provider)); } void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) { - std::lock_guard<std::mutex> guard(m_mutex); m_request_handlers[method] = std::move(handler); } void Server::AddNotificationHandler(llvm::StringRef method, NotificationHandler handler) { - std::lock_guard<std::mutex> guard(m_mutex); m_notification_handlers[method] = std::move(handler); } llvm::Expected<Response> Server::InitializeHandler(const Request &request) { Response response; - response.result.emplace(llvm::json::Object{ - {"protocolVersion", mcp::kProtocolVersion}, - {"capabilities", GetCapabilities()}, - {"serverInfo", - llvm::json::Object{{"name", m_name}, {"version", m_version}}}}); + InitializeResult result; + result.protocolVersion = mcp::kProtocolVersion; + result.capabilities = GetCapabilities(); + result.serverInfo.name = m_name; + result.serverInfo.version = m_version; + response.result = std::move(result); return response; } llvm::Expected<Response> Server::ToolsListHandler(const Request &request) { Response response; - llvm::json::Array tools; + ListToolsResult result; for (const auto &tool : m_tools) - tools.emplace_back(toJSON(tool.second->GetDefinition())); + result.tools.emplace_back(tool.second->GetDefinition()); - response.result.emplace(llvm::json::Object{{"tools", std::move(tools)}}); + response.result = std::move(result); return response; } @@ -148,16 +119,12 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) { if (!request.params) return llvm::createStringError("no tool parameters"); + CallToolParams params; + json::Path::Root root("params"); + if (!fromJSON(request.params, params, root)) + return root.getError(); - const json::Object *param_obj = request.params->getAsObject(); - if (!param_obj) - return llvm::createStringError("no tool parameters"); - - const json::Value *name = param_obj->get("name"); - if (!name) - return llvm::createStringError("no tool name"); - - llvm::StringRef tool_name = name->getAsString().value_or(""); + llvm::StringRef tool_name = params.name; if (tool_name.empty()) return llvm::createStringError("no tool name"); @@ -166,14 +133,14 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) { return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); ToolArguments tool_args; - if (const json::Value *args = param_obj->get("arguments")) - tool_args = *args; + if (params.arguments) + tool_args = *params.arguments; - llvm::Expected<TextResult> text_result = it->second->Call(tool_args); + llvm::Expected<CallToolResult> text_result = it->second->Call(tool_args); if (!text_result) return text_result.takeError(); - response.result.emplace(toJSON(*text_result)); + response.result = toJSON(*text_result); return response; } @@ -181,16 +148,13 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) { llvm::Expected<Response> Server::ResourcesListHandler(const Request &request) { Response response; - llvm::json::Array resources; - - std::lock_guard<std::mutex> guard(m_mutex); + ListResourcesResult result; for (std::unique_ptr<ResourceProvider> &resource_provider_up : - m_resource_providers) { + m_resource_providers) for (const Resource &resource : resource_provider_up->GetResources()) - resources.push_back(resource); - } - response.result.emplace( - llvm::json::Object{{"resources", std::move(resources)}}); + result.resources.push_back(resource); + + response.result = std::move(result); return response; } @@ -201,22 +165,18 @@ llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) { if (!request.params) return llvm::createStringError("no resource parameters"); - const json::Object *param_obj = request.params->getAsObject(); - if (!param_obj) - return llvm::createStringError("no resource parameters"); - - const json::Value *uri = param_obj->get("uri"); - if (!uri) - return llvm::createStringError("no resource uri"); + ReadResourceParams params; + json::Path::Root root("params"); + if (!fromJSON(request.params, params, root)) + return root.getError(); - llvm::StringRef uri_str = uri->getAsString().value_or(""); + llvm::StringRef uri_str = params.uri; if (uri_str.empty()) return llvm::createStringError("no resource uri"); - std::lock_guard<std::mutex> guard(m_mutex); for (std::unique_ptr<ResourceProvider> &resource_provider_up : m_resource_providers) { - llvm::Expected<ResourceResult> result = + llvm::Expected<ReadResourceResult> result = resource_provider_up->ReadResource(uri_str); if (result.errorIsA<UnsupportedURI>()) { llvm::consumeError(result.takeError()); @@ -226,7 +186,7 @@ llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) { return result.takeError(); Response response; - response.result.emplace(std::move(*result)); + response.result = std::move(*result); return response; } @@ -234,3 +194,71 @@ llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) { llvm::formatv("no resource handler for uri: {0}", uri_str).str(), MCPError::kResourceNotFound); } + +ServerCapabilities Server::GetCapabilities() { + lldb_protocol::mcp::ServerCapabilities capabilities; + capabilities.supportsToolsList = true; + // FIXME: Support sending notifications when a debugger/target are + // added/removed. + capabilities.supportsResourcesList = false; + return capabilities; +} + +llvm::Error Server::Run() { + auto handle = m_transport_up->RegisterMessageHandler(m_loop, *this); + if (!handle) + return handle.takeError(); + + lldb_private::Status status = m_loop.Run(); + if (status.Fail()) + return status.takeError(); + + return llvm::Error::success(); +} + +void Server::Received(const Request &request) { + auto SendResponse = [this](const Response &response) { + if (llvm::Error error = m_transport_up->Send(response)) + m_transport_up->Log(llvm::toString(std::move(error))); + }; + + llvm::Expected<Response> response = Handle(request); + if (response) + return SendResponse(*response); + + lldb_protocol::mcp::Error protocol_error; + llvm::handleAllErrors( + response.takeError(), + [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, + [&](const llvm::ErrorInfoBase &err) { + protocol_error.code = MCPError::kInternalError; + protocol_error.message = err.message(); + }); + Response error_response; + error_response.id = request.id; + error_response.result = std::move(protocol_error); + SendResponse(error_response); +} + +void Server::Received(const Response &response) { + m_transport_up->Log("unexpected MCP message: response"); +} + +void Server::Received(const Notification ¬ification) { + Handle(notification); +} + +void Server::OnError(llvm::Error error) { + m_transport_up->Log(llvm::toString(std::move(error))); + TerminateLoop(); +} + +void Server::OnClosed() { + m_transport_up->Log("EOF"); + TerminateLoop(); +} + +void Server::TerminateLoop() { + m_loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); +} |