aboutsummaryrefslogtreecommitdiff
path: root/lldb/source
diff options
context:
space:
mode:
Diffstat (limited to 'lldb/source')
-rw-r--r--lldb/source/Breakpoint/BreakpointResolverName.cpp2
-rw-r--r--lldb/source/Commands/CommandObjectType.cpp4
-rw-r--r--lldb/source/Core/Mangled.cpp4
-rw-r--r--lldb/source/Host/common/JSONTransport.cpp26
-rw-r--r--lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp52
-rw-r--r--lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h20
-rw-r--r--lldb/source/Protocol/MCP/MCPError.cpp9
-rw-r--r--lldb/source/Protocol/MCP/Server.cpp209
-rw-r--r--lldb/source/Symbol/Symtab.cpp2
-rw-r--r--lldb/source/Target/Language.cpp15
10 files changed, 129 insertions, 214 deletions
diff --git a/lldb/source/Breakpoint/BreakpointResolverName.cpp b/lldb/source/Breakpoint/BreakpointResolverName.cpp
index 6372595..4f252f9 100644
--- a/lldb/source/Breakpoint/BreakpointResolverName.cpp
+++ b/lldb/source/Breakpoint/BreakpointResolverName.cpp
@@ -233,7 +233,7 @@ void BreakpointResolverName::AddNameLookup(ConstString name,
m_lookups.emplace_back(variant_lookup);
}
}
- return true;
+ return IterationAction::Continue;
};
if (Language *lang = Language::FindPlugin(m_language)) {
diff --git a/lldb/source/Commands/CommandObjectType.cpp b/lldb/source/Commands/CommandObjectType.cpp
index 19cd3ff..22ed5b8 100644
--- a/lldb/source/Commands/CommandObjectType.cpp
+++ b/lldb/source/Commands/CommandObjectType.cpp
@@ -2610,7 +2610,7 @@ public:
Language::ForEach([&](Language *lang) {
if (const char *help = lang->GetLanguageSpecificTypeLookupHelp())
stream.Printf("%s\n", help);
- return true;
+ return IterationAction::Continue;
});
m_cmd_help_long = std::string(stream.GetString());
@@ -2649,7 +2649,7 @@ public:
(m_command_options.m_language == eLanguageTypeUnknown))) {
Language::ForEach([&](Language *lang) {
languages.push_back(lang);
- return true;
+ return IterationAction::Continue;
});
} else {
languages.push_back(Language::FindPlugin(m_command_options.m_language));
diff --git a/lldb/source/Core/Mangled.cpp b/lldb/source/Core/Mangled.cpp
index 0780846..f7683c5 100644
--- a/lldb/source/Core/Mangled.cpp
+++ b/lldb/source/Core/Mangled.cpp
@@ -428,9 +428,9 @@ lldb::LanguageType Mangled::GuessLanguage() const {
Language::ForEach([this, &result](Language *l) {
if (l->SymbolNameFitsToLanguage(*this)) {
result = l->GetLanguageType();
- return false;
+ return IterationAction::Stop;
}
- return true;
+ return IterationAction::Continue;
});
return result;
}
diff --git a/lldb/source/Host/common/JSONTransport.cpp b/lldb/source/Host/common/JSONTransport.cpp
index c4b42ea..22de7fa 100644
--- a/lldb/source/Host/common/JSONTransport.cpp
+++ b/lldb/source/Host/common/JSONTransport.cpp
@@ -14,8 +14,7 @@
#include <string>
using namespace llvm;
-using namespace lldb;
-using namespace lldb_private;
+using namespace lldb_private::transport;
char TransportUnhandledContentsError::ID;
@@ -23,10 +22,31 @@ TransportUnhandledContentsError::TransportUnhandledContentsError(
std::string unhandled_contents)
: m_unhandled_contents(unhandled_contents) {}
-void TransportUnhandledContentsError::log(llvm::raw_ostream &OS) const {
+void TransportUnhandledContentsError::log(raw_ostream &OS) const {
OS << "transport EOF with unhandled contents: '" << m_unhandled_contents
<< "'";
}
std::error_code TransportUnhandledContentsError::convertToErrorCode() const {
return std::make_error_code(std::errc::bad_message);
}
+
+char InvalidParams::ID;
+
+void InvalidParams::log(raw_ostream &OS) const {
+ OS << "invalid parameters for method '" << m_method << "': '" << m_context
+ << "'";
+}
+std::error_code InvalidParams::convertToErrorCode() const {
+ return std::make_error_code(std::errc::invalid_argument);
+}
+
+char MethodNotFound::ID;
+
+void MethodNotFound::log(raw_ostream &OS) const {
+ OS << "method not found: '" << m_method << "'";
+}
+
+std::error_code MethodNotFound::convertToErrorCode() const {
+ // JSON-RPC Method not found
+ return std::error_code(MethodNotFound::kErrorCode, std::generic_category());
+}
diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp
index d7293fc..33bdd5e 100644
--- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp
+++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp
@@ -52,11 +52,6 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() {
}
void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const {
- server.AddNotificationHandler("notifications/initialized",
- [](const lldb_protocol::mcp::Notification &) {
- LLDB_LOG(GetLog(LLDBLog::Host),
- "MCP initialization complete");
- });
server.AddTool(
std::make_unique<CommandTool>("command", "Run an lldb command."));
server.AddTool(std::make_unique<DebuggerListTool>(
@@ -74,26 +69,9 @@ void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) {
io_sp, io_sp, [client_name](llvm::StringRef message) {
LLDB_LOG(GetLog(LLDBLog::Host), "{0}: {1}", client_name, message);
});
- MCPTransport *transport_ptr = transport_up.get();
- auto instance_up = std::make_unique<lldb_protocol::mcp::Server>(
- std::string(kName), std::string(kVersion), *transport_up,
- /*log_callback=*/
- [client_name](llvm::StringRef message) {
- LLDB_LOG(GetLog(LLDBLog::Host), "{0} Server: {1}", client_name,
- message);
- },
- /*closed_callback=*/
- [this, transport_ptr]() { m_instances.erase(transport_ptr); });
- Extend(*instance_up);
- llvm::Expected<MainLoop::ReadHandleUP> handle =
- transport_up->RegisterMessageHandler(m_loop, *instance_up);
- if (!handle) {
- LLDB_LOG_ERROR(log, handle.takeError(), "Failed to run MCP server: {0}");
- return;
- }
- m_instances[transport_ptr] =
- std::make_tuple<ServerUP, ReadHandleUP, TransportUP>(
- std::move(instance_up), std::move(*handle), std::move(transport_up));
+
+ if (auto error = m_server->Accept(m_loop, std::move(transport_up)))
+ LLDB_LOG_ERROR(log, std::move(error), "{0}:");
}
llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
@@ -124,14 +102,21 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
llvm::join(m_listener->GetListeningConnectionURI(), ", ");
ServerInfo info{listening_uris[0]};
- llvm::Expected<ServerInfoHandle> handle = ServerInfo::Write(info);
- if (!handle)
- return handle.takeError();
+ llvm::Expected<ServerInfoHandle> server_info_handle = ServerInfo::Write(info);
+ if (!server_info_handle)
+ return server_info_handle.takeError();
+
+ m_client_count = 0;
+ m_server = std::make_unique<lldb_protocol::mcp::Server>(
+ std::string(kName), std::string(kVersion), [](StringRef message) {
+ LLDB_LOG(GetLog(LLDBLog::Host), "MCP Server: {0}", message);
+ });
+ Extend(*m_server);
m_running = true;
- m_server_info_handle = std::move(*handle);
- m_listen_handlers = std::move(*handles);
- m_loop_thread = std::thread([=] {
+ m_server_info_handle = std::move(*server_info_handle);
+ m_accept_handles = std::move(*handles);
+ m_loop_thread = std::thread([this] {
llvm::set_thread_name("protocol-server.mcp");
m_loop.Run();
});
@@ -155,9 +140,10 @@ llvm::Error ProtocolServerMCP::Stop() {
if (m_loop_thread.joinable())
m_loop_thread.join();
+ m_accept_handles.clear();
+
+ m_server.reset(nullptr);
m_server_info_handle.Remove();
- m_listen_handlers.clear();
- m_instances.clear();
return llvm::Error::success();
}
diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h
index b325a36..e0f2a6c 100644
--- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h
+++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h
@@ -23,16 +23,17 @@
namespace lldb_private::mcp {
class ProtocolServerMCP : public ProtocolServer {
- using ReadHandleUP = MainLoopBase::ReadHandleUP;
- using TransportUP = std::unique_ptr<lldb_protocol::mcp::MCPTransport>;
+
using ServerUP = std::unique_ptr<lldb_protocol::mcp::Server>;
+ using ReadHandleUP = MainLoop::ReadHandleUP;
+
public:
ProtocolServerMCP();
- virtual ~ProtocolServerMCP() override;
+ ~ProtocolServerMCP() override;
- virtual llvm::Error Start(ProtocolServer::Connection connection) override;
- virtual llvm::Error Stop() override;
+ llvm::Error Start(ProtocolServer::Connection connection) override;
+ llvm::Error Stop() override;
static void Initialize();
static void Terminate();
@@ -56,19 +57,18 @@ private:
bool m_running = false;
- lldb_protocol::mcp::ServerInfoHandle m_server_info_handle;
lldb_private::MainLoop m_loop;
std::thread m_loop_thread;
std::mutex m_mutex;
size_t m_client_count = 0;
std::unique_ptr<Socket> m_listener;
+ std::vector<ReadHandleUP> m_accept_handles;
- std::vector<ReadHandleUP> m_listen_handlers;
- std::map<lldb_protocol::mcp::MCPTransport *,
- std::tuple<ServerUP, ReadHandleUP, TransportUP>>
- m_instances;
+ ServerUP m_server;
+ lldb_protocol::mcp::ServerInfoHandle m_server_info_handle;
};
+
} // namespace lldb_private::mcp
#endif
diff --git a/lldb/source/Protocol/MCP/MCPError.cpp b/lldb/source/Protocol/MCP/MCPError.cpp
index e140d11..cfac055 100644
--- a/lldb/source/Protocol/MCP/MCPError.cpp
+++ b/lldb/source/Protocol/MCP/MCPError.cpp
@@ -22,14 +22,7 @@ MCPError::MCPError(std::string message, int64_t error_code)
void MCPError::log(llvm::raw_ostream &OS) const { OS << m_message; }
std::error_code MCPError::convertToErrorCode() const {
- return llvm::inconvertibleErrorCode();
-}
-
-lldb_protocol::mcp::Error MCPError::toProtocolError() const {
- lldb_protocol::mcp::Error error;
- error.code = m_error_code;
- error.message = m_message;
- return error;
+ return std::error_code(m_error_code, std::generic_category());
}
UnsupportedURI::UnsupportedURI(std::string uri) : m_uri(uri) {}
diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp
index 19030a3..71323ad 100644
--- a/lldb/source/Protocol/MCP/Server.cpp
+++ b/lldb/source/Protocol/MCP/Server.cpp
@@ -12,6 +12,7 @@
#include "lldb/Host/HostInfo.h"
#include "lldb/Protocol/MCP/MCPError.h"
#include "lldb/Protocol/MCP/Protocol.h"
+#include "lldb/Protocol/MCP/Transport.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/JSON.h"
@@ -108,48 +109,9 @@ Expected<std::vector<ServerInfo>> ServerInfo::Load() {
return infos;
}
-Server::Server(std::string name, std::string version, MCPTransport &client,
- LogCallback log_callback, ClosedCallback closed_callback)
- : m_name(std::move(name)), m_version(std::move(version)), m_client(client),
- m_log_callback(std::move(log_callback)),
- m_closed_callback(std::move(closed_callback)) {
- AddRequestHandlers();
-}
-
-void Server::AddRequestHandlers() {
- AddRequestHandler("initialize", std::bind(&Server::InitializeHandler, this,
- std::placeholders::_1));
- AddRequestHandler("tools/list", std::bind(&Server::ToolsListHandler, this,
- std::placeholders::_1));
- AddRequestHandler("tools/call", std::bind(&Server::ToolsCallHandler, this,
- std::placeholders::_1));
- AddRequestHandler("resources/list", std::bind(&Server::ResourcesListHandler,
- this, std::placeholders::_1));
- AddRequestHandler("resources/read", std::bind(&Server::ResourcesReadHandler,
- this, std::placeholders::_1));
-}
-
-llvm::Expected<Response> Server::Handle(const Request &request) {
- auto it = m_request_handlers.find(request.method);
- if (it != m_request_handlers.end()) {
- llvm::Expected<Response> response = it->second(request);
- if (!response)
- return response;
- response->id = request.id;
- return *response;
- }
-
- return llvm::make_error<MCPError>(
- llvm::formatv("no handler for request: {0}", request.method).str());
-}
-
-void Server::Handle(const Notification &notification) {
- auto it = m_notification_handlers.find(notification.method);
- if (it != m_notification_handlers.end()) {
- it->second(notification);
- return;
- }
-}
+Server::Server(std::string name, std::string version, LogCallback log_callback)
+ : m_name(std::move(name)), m_version(std::move(version)),
+ m_log_callback(std::move(log_callback)) {}
void Server::AddTool(std::unique_ptr<Tool> tool) {
if (!tool)
@@ -164,48 +126,64 @@ void Server::AddResourceProvider(
m_resource_providers.push_back(std::move(resource_provider));
}
-void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) {
- m_request_handlers[method] = std::move(handler);
-}
-
-void Server::AddNotificationHandler(llvm::StringRef method,
- NotificationHandler handler) {
- m_notification_handlers[method] = std::move(handler);
-}
-
-llvm::Expected<Response> Server::InitializeHandler(const Request &request) {
- Response response;
+MCPBinderUP Server::Bind(MCPTransport &transport) {
+ MCPBinderUP binder_up = std::make_unique<MCPBinder>(transport);
+ binder_up->Bind<InitializeResult, InitializeParams>(
+ "initialize", &Server::InitializeHandler, this);
+ binder_up->Bind<ListToolsResult, void>("tools/list",
+ &Server::ToolsListHandler, this);
+ binder_up->Bind<CallToolResult, CallToolParams>(
+ "tools/call", &Server::ToolsCallHandler, this);
+ binder_up->Bind<ListResourcesResult, void>(
+ "resources/list", &Server::ResourcesListHandler, this);
+ binder_up->Bind<ReadResourceResult, ReadResourceParams>(
+ "resources/read", &Server::ResourcesReadHandler, this);
+ binder_up->Bind<void>("notifications/initialized",
+ [this]() { Log("MCP initialization complete"); });
+ return binder_up;
+}
+
+llvm::Error Server::Accept(MainLoop &loop, MCPTransportUP transport) {
+ MCPBinderUP binder = Bind(*transport);
+ MCPTransport *transport_ptr = transport.get();
+ binder->OnDisconnect([this, transport_ptr]() {
+ assert(m_instances.find(transport_ptr) != m_instances.end() &&
+ "Client not found in m_instances");
+ m_instances.erase(transport_ptr);
+ });
+ binder->OnError([this](llvm::Error err) {
+ Logv("Transport error: {0}", llvm::toString(std::move(err)));
+ });
+
+ auto handle = transport->RegisterMessageHandler(loop, *binder);
+ if (!handle)
+ return handle.takeError();
+
+ m_instances[transport_ptr] =
+ Client{std::move(*handle), std::move(transport), std::move(binder)};
+ return llvm::Error::success();
+}
+
+Expected<InitializeResult>
+Server::InitializeHandler(const InitializeParams &request) {
InitializeResult result;
result.protocolVersion = mcp::kProtocolVersion;
result.capabilities = GetCapabilities();
result.serverInfo.name = m_name;
result.serverInfo.version = m_version;
- response.result = std::move(result);
- return response;
+ return result;
}
-llvm::Expected<Response> Server::ToolsListHandler(const Request &request) {
- Response response;
-
+llvm::Expected<ListToolsResult> Server::ToolsListHandler() {
ListToolsResult result;
for (const auto &tool : m_tools)
result.tools.emplace_back(tool.second->GetDefinition());
- response.result = std::move(result);
-
- return response;
+ return result;
}
-llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) {
- Response response;
-
- if (!request.params)
- return llvm::createStringError("no tool parameters");
- CallToolParams params;
- json::Path::Root root("params");
- if (!fromJSON(request.params, params, root))
- return root.getError();
-
+llvm::Expected<CallToolResult>
+Server::ToolsCallHandler(const CallToolParams &params) {
llvm::StringRef tool_name = params.name;
if (tool_name.empty())
return llvm::createStringError("no tool name");
@@ -222,113 +200,50 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) {
if (!text_result)
return text_result.takeError();
- response.result = toJSON(*text_result);
-
- return response;
+ return text_result;
}
-llvm::Expected<Response> Server::ResourcesListHandler(const Request &request) {
- Response response;
-
+llvm::Expected<ListResourcesResult> Server::ResourcesListHandler() {
ListResourcesResult result;
for (std::unique_ptr<ResourceProvider> &resource_provider_up :
m_resource_providers)
for (const Resource &resource : resource_provider_up->GetResources())
result.resources.push_back(resource);
- response.result = std::move(result);
-
- return response;
+ return result;
}
-llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) {
- Response response;
-
- if (!request.params)
- return llvm::createStringError("no resource parameters");
-
- ReadResourceParams params;
- json::Path::Root root("params");
- if (!fromJSON(request.params, params, root))
- return root.getError();
-
- llvm::StringRef uri_str = params.uri;
+Expected<ReadResourceResult>
+Server::ResourcesReadHandler(const ReadResourceParams &params) {
+ StringRef uri_str = params.uri;
if (uri_str.empty())
- return llvm::createStringError("no resource uri");
+ return createStringError("no resource uri");
for (std::unique_ptr<ResourceProvider> &resource_provider_up :
m_resource_providers) {
- llvm::Expected<ReadResourceResult> result =
+ Expected<ReadResourceResult> result =
resource_provider_up->ReadResource(uri_str);
if (result.errorIsA<UnsupportedURI>()) {
- llvm::consumeError(result.takeError());
+ consumeError(result.takeError());
continue;
}
if (!result)
return result.takeError();
- Response response;
- response.result = std::move(*result);
- return response;
+ return *result;
}
return make_error<MCPError>(
- llvm::formatv("no resource handler for uri: {0}", uri_str).str(),
+ formatv("no resource handler for uri: {0}", uri_str).str(),
MCPError::kResourceNotFound);
}
ServerCapabilities Server::GetCapabilities() {
lldb_protocol::mcp::ServerCapabilities capabilities;
capabilities.supportsToolsList = true;
+ capabilities.supportsResourcesList = true;
// FIXME: Support sending notifications when a debugger/target are
// added/removed.
- capabilities.supportsResourcesList = false;
+ capabilities.supportsResourcesSubscribe = false;
return capabilities;
}
-
-void Server::Log(llvm::StringRef message) {
- if (m_log_callback)
- m_log_callback(message);
-}
-
-void Server::Received(const Request &request) {
- auto SendResponse = [this](const Response &response) {
- if (llvm::Error error = m_client.Send(response))
- Log(llvm::toString(std::move(error)));
- };
-
- llvm::Expected<Response> response = Handle(request);
- if (response)
- return SendResponse(*response);
-
- lldb_protocol::mcp::Error protocol_error;
- llvm::handleAllErrors(
- response.takeError(),
- [&](const MCPError &err) { protocol_error = err.toProtocolError(); },
- [&](const llvm::ErrorInfoBase &err) {
- protocol_error.code = MCPError::kInternalError;
- protocol_error.message = err.message();
- });
- Response error_response;
- error_response.id = request.id;
- error_response.result = std::move(protocol_error);
- SendResponse(error_response);
-}
-
-void Server::Received(const Response &response) {
- Log("unexpected MCP message: response");
-}
-
-void Server::Received(const Notification &notification) {
- Handle(notification);
-}
-
-void Server::OnError(llvm::Error error) {
- Log(llvm::toString(std::move(error)));
-}
-
-void Server::OnClosed() {
- Log("EOF");
- if (m_closed_callback)
- m_closed_callback();
-}
diff --git a/lldb/source/Symbol/Symtab.cpp b/lldb/source/Symbol/Symtab.cpp
index 970f6c4..6080703 100644
--- a/lldb/source/Symbol/Symtab.cpp
+++ b/lldb/source/Symbol/Symtab.cpp
@@ -289,7 +289,7 @@ void Symtab::InitNameIndexes() {
std::vector<Language *> languages;
Language::ForEach([&languages](Language *l) {
languages.push_back(l);
- return true;
+ return IterationAction::Continue;
});
auto &name_to_index = GetNameToSymbolIndexMap(lldb::eFunctionNameTypeNone);
diff --git a/lldb/source/Target/Language.cpp b/lldb/source/Target/Language.cpp
index 484d9ba..d4a9268 100644
--- a/lldb/source/Target/Language.cpp
+++ b/lldb/source/Target/Language.cpp
@@ -111,9 +111,9 @@ Language *Language::FindPlugin(llvm::StringRef file_path) {
ForEach([&result, file_path](Language *language) {
if (language->IsSourceFile(file_path)) {
result = language;
- return false;
+ return IterationAction::Stop;
}
- return true;
+ return IterationAction::Continue;
});
return result;
}
@@ -128,7 +128,8 @@ Language *Language::FindPlugin(LanguageType language,
return result;
}
-void Language::ForEach(std::function<bool(Language *)> callback) {
+void Language::ForEach(
+ llvm::function_ref<IterationAction(Language *)> callback) {
// If we want to iterate over all languages, we first have to complete the
// LanguagesMap.
static llvm::once_flag g_initialize;
@@ -153,7 +154,7 @@ void Language::ForEach(std::function<bool(Language *)> callback) {
}
for (auto *lang : loaded_plugins) {
- if (!callback(lang))
+ if (callback(lang) == IterationAction::Stop)
break;
}
}
@@ -289,9 +290,9 @@ void Language::PrintAllLanguages(Stream &s, const char *prefix,
}
void Language::ForAllLanguages(
- std::function<bool(lldb::LanguageType)> callback) {
+ llvm::function_ref<IterationAction(lldb::LanguageType)> callback) {
for (uint32_t i = 1; i < num_languages; i++) {
- if (!callback(language_names[i].type))
+ if (callback(language_names[i].type) == IterationAction::Stop)
break;
}
}
@@ -416,7 +417,7 @@ std::set<lldb::LanguageType> Language::GetSupportedLanguages() {
std::set<lldb::LanguageType> supported_languages;
ForEach([&](Language *lang) {
supported_languages.emplace(lang->GetLanguageType());
- return true;
+ return IterationAction::Continue;
});
return supported_languages;
}