aboutsummaryrefslogtreecommitdiff
path: root/lldb/source/Protocol/MCP
diff options
context:
space:
mode:
Diffstat (limited to 'lldb/source/Protocol/MCP')
-rw-r--r--lldb/source/Protocol/MCP/MCPError.cpp9
-rw-r--r--lldb/source/Protocol/MCP/Server.cpp209
2 files changed, 63 insertions, 155 deletions
diff --git a/lldb/source/Protocol/MCP/MCPError.cpp b/lldb/source/Protocol/MCP/MCPError.cpp
index e140d11..cfac055 100644
--- a/lldb/source/Protocol/MCP/MCPError.cpp
+++ b/lldb/source/Protocol/MCP/MCPError.cpp
@@ -22,14 +22,7 @@ MCPError::MCPError(std::string message, int64_t error_code)
void MCPError::log(llvm::raw_ostream &OS) const { OS << m_message; }
std::error_code MCPError::convertToErrorCode() const {
- return llvm::inconvertibleErrorCode();
-}
-
-lldb_protocol::mcp::Error MCPError::toProtocolError() const {
- lldb_protocol::mcp::Error error;
- error.code = m_error_code;
- error.message = m_message;
- return error;
+ return std::error_code(m_error_code, std::generic_category());
}
UnsupportedURI::UnsupportedURI(std::string uri) : m_uri(uri) {}
diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp
index 19030a3..71323ad 100644
--- a/lldb/source/Protocol/MCP/Server.cpp
+++ b/lldb/source/Protocol/MCP/Server.cpp
@@ -12,6 +12,7 @@
#include "lldb/Host/HostInfo.h"
#include "lldb/Protocol/MCP/MCPError.h"
#include "lldb/Protocol/MCP/Protocol.h"
+#include "lldb/Protocol/MCP/Transport.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/JSON.h"
@@ -108,48 +109,9 @@ Expected<std::vector<ServerInfo>> ServerInfo::Load() {
return infos;
}
-Server::Server(std::string name, std::string version, MCPTransport &client,
- LogCallback log_callback, ClosedCallback closed_callback)
- : m_name(std::move(name)), m_version(std::move(version)), m_client(client),
- m_log_callback(std::move(log_callback)),
- m_closed_callback(std::move(closed_callback)) {
- AddRequestHandlers();
-}
-
-void Server::AddRequestHandlers() {
- AddRequestHandler("initialize", std::bind(&Server::InitializeHandler, this,
- std::placeholders::_1));
- AddRequestHandler("tools/list", std::bind(&Server::ToolsListHandler, this,
- std::placeholders::_1));
- AddRequestHandler("tools/call", std::bind(&Server::ToolsCallHandler, this,
- std::placeholders::_1));
- AddRequestHandler("resources/list", std::bind(&Server::ResourcesListHandler,
- this, std::placeholders::_1));
- AddRequestHandler("resources/read", std::bind(&Server::ResourcesReadHandler,
- this, std::placeholders::_1));
-}
-
-llvm::Expected<Response> Server::Handle(const Request &request) {
- auto it = m_request_handlers.find(request.method);
- if (it != m_request_handlers.end()) {
- llvm::Expected<Response> response = it->second(request);
- if (!response)
- return response;
- response->id = request.id;
- return *response;
- }
-
- return llvm::make_error<MCPError>(
- llvm::formatv("no handler for request: {0}", request.method).str());
-}
-
-void Server::Handle(const Notification &notification) {
- auto it = m_notification_handlers.find(notification.method);
- if (it != m_notification_handlers.end()) {
- it->second(notification);
- return;
- }
-}
+Server::Server(std::string name, std::string version, LogCallback log_callback)
+ : m_name(std::move(name)), m_version(std::move(version)),
+ m_log_callback(std::move(log_callback)) {}
void Server::AddTool(std::unique_ptr<Tool> tool) {
if (!tool)
@@ -164,48 +126,64 @@ void Server::AddResourceProvider(
m_resource_providers.push_back(std::move(resource_provider));
}
-void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) {
- m_request_handlers[method] = std::move(handler);
-}
-
-void Server::AddNotificationHandler(llvm::StringRef method,
- NotificationHandler handler) {
- m_notification_handlers[method] = std::move(handler);
-}
-
-llvm::Expected<Response> Server::InitializeHandler(const Request &request) {
- Response response;
+MCPBinderUP Server::Bind(MCPTransport &transport) {
+ MCPBinderUP binder_up = std::make_unique<MCPBinder>(transport);
+ binder_up->Bind<InitializeResult, InitializeParams>(
+ "initialize", &Server::InitializeHandler, this);
+ binder_up->Bind<ListToolsResult, void>("tools/list",
+ &Server::ToolsListHandler, this);
+ binder_up->Bind<CallToolResult, CallToolParams>(
+ "tools/call", &Server::ToolsCallHandler, this);
+ binder_up->Bind<ListResourcesResult, void>(
+ "resources/list", &Server::ResourcesListHandler, this);
+ binder_up->Bind<ReadResourceResult, ReadResourceParams>(
+ "resources/read", &Server::ResourcesReadHandler, this);
+ binder_up->Bind<void>("notifications/initialized",
+ [this]() { Log("MCP initialization complete"); });
+ return binder_up;
+}
+
+llvm::Error Server::Accept(MainLoop &loop, MCPTransportUP transport) {
+ MCPBinderUP binder = Bind(*transport);
+ MCPTransport *transport_ptr = transport.get();
+ binder->OnDisconnect([this, transport_ptr]() {
+ assert(m_instances.find(transport_ptr) != m_instances.end() &&
+ "Client not found in m_instances");
+ m_instances.erase(transport_ptr);
+ });
+ binder->OnError([this](llvm::Error err) {
+ Logv("Transport error: {0}", llvm::toString(std::move(err)));
+ });
+
+ auto handle = transport->RegisterMessageHandler(loop, *binder);
+ if (!handle)
+ return handle.takeError();
+
+ m_instances[transport_ptr] =
+ Client{std::move(*handle), std::move(transport), std::move(binder)};
+ return llvm::Error::success();
+}
+
+Expected<InitializeResult>
+Server::InitializeHandler(const InitializeParams &request) {
InitializeResult result;
result.protocolVersion = mcp::kProtocolVersion;
result.capabilities = GetCapabilities();
result.serverInfo.name = m_name;
result.serverInfo.version = m_version;
- response.result = std::move(result);
- return response;
+ return result;
}
-llvm::Expected<Response> Server::ToolsListHandler(const Request &request) {
- Response response;
-
+llvm::Expected<ListToolsResult> Server::ToolsListHandler() {
ListToolsResult result;
for (const auto &tool : m_tools)
result.tools.emplace_back(tool.second->GetDefinition());
- response.result = std::move(result);
-
- return response;
+ return result;
}
-llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) {
- Response response;
-
- if (!request.params)
- return llvm::createStringError("no tool parameters");
- CallToolParams params;
- json::Path::Root root("params");
- if (!fromJSON(request.params, params, root))
- return root.getError();
-
+llvm::Expected<CallToolResult>
+Server::ToolsCallHandler(const CallToolParams &params) {
llvm::StringRef tool_name = params.name;
if (tool_name.empty())
return llvm::createStringError("no tool name");
@@ -222,113 +200,50 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) {
if (!text_result)
return text_result.takeError();
- response.result = toJSON(*text_result);
-
- return response;
+ return text_result;
}
-llvm::Expected<Response> Server::ResourcesListHandler(const Request &request) {
- Response response;
-
+llvm::Expected<ListResourcesResult> Server::ResourcesListHandler() {
ListResourcesResult result;
for (std::unique_ptr<ResourceProvider> &resource_provider_up :
m_resource_providers)
for (const Resource &resource : resource_provider_up->GetResources())
result.resources.push_back(resource);
- response.result = std::move(result);
-
- return response;
+ return result;
}
-llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) {
- Response response;
-
- if (!request.params)
- return llvm::createStringError("no resource parameters");
-
- ReadResourceParams params;
- json::Path::Root root("params");
- if (!fromJSON(request.params, params, root))
- return root.getError();
-
- llvm::StringRef uri_str = params.uri;
+Expected<ReadResourceResult>
+Server::ResourcesReadHandler(const ReadResourceParams &params) {
+ StringRef uri_str = params.uri;
if (uri_str.empty())
- return llvm::createStringError("no resource uri");
+ return createStringError("no resource uri");
for (std::unique_ptr<ResourceProvider> &resource_provider_up :
m_resource_providers) {
- llvm::Expected<ReadResourceResult> result =
+ Expected<ReadResourceResult> result =
resource_provider_up->ReadResource(uri_str);
if (result.errorIsA<UnsupportedURI>()) {
- llvm::consumeError(result.takeError());
+ consumeError(result.takeError());
continue;
}
if (!result)
return result.takeError();
- Response response;
- response.result = std::move(*result);
- return response;
+ return *result;
}
return make_error<MCPError>(
- llvm::formatv("no resource handler for uri: {0}", uri_str).str(),
+ formatv("no resource handler for uri: {0}", uri_str).str(),
MCPError::kResourceNotFound);
}
ServerCapabilities Server::GetCapabilities() {
lldb_protocol::mcp::ServerCapabilities capabilities;
capabilities.supportsToolsList = true;
+ capabilities.supportsResourcesList = true;
// FIXME: Support sending notifications when a debugger/target are
// added/removed.
- capabilities.supportsResourcesList = false;
+ capabilities.supportsResourcesSubscribe = false;
return capabilities;
}
-
-void Server::Log(llvm::StringRef message) {
- if (m_log_callback)
- m_log_callback(message);
-}
-
-void Server::Received(const Request &request) {
- auto SendResponse = [this](const Response &response) {
- if (llvm::Error error = m_client.Send(response))
- Log(llvm::toString(std::move(error)));
- };
-
- llvm::Expected<Response> response = Handle(request);
- if (response)
- return SendResponse(*response);
-
- lldb_protocol::mcp::Error protocol_error;
- llvm::handleAllErrors(
- response.takeError(),
- [&](const MCPError &err) { protocol_error = err.toProtocolError(); },
- [&](const llvm::ErrorInfoBase &err) {
- protocol_error.code = MCPError::kInternalError;
- protocol_error.message = err.message();
- });
- Response error_response;
- error_response.id = request.id;
- error_response.result = std::move(protocol_error);
- SendResponse(error_response);
-}
-
-void Server::Received(const Response &response) {
- Log("unexpected MCP message: response");
-}
-
-void Server::Received(const Notification &notification) {
- Handle(notification);
-}
-
-void Server::OnError(llvm::Error error) {
- Log(llvm::toString(std::move(error)));
-}
-
-void Server::OnClosed() {
- Log("EOF");
- if (m_closed_callback)
- m_closed_callback();
-}