diff options
Diffstat (limited to 'lldb/source')
-rw-r--r-- | lldb/source/Breakpoint/BreakpointResolverName.cpp | 2 | ||||
-rw-r--r-- | lldb/source/Commands/CommandObjectType.cpp | 4 | ||||
-rw-r--r-- | lldb/source/Core/Mangled.cpp | 4 | ||||
-rw-r--r-- | lldb/source/Host/common/JSONTransport.cpp | 26 | ||||
-rw-r--r-- | lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp | 52 | ||||
-rw-r--r-- | lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h | 20 | ||||
-rw-r--r-- | lldb/source/Protocol/MCP/MCPError.cpp | 9 | ||||
-rw-r--r-- | lldb/source/Protocol/MCP/Server.cpp | 209 | ||||
-rw-r--r-- | lldb/source/Symbol/Symtab.cpp | 2 | ||||
-rw-r--r-- | lldb/source/Target/Language.cpp | 15 |
10 files changed, 129 insertions, 214 deletions
diff --git a/lldb/source/Breakpoint/BreakpointResolverName.cpp b/lldb/source/Breakpoint/BreakpointResolverName.cpp index 6372595..4f252f9 100644 --- a/lldb/source/Breakpoint/BreakpointResolverName.cpp +++ b/lldb/source/Breakpoint/BreakpointResolverName.cpp @@ -233,7 +233,7 @@ void BreakpointResolverName::AddNameLookup(ConstString name, m_lookups.emplace_back(variant_lookup); } } - return true; + return IterationAction::Continue; }; if (Language *lang = Language::FindPlugin(m_language)) { diff --git a/lldb/source/Commands/CommandObjectType.cpp b/lldb/source/Commands/CommandObjectType.cpp index 19cd3ff..22ed5b8 100644 --- a/lldb/source/Commands/CommandObjectType.cpp +++ b/lldb/source/Commands/CommandObjectType.cpp @@ -2610,7 +2610,7 @@ public: Language::ForEach([&](Language *lang) { if (const char *help = lang->GetLanguageSpecificTypeLookupHelp()) stream.Printf("%s\n", help); - return true; + return IterationAction::Continue; }); m_cmd_help_long = std::string(stream.GetString()); @@ -2649,7 +2649,7 @@ public: (m_command_options.m_language == eLanguageTypeUnknown))) { Language::ForEach([&](Language *lang) { languages.push_back(lang); - return true; + return IterationAction::Continue; }); } else { languages.push_back(Language::FindPlugin(m_command_options.m_language)); diff --git a/lldb/source/Core/Mangled.cpp b/lldb/source/Core/Mangled.cpp index 0780846..f7683c5 100644 --- a/lldb/source/Core/Mangled.cpp +++ b/lldb/source/Core/Mangled.cpp @@ -428,9 +428,9 @@ lldb::LanguageType Mangled::GuessLanguage() const { Language::ForEach([this, &result](Language *l) { if (l->SymbolNameFitsToLanguage(*this)) { result = l->GetLanguageType(); - return false; + return IterationAction::Stop; } - return true; + return IterationAction::Continue; }); return result; } diff --git a/lldb/source/Host/common/JSONTransport.cpp b/lldb/source/Host/common/JSONTransport.cpp index c4b42ea..22de7fa 100644 --- a/lldb/source/Host/common/JSONTransport.cpp +++ b/lldb/source/Host/common/JSONTransport.cpp @@ -14,8 +14,7 @@ #include <string> using namespace llvm; -using namespace lldb; -using namespace lldb_private; +using namespace lldb_private::transport; char TransportUnhandledContentsError::ID; @@ -23,10 +22,31 @@ TransportUnhandledContentsError::TransportUnhandledContentsError( std::string unhandled_contents) : m_unhandled_contents(unhandled_contents) {} -void TransportUnhandledContentsError::log(llvm::raw_ostream &OS) const { +void TransportUnhandledContentsError::log(raw_ostream &OS) const { OS << "transport EOF with unhandled contents: '" << m_unhandled_contents << "'"; } std::error_code TransportUnhandledContentsError::convertToErrorCode() const { return std::make_error_code(std::errc::bad_message); } + +char InvalidParams::ID; + +void InvalidParams::log(raw_ostream &OS) const { + OS << "invalid parameters for method '" << m_method << "': '" << m_context + << "'"; +} +std::error_code InvalidParams::convertToErrorCode() const { + return std::make_error_code(std::errc::invalid_argument); +} + +char MethodNotFound::ID; + +void MethodNotFound::log(raw_ostream &OS) const { + OS << "method not found: '" << m_method << "'"; +} + +std::error_code MethodNotFound::convertToErrorCode() const { + // JSON-RPC Method not found + return std::error_code(MethodNotFound::kErrorCode, std::generic_category()); +} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index d7293fc..33bdd5e 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -52,11 +52,6 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { } void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const { - server.AddNotificationHandler("notifications/initialized", - [](const lldb_protocol::mcp::Notification &) { - LLDB_LOG(GetLog(LLDBLog::Host), - "MCP initialization complete"); - }); server.AddTool( std::make_unique<CommandTool>("command", "Run an lldb command.")); server.AddTool(std::make_unique<DebuggerListTool>( @@ -74,26 +69,9 @@ void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) { io_sp, io_sp, [client_name](llvm::StringRef message) { LLDB_LOG(GetLog(LLDBLog::Host), "{0}: {1}", client_name, message); }); - MCPTransport *transport_ptr = transport_up.get(); - auto instance_up = std::make_unique<lldb_protocol::mcp::Server>( - std::string(kName), std::string(kVersion), *transport_up, - /*log_callback=*/ - [client_name](llvm::StringRef message) { - LLDB_LOG(GetLog(LLDBLog::Host), "{0} Server: {1}", client_name, - message); - }, - /*closed_callback=*/ - [this, transport_ptr]() { m_instances.erase(transport_ptr); }); - Extend(*instance_up); - llvm::Expected<MainLoop::ReadHandleUP> handle = - transport_up->RegisterMessageHandler(m_loop, *instance_up); - if (!handle) { - LLDB_LOG_ERROR(log, handle.takeError(), "Failed to run MCP server: {0}"); - return; - } - m_instances[transport_ptr] = - std::make_tuple<ServerUP, ReadHandleUP, TransportUP>( - std::move(instance_up), std::move(*handle), std::move(transport_up)); + + if (auto error = m_server->Accept(m_loop, std::move(transport_up))) + LLDB_LOG_ERROR(log, std::move(error), "{0}:"); } llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { @@ -124,14 +102,21 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { llvm::join(m_listener->GetListeningConnectionURI(), ", "); ServerInfo info{listening_uris[0]}; - llvm::Expected<ServerInfoHandle> handle = ServerInfo::Write(info); - if (!handle) - return handle.takeError(); + llvm::Expected<ServerInfoHandle> server_info_handle = ServerInfo::Write(info); + if (!server_info_handle) + return server_info_handle.takeError(); + + m_client_count = 0; + m_server = std::make_unique<lldb_protocol::mcp::Server>( + std::string(kName), std::string(kVersion), [](StringRef message) { + LLDB_LOG(GetLog(LLDBLog::Host), "MCP Server: {0}", message); + }); + Extend(*m_server); m_running = true; - m_server_info_handle = std::move(*handle); - m_listen_handlers = std::move(*handles); - m_loop_thread = std::thread([=] { + m_server_info_handle = std::move(*server_info_handle); + m_accept_handles = std::move(*handles); + m_loop_thread = std::thread([this] { llvm::set_thread_name("protocol-server.mcp"); m_loop.Run(); }); @@ -155,9 +140,10 @@ llvm::Error ProtocolServerMCP::Stop() { if (m_loop_thread.joinable()) m_loop_thread.join(); + m_accept_handles.clear(); + + m_server.reset(nullptr); m_server_info_handle.Remove(); - m_listen_handlers.clear(); - m_instances.clear(); return llvm::Error::success(); } diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index b325a36..e0f2a6c 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -23,16 +23,17 @@ namespace lldb_private::mcp { class ProtocolServerMCP : public ProtocolServer { - using ReadHandleUP = MainLoopBase::ReadHandleUP; - using TransportUP = std::unique_ptr<lldb_protocol::mcp::MCPTransport>; + using ServerUP = std::unique_ptr<lldb_protocol::mcp::Server>; + using ReadHandleUP = MainLoop::ReadHandleUP; + public: ProtocolServerMCP(); - virtual ~ProtocolServerMCP() override; + ~ProtocolServerMCP() override; - virtual llvm::Error Start(ProtocolServer::Connection connection) override; - virtual llvm::Error Stop() override; + llvm::Error Start(ProtocolServer::Connection connection) override; + llvm::Error Stop() override; static void Initialize(); static void Terminate(); @@ -56,19 +57,18 @@ private: bool m_running = false; - lldb_protocol::mcp::ServerInfoHandle m_server_info_handle; lldb_private::MainLoop m_loop; std::thread m_loop_thread; std::mutex m_mutex; size_t m_client_count = 0; std::unique_ptr<Socket> m_listener; + std::vector<ReadHandleUP> m_accept_handles; - std::vector<ReadHandleUP> m_listen_handlers; - std::map<lldb_protocol::mcp::MCPTransport *, - std::tuple<ServerUP, ReadHandleUP, TransportUP>> - m_instances; + ServerUP m_server; + lldb_protocol::mcp::ServerInfoHandle m_server_info_handle; }; + } // namespace lldb_private::mcp #endif diff --git a/lldb/source/Protocol/MCP/MCPError.cpp b/lldb/source/Protocol/MCP/MCPError.cpp index e140d11..cfac055 100644 --- a/lldb/source/Protocol/MCP/MCPError.cpp +++ b/lldb/source/Protocol/MCP/MCPError.cpp @@ -22,14 +22,7 @@ MCPError::MCPError(std::string message, int64_t error_code) void MCPError::log(llvm::raw_ostream &OS) const { OS << m_message; } std::error_code MCPError::convertToErrorCode() const { - return llvm::inconvertibleErrorCode(); -} - -lldb_protocol::mcp::Error MCPError::toProtocolError() const { - lldb_protocol::mcp::Error error; - error.code = m_error_code; - error.message = m_message; - return error; + return std::error_code(m_error_code, std::generic_category()); } UnsupportedURI::UnsupportedURI(std::string uri) : m_uri(uri) {} diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index 19030a3..71323ad 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -12,6 +12,7 @@ #include "lldb/Host/HostInfo.h" #include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Protocol/MCP/Transport.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" @@ -108,48 +109,9 @@ Expected<std::vector<ServerInfo>> ServerInfo::Load() { return infos; } -Server::Server(std::string name, std::string version, MCPTransport &client, - LogCallback log_callback, ClosedCallback closed_callback) - : m_name(std::move(name)), m_version(std::move(version)), m_client(client), - m_log_callback(std::move(log_callback)), - m_closed_callback(std::move(closed_callback)) { - AddRequestHandlers(); -} - -void Server::AddRequestHandlers() { - AddRequestHandler("initialize", std::bind(&Server::InitializeHandler, this, - std::placeholders::_1)); - AddRequestHandler("tools/list", std::bind(&Server::ToolsListHandler, this, - std::placeholders::_1)); - AddRequestHandler("tools/call", std::bind(&Server::ToolsCallHandler, this, - std::placeholders::_1)); - AddRequestHandler("resources/list", std::bind(&Server::ResourcesListHandler, - this, std::placeholders::_1)); - AddRequestHandler("resources/read", std::bind(&Server::ResourcesReadHandler, - this, std::placeholders::_1)); -} - -llvm::Expected<Response> Server::Handle(const Request &request) { - auto it = m_request_handlers.find(request.method); - if (it != m_request_handlers.end()) { - llvm::Expected<Response> response = it->second(request); - if (!response) - return response; - response->id = request.id; - return *response; - } - - return llvm::make_error<MCPError>( - llvm::formatv("no handler for request: {0}", request.method).str()); -} - -void Server::Handle(const Notification ¬ification) { - auto it = m_notification_handlers.find(notification.method); - if (it != m_notification_handlers.end()) { - it->second(notification); - return; - } -} +Server::Server(std::string name, std::string version, LogCallback log_callback) + : m_name(std::move(name)), m_version(std::move(version)), + m_log_callback(std::move(log_callback)) {} void Server::AddTool(std::unique_ptr<Tool> tool) { if (!tool) @@ -164,48 +126,64 @@ void Server::AddResourceProvider( m_resource_providers.push_back(std::move(resource_provider)); } -void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) { - m_request_handlers[method] = std::move(handler); -} - -void Server::AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler) { - m_notification_handlers[method] = std::move(handler); -} - -llvm::Expected<Response> Server::InitializeHandler(const Request &request) { - Response response; +MCPBinderUP Server::Bind(MCPTransport &transport) { + MCPBinderUP binder_up = std::make_unique<MCPBinder>(transport); + binder_up->Bind<InitializeResult, InitializeParams>( + "initialize", &Server::InitializeHandler, this); + binder_up->Bind<ListToolsResult, void>("tools/list", + &Server::ToolsListHandler, this); + binder_up->Bind<CallToolResult, CallToolParams>( + "tools/call", &Server::ToolsCallHandler, this); + binder_up->Bind<ListResourcesResult, void>( + "resources/list", &Server::ResourcesListHandler, this); + binder_up->Bind<ReadResourceResult, ReadResourceParams>( + "resources/read", &Server::ResourcesReadHandler, this); + binder_up->Bind<void>("notifications/initialized", + [this]() { Log("MCP initialization complete"); }); + return binder_up; +} + +llvm::Error Server::Accept(MainLoop &loop, MCPTransportUP transport) { + MCPBinderUP binder = Bind(*transport); + MCPTransport *transport_ptr = transport.get(); + binder->OnDisconnect([this, transport_ptr]() { + assert(m_instances.find(transport_ptr) != m_instances.end() && + "Client not found in m_instances"); + m_instances.erase(transport_ptr); + }); + binder->OnError([this](llvm::Error err) { + Logv("Transport error: {0}", llvm::toString(std::move(err))); + }); + + auto handle = transport->RegisterMessageHandler(loop, *binder); + if (!handle) + return handle.takeError(); + + m_instances[transport_ptr] = + Client{std::move(*handle), std::move(transport), std::move(binder)}; + return llvm::Error::success(); +} + +Expected<InitializeResult> +Server::InitializeHandler(const InitializeParams &request) { InitializeResult result; result.protocolVersion = mcp::kProtocolVersion; result.capabilities = GetCapabilities(); result.serverInfo.name = m_name; result.serverInfo.version = m_version; - response.result = std::move(result); - return response; + return result; } -llvm::Expected<Response> Server::ToolsListHandler(const Request &request) { - Response response; - +llvm::Expected<ListToolsResult> Server::ToolsListHandler() { ListToolsResult result; for (const auto &tool : m_tools) result.tools.emplace_back(tool.second->GetDefinition()); - response.result = std::move(result); - - return response; + return result; } -llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) { - Response response; - - if (!request.params) - return llvm::createStringError("no tool parameters"); - CallToolParams params; - json::Path::Root root("params"); - if (!fromJSON(request.params, params, root)) - return root.getError(); - +llvm::Expected<CallToolResult> +Server::ToolsCallHandler(const CallToolParams ¶ms) { llvm::StringRef tool_name = params.name; if (tool_name.empty()) return llvm::createStringError("no tool name"); @@ -222,113 +200,50 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) { if (!text_result) return text_result.takeError(); - response.result = toJSON(*text_result); - - return response; + return text_result; } -llvm::Expected<Response> Server::ResourcesListHandler(const Request &request) { - Response response; - +llvm::Expected<ListResourcesResult> Server::ResourcesListHandler() { ListResourcesResult result; for (std::unique_ptr<ResourceProvider> &resource_provider_up : m_resource_providers) for (const Resource &resource : resource_provider_up->GetResources()) result.resources.push_back(resource); - response.result = std::move(result); - - return response; + return result; } -llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) { - Response response; - - if (!request.params) - return llvm::createStringError("no resource parameters"); - - ReadResourceParams params; - json::Path::Root root("params"); - if (!fromJSON(request.params, params, root)) - return root.getError(); - - llvm::StringRef uri_str = params.uri; +Expected<ReadResourceResult> +Server::ResourcesReadHandler(const ReadResourceParams ¶ms) { + StringRef uri_str = params.uri; if (uri_str.empty()) - return llvm::createStringError("no resource uri"); + return createStringError("no resource uri"); for (std::unique_ptr<ResourceProvider> &resource_provider_up : m_resource_providers) { - llvm::Expected<ReadResourceResult> result = + Expected<ReadResourceResult> result = resource_provider_up->ReadResource(uri_str); if (result.errorIsA<UnsupportedURI>()) { - llvm::consumeError(result.takeError()); + consumeError(result.takeError()); continue; } if (!result) return result.takeError(); - Response response; - response.result = std::move(*result); - return response; + return *result; } return make_error<MCPError>( - llvm::formatv("no resource handler for uri: {0}", uri_str).str(), + formatv("no resource handler for uri: {0}", uri_str).str(), MCPError::kResourceNotFound); } ServerCapabilities Server::GetCapabilities() { lldb_protocol::mcp::ServerCapabilities capabilities; capabilities.supportsToolsList = true; + capabilities.supportsResourcesList = true; // FIXME: Support sending notifications when a debugger/target are // added/removed. - capabilities.supportsResourcesList = false; + capabilities.supportsResourcesSubscribe = false; return capabilities; } - -void Server::Log(llvm::StringRef message) { - if (m_log_callback) - m_log_callback(message); -} - -void Server::Received(const Request &request) { - auto SendResponse = [this](const Response &response) { - if (llvm::Error error = m_client.Send(response)) - Log(llvm::toString(std::move(error))); - }; - - llvm::Expected<Response> response = Handle(request); - if (response) - return SendResponse(*response); - - lldb_protocol::mcp::Error protocol_error; - llvm::handleAllErrors( - response.takeError(), - [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, - [&](const llvm::ErrorInfoBase &err) { - protocol_error.code = MCPError::kInternalError; - protocol_error.message = err.message(); - }); - Response error_response; - error_response.id = request.id; - error_response.result = std::move(protocol_error); - SendResponse(error_response); -} - -void Server::Received(const Response &response) { - Log("unexpected MCP message: response"); -} - -void Server::Received(const Notification ¬ification) { - Handle(notification); -} - -void Server::OnError(llvm::Error error) { - Log(llvm::toString(std::move(error))); -} - -void Server::OnClosed() { - Log("EOF"); - if (m_closed_callback) - m_closed_callback(); -} diff --git a/lldb/source/Symbol/Symtab.cpp b/lldb/source/Symbol/Symtab.cpp index 970f6c4..6080703 100644 --- a/lldb/source/Symbol/Symtab.cpp +++ b/lldb/source/Symbol/Symtab.cpp @@ -289,7 +289,7 @@ void Symtab::InitNameIndexes() { std::vector<Language *> languages; Language::ForEach([&languages](Language *l) { languages.push_back(l); - return true; + return IterationAction::Continue; }); auto &name_to_index = GetNameToSymbolIndexMap(lldb::eFunctionNameTypeNone); diff --git a/lldb/source/Target/Language.cpp b/lldb/source/Target/Language.cpp index 484d9ba..d4a9268 100644 --- a/lldb/source/Target/Language.cpp +++ b/lldb/source/Target/Language.cpp @@ -111,9 +111,9 @@ Language *Language::FindPlugin(llvm::StringRef file_path) { ForEach([&result, file_path](Language *language) { if (language->IsSourceFile(file_path)) { result = language; - return false; + return IterationAction::Stop; } - return true; + return IterationAction::Continue; }); return result; } @@ -128,7 +128,8 @@ Language *Language::FindPlugin(LanguageType language, return result; } -void Language::ForEach(std::function<bool(Language *)> callback) { +void Language::ForEach( + llvm::function_ref<IterationAction(Language *)> callback) { // If we want to iterate over all languages, we first have to complete the // LanguagesMap. static llvm::once_flag g_initialize; @@ -153,7 +154,7 @@ void Language::ForEach(std::function<bool(Language *)> callback) { } for (auto *lang : loaded_plugins) { - if (!callback(lang)) + if (callback(lang) == IterationAction::Stop) break; } } @@ -289,9 +290,9 @@ void Language::PrintAllLanguages(Stream &s, const char *prefix, } void Language::ForAllLanguages( - std::function<bool(lldb::LanguageType)> callback) { + llvm::function_ref<IterationAction(lldb::LanguageType)> callback) { for (uint32_t i = 1; i < num_languages; i++) { - if (!callback(language_names[i].type)) + if (callback(language_names[i].type) == IterationAction::Stop) break; } } @@ -416,7 +417,7 @@ std::set<lldb::LanguageType> Language::GetSupportedLanguages() { std::set<lldb::LanguageType> supported_languages; ForEach([&](Language *lang) { supported_languages.emplace(lang->GetLanguageType()); - return true; + return IterationAction::Continue; }); return supported_languages; } |