diff options
Diffstat (limited to 'lldb/source/Protocol')
-rw-r--r-- | lldb/source/Protocol/MCP/MCPError.cpp | 9 | ||||
-rw-r--r-- | lldb/source/Protocol/MCP/Server.cpp | 209 |
2 files changed, 63 insertions, 155 deletions
diff --git a/lldb/source/Protocol/MCP/MCPError.cpp b/lldb/source/Protocol/MCP/MCPError.cpp index e140d11..cfac055 100644 --- a/lldb/source/Protocol/MCP/MCPError.cpp +++ b/lldb/source/Protocol/MCP/MCPError.cpp @@ -22,14 +22,7 @@ MCPError::MCPError(std::string message, int64_t error_code) void MCPError::log(llvm::raw_ostream &OS) const { OS << m_message; } std::error_code MCPError::convertToErrorCode() const { - return llvm::inconvertibleErrorCode(); -} - -lldb_protocol::mcp::Error MCPError::toProtocolError() const { - lldb_protocol::mcp::Error error; - error.code = m_error_code; - error.message = m_message; - return error; + return std::error_code(m_error_code, std::generic_category()); } UnsupportedURI::UnsupportedURI(std::string uri) : m_uri(uri) {} diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index 19030a3..71323ad 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -12,6 +12,7 @@ #include "lldb/Host/HostInfo.h" #include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Protocol/MCP/Transport.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" @@ -108,48 +109,9 @@ Expected<std::vector<ServerInfo>> ServerInfo::Load() { return infos; } -Server::Server(std::string name, std::string version, MCPTransport &client, - LogCallback log_callback, ClosedCallback closed_callback) - : m_name(std::move(name)), m_version(std::move(version)), m_client(client), - m_log_callback(std::move(log_callback)), - m_closed_callback(std::move(closed_callback)) { - AddRequestHandlers(); -} - -void Server::AddRequestHandlers() { - AddRequestHandler("initialize", std::bind(&Server::InitializeHandler, this, - std::placeholders::_1)); - AddRequestHandler("tools/list", std::bind(&Server::ToolsListHandler, this, - std::placeholders::_1)); - AddRequestHandler("tools/call", std::bind(&Server::ToolsCallHandler, this, - std::placeholders::_1)); - AddRequestHandler("resources/list", std::bind(&Server::ResourcesListHandler, - this, std::placeholders::_1)); - AddRequestHandler("resources/read", std::bind(&Server::ResourcesReadHandler, - this, std::placeholders::_1)); -} - -llvm::Expected<Response> Server::Handle(const Request &request) { - auto it = m_request_handlers.find(request.method); - if (it != m_request_handlers.end()) { - llvm::Expected<Response> response = it->second(request); - if (!response) - return response; - response->id = request.id; - return *response; - } - - return llvm::make_error<MCPError>( - llvm::formatv("no handler for request: {0}", request.method).str()); -} - -void Server::Handle(const Notification ¬ification) { - auto it = m_notification_handlers.find(notification.method); - if (it != m_notification_handlers.end()) { - it->second(notification); - return; - } -} +Server::Server(std::string name, std::string version, LogCallback log_callback) + : m_name(std::move(name)), m_version(std::move(version)), + m_log_callback(std::move(log_callback)) {} void Server::AddTool(std::unique_ptr<Tool> tool) { if (!tool) @@ -164,48 +126,64 @@ void Server::AddResourceProvider( m_resource_providers.push_back(std::move(resource_provider)); } -void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) { - m_request_handlers[method] = std::move(handler); -} - -void Server::AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler) { - m_notification_handlers[method] = std::move(handler); -} - -llvm::Expected<Response> Server::InitializeHandler(const Request &request) { - Response response; +MCPBinderUP Server::Bind(MCPTransport &transport) { + MCPBinderUP binder_up = std::make_unique<MCPBinder>(transport); + binder_up->Bind<InitializeResult, InitializeParams>( + "initialize", &Server::InitializeHandler, this); + binder_up->Bind<ListToolsResult, void>("tools/list", + &Server::ToolsListHandler, this); + binder_up->Bind<CallToolResult, CallToolParams>( + "tools/call", &Server::ToolsCallHandler, this); + binder_up->Bind<ListResourcesResult, void>( + "resources/list", &Server::ResourcesListHandler, this); + binder_up->Bind<ReadResourceResult, ReadResourceParams>( + "resources/read", &Server::ResourcesReadHandler, this); + binder_up->Bind<void>("notifications/initialized", + [this]() { Log("MCP initialization complete"); }); + return binder_up; +} + +llvm::Error Server::Accept(MainLoop &loop, MCPTransportUP transport) { + MCPBinderUP binder = Bind(*transport); + MCPTransport *transport_ptr = transport.get(); + binder->OnDisconnect([this, transport_ptr]() { + assert(m_instances.find(transport_ptr) != m_instances.end() && + "Client not found in m_instances"); + m_instances.erase(transport_ptr); + }); + binder->OnError([this](llvm::Error err) { + Logv("Transport error: {0}", llvm::toString(std::move(err))); + }); + + auto handle = transport->RegisterMessageHandler(loop, *binder); + if (!handle) + return handle.takeError(); + + m_instances[transport_ptr] = + Client{std::move(*handle), std::move(transport), std::move(binder)}; + return llvm::Error::success(); +} + +Expected<InitializeResult> +Server::InitializeHandler(const InitializeParams &request) { InitializeResult result; result.protocolVersion = mcp::kProtocolVersion; result.capabilities = GetCapabilities(); result.serverInfo.name = m_name; result.serverInfo.version = m_version; - response.result = std::move(result); - return response; + return result; } -llvm::Expected<Response> Server::ToolsListHandler(const Request &request) { - Response response; - +llvm::Expected<ListToolsResult> Server::ToolsListHandler() { ListToolsResult result; for (const auto &tool : m_tools) result.tools.emplace_back(tool.second->GetDefinition()); - response.result = std::move(result); - - return response; + return result; } -llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) { - Response response; - - if (!request.params) - return llvm::createStringError("no tool parameters"); - CallToolParams params; - json::Path::Root root("params"); - if (!fromJSON(request.params, params, root)) - return root.getError(); - +llvm::Expected<CallToolResult> +Server::ToolsCallHandler(const CallToolParams ¶ms) { llvm::StringRef tool_name = params.name; if (tool_name.empty()) return llvm::createStringError("no tool name"); @@ -222,113 +200,50 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) { if (!text_result) return text_result.takeError(); - response.result = toJSON(*text_result); - - return response; + return text_result; } -llvm::Expected<Response> Server::ResourcesListHandler(const Request &request) { - Response response; - +llvm::Expected<ListResourcesResult> Server::ResourcesListHandler() { ListResourcesResult result; for (std::unique_ptr<ResourceProvider> &resource_provider_up : m_resource_providers) for (const Resource &resource : resource_provider_up->GetResources()) result.resources.push_back(resource); - response.result = std::move(result); - - return response; + return result; } -llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) { - Response response; - - if (!request.params) - return llvm::createStringError("no resource parameters"); - - ReadResourceParams params; - json::Path::Root root("params"); - if (!fromJSON(request.params, params, root)) - return root.getError(); - - llvm::StringRef uri_str = params.uri; +Expected<ReadResourceResult> +Server::ResourcesReadHandler(const ReadResourceParams ¶ms) { + StringRef uri_str = params.uri; if (uri_str.empty()) - return llvm::createStringError("no resource uri"); + return createStringError("no resource uri"); for (std::unique_ptr<ResourceProvider> &resource_provider_up : m_resource_providers) { - llvm::Expected<ReadResourceResult> result = + Expected<ReadResourceResult> result = resource_provider_up->ReadResource(uri_str); if (result.errorIsA<UnsupportedURI>()) { - llvm::consumeError(result.takeError()); + consumeError(result.takeError()); continue; } if (!result) return result.takeError(); - Response response; - response.result = std::move(*result); - return response; + return *result; } return make_error<MCPError>( - llvm::formatv("no resource handler for uri: {0}", uri_str).str(), + formatv("no resource handler for uri: {0}", uri_str).str(), MCPError::kResourceNotFound); } ServerCapabilities Server::GetCapabilities() { lldb_protocol::mcp::ServerCapabilities capabilities; capabilities.supportsToolsList = true; + capabilities.supportsResourcesList = true; // FIXME: Support sending notifications when a debugger/target are // added/removed. - capabilities.supportsResourcesList = false; + capabilities.supportsResourcesSubscribe = false; return capabilities; } - -void Server::Log(llvm::StringRef message) { - if (m_log_callback) - m_log_callback(message); -} - -void Server::Received(const Request &request) { - auto SendResponse = [this](const Response &response) { - if (llvm::Error error = m_client.Send(response)) - Log(llvm::toString(std::move(error))); - }; - - llvm::Expected<Response> response = Handle(request); - if (response) - return SendResponse(*response); - - lldb_protocol::mcp::Error protocol_error; - llvm::handleAllErrors( - response.takeError(), - [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, - [&](const llvm::ErrorInfoBase &err) { - protocol_error.code = MCPError::kInternalError; - protocol_error.message = err.message(); - }); - Response error_response; - error_response.id = request.id; - error_response.result = std::move(protocol_error); - SendResponse(error_response); -} - -void Server::Received(const Response &response) { - Log("unexpected MCP message: response"); -} - -void Server::Received(const Notification ¬ification) { - Handle(notification); -} - -void Server::OnError(llvm::Error error) { - Log(llvm::toString(std::move(error))); -} - -void Server::OnClosed() { - Log("EOF"); - if (m_closed_callback) - m_closed_callback(); -} |