diff options
author | John Harrison <harjohn@google.com> | 2025-08-21 15:26:52 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-08-21 15:26:52 -0700 |
commit | 36d07ad83b1e537e976f9ae3da5b618d3ccf951c (patch) | |
tree | 3ec208c5e6e72cfbefe045ce51d876ac25980849 /lldb/packages/Python/lldbsuite/test | |
parent | 33f6b10c179f3636904415b73b7d71556583061b (diff) | |
download | llvm-36d07ad83b1e537e976f9ae3da5b618d3ccf951c.zip llvm-36d07ad83b1e537e976f9ae3da5b618d3ccf951c.tar.gz llvm-36d07ad83b1e537e976f9ae3da5b618d3ccf951c.tar.bz2 |
Reapply "[lldb-dap] Re-land refactor of DebugCommunication. (#147787)" (#154832)
This reverts commit 0f33b90b6117bcfa6ca3779c641c1ee8d03590fd and
includes a fix for the added test that was submitted between my last
update and pull.
Diffstat (limited to 'lldb/packages/Python/lldbsuite/test')
-rw-r--r-- | lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py | 806 | ||||
-rw-r--r-- | lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py | 87 |
2 files changed, 510 insertions, 383 deletions
diff --git a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py index 7acb9c8..0608ac3 100644 --- a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py +++ b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py @@ -12,15 +12,91 @@ import signal import sys import threading import time -from typing import Any, Optional, Union, BinaryIO, TextIO +from typing import ( + Any, + Optional, + Dict, + cast, + List, + Callable, + IO, + Union, + BinaryIO, + TextIO, + TypedDict, + Literal, +) ## DAP type references -Event = dict[str, Any] -Request = dict[str, Any] -Response = dict[str, Any] + + +class Event(TypedDict): + type: Literal["event"] + seq: int + event: str + body: Any + + +class Request(TypedDict, total=False): + type: Literal["request"] + seq: int + command: str + arguments: Any + + +class Response(TypedDict): + type: Literal["response"] + seq: int + request_seq: int + success: bool + command: str + message: Optional[str] + body: Any + + ProtocolMessage = Union[Event, Request, Response] +class Source(TypedDict, total=False): + name: str + path: str + sourceReference: int + + @staticmethod + def build( + *, + name: Optional[str] = None, + path: Optional[str] = None, + source_reference: Optional[int] = None, + ) -> "Source": + """Builds a source from the given name, path or source_reference.""" + if not name and not path and not source_reference: + raise ValueError( + "Source.build requires either name, path, or source_reference" + ) + + s = Source() + if name: + s["name"] = name + if path: + if not name: + s["name"] = os.path.basename(path) + s["path"] = path + if source_reference is not None: + s["sourceReference"] = source_reference + return s + + +class Breakpoint(TypedDict, total=False): + id: int + verified: bool + source: Source + + @staticmethod + def is_verified(src: "Breakpoint") -> bool: + return src.get("verified", False) + + def dump_memory(base_addr, data, num_per_line, outfile): data_len = len(data) hex_string = binascii.hexlify(data) @@ -58,7 +134,9 @@ def dump_memory(base_addr, data, num_per_line, outfile): outfile.write("\n") -def read_packet(f, verbose=False, trace_file=None): +def read_packet( + f: IO[bytes], trace_file: Optional[IO[str]] = None +) -> Optional[ProtocolMessage]: """Decode a JSON packet that starts with the content length and is followed by the JSON bytes from a file 'f'. Returns None on EOF. """ @@ -70,19 +148,13 @@ def read_packet(f, verbose=False, trace_file=None): prefix = "Content-Length: " if line.startswith(prefix): # Decode length of JSON bytes - if verbose: - print('content: "%s"' % (line)) length = int(line[len(prefix) :]) - if verbose: - print('length: "%u"' % (length)) # Skip empty line - line = f.readline() - if verbose: - print('empty: "%s"' % (line)) + separator = f.readline().decode() + if separator != "": + Exception("malformed DAP content header, unexpected line: " + separator) # Read JSON bytes - json_str = f.read(length) - if verbose: - print('json: "%s"' % (json_str)) + json_str = f.read(length).decode() if trace_file: trace_file.write("from adapter:\n%s\n" % (json_str)) # Decode the JSON bytes into a python dictionary @@ -95,7 +167,7 @@ def packet_type_is(packet, packet_type): return "type" in packet and packet["type"] == packet_type -def dump_dap_log(log_file): +def dump_dap_log(log_file: Optional[str]) -> None: print("========= DEBUG ADAPTER PROTOCOL LOGS =========", file=sys.stderr) if log_file is None: print("no log file available", file=sys.stderr) @@ -105,58 +177,6 @@ def dump_dap_log(log_file): print("========= END =========", file=sys.stderr) -class Source(object): - def __init__( - self, - path: Optional[str] = None, - source_reference: Optional[int] = None, - raw_dict: Optional[dict[str, Any]] = None, - ): - self._name = None - self._path = None - self._source_reference = None - self._raw_dict = None - - if path is not None: - self._name = os.path.basename(path) - self._path = path - elif source_reference is not None: - self._source_reference = source_reference - elif raw_dict is not None: - self._raw_dict = raw_dict - else: - raise ValueError("Either path or source_reference must be provided") - - def __str__(self): - return f"Source(name={self.name}, path={self.path}), source_reference={self.source_reference})" - - def as_dict(self): - if self._raw_dict is not None: - return self._raw_dict - - source_dict = {} - if self._name is not None: - source_dict["name"] = self._name - if self._path is not None: - source_dict["path"] = self._path - if self._source_reference is not None: - source_dict["sourceReference"] = self._source_reference - return source_dict - - -class Breakpoint(object): - def __init__(self, obj): - self._breakpoint = obj - - def is_verified(self): - """Check if the breakpoint is verified.""" - return self._breakpoint.get("verified", False) - - def source(self): - """Get the source of the breakpoint.""" - return self._breakpoint.get("source", {}) - - class NotSupportedError(KeyError): """Raised if a feature is not supported due to its capabilities.""" @@ -174,26 +194,42 @@ class DebugCommunication(object): self.log_file = log_file self.send = send self.recv = recv - self.recv_packets: list[Optional[ProtocolMessage]] = [] - self.recv_condition = threading.Condition() - self.recv_thread = threading.Thread(target=self._read_packet_thread) - self.process_event_body = None - self.exit_status: Optional[int] = None - self.capabilities: dict[str, Any] = {} - self.progress_events: list[Event] = [] - self.reverse_requests = [] - self.sequence = 1 - self.threads = None - self.thread_stop_reasons = {} - self.recv_thread.start() - self.output_condition = threading.Condition() - self.output: dict[str, list[str]] = {} - self.configuration_done_sent = False - self.initialized = False - self.frame_scopes = {} + + # Packets that have been received and processed but have not yet been + # requested by a test case. + self._pending_packets: List[Optional[ProtocolMessage]] = [] + # Received packets that have not yet been processed. + self._recv_packets: List[Optional[ProtocolMessage]] = [] + # Used as a mutex for _recv_packets and for notify when _recv_packets + # changes. + self._recv_condition = threading.Condition() + self._recv_thread = threading.Thread(target=self._read_packet_thread) + + # session state self.init_commands = init_commands + self.exit_status: Optional[int] = None + self.capabilities: Dict = {} + self.initialized: bool = False + self.configuration_done_sent: bool = False + self.process_event_body: Optional[Dict] = None + self.terminated: bool = False + self.events: List[Event] = [] + self.progress_events: List[Event] = [] + self.reverse_requests: List[Request] = [] + self.module_events: List[Dict] = [] + self.sequence: int = 1 + self.output: Dict[str, str] = {} + + # debuggee state + self.threads: Optional[dict] = None + self.thread_stop_reasons: Dict[str, Any] = {} + self.frame_scopes: Dict[str, Any] = {} + # keyed by breakpoint id self.resolved_breakpoints: dict[str, Breakpoint] = {} + # trigger enqueue thread + self._recv_thread.start() + @classmethod def encode_content(cls, s: str) -> bytes: return ("Content-Length: %u\r\n\r\n%s" % (len(s), s)).encode("utf-8") @@ -210,267 +246,324 @@ class DebugCommunication(object): ) def _read_packet_thread(self): - done = False try: - while not done: + while True: packet = read_packet(self.recv, trace_file=self.trace_file) # `packet` will be `None` on EOF. We want to pass it down to # handle_recv_packet anyway so the main thread can handle unexpected # termination of lldb-dap and stop waiting for new packets. - done = not self._handle_recv_packet(packet) + if not self._handle_recv_packet(packet): + break finally: dump_dap_log(self.log_file) - def get_modules(self, startModule: int = 0, moduleCount: int = 0): - module_list = self.request_modules(startModule, moduleCount)["body"]["modules"] + def get_modules( + self, start_module: Optional[int] = None, module_count: Optional[int] = None + ) -> Dict: + resp = self.request_modules(start_module, module_count) + if not resp["success"]: + raise ValueError(f"request_modules failed: {resp!r}") modules = {} + module_list = resp["body"]["modules"] for module in module_list: modules[module["name"]] = module return modules - def get_output(self, category, timeout=0.0, clear=True): - self.output_condition.acquire() - output = None + def get_output(self, category: str, clear=True) -> str: + output = "" if category in self.output: - output = self.output[category] + output = self.output.get(category, "") if clear: del self.output[category] - elif timeout != 0.0: - self.output_condition.wait(timeout) - if category in self.output: - output = self.output[category] - if clear: - del self.output[category] - self.output_condition.release() return output - def collect_output(self, category, timeout_secs, pattern, clear=True): - end_time = time.time() + timeout_secs - collected_output = "" - while end_time > time.time(): - output = self.get_output(category, timeout=0.25, clear=clear) - if output: - collected_output += output - if pattern is not None and pattern in output: - break - return collected_output if collected_output else None + def collect_output( + self, + category: str, + timeout: float, + pattern: Optional[str] = None, + clear=True, + ) -> str: + """Collect output from 'output' events. + Args: + category: The category to collect. + timeout: The max duration for collecting output. + pattern: + Optional, if set, return once this pattern is detected in the + collected output. + Returns: + The collected output. + """ + deadline = time.monotonic() + timeout + output = self.get_output(category, clear) + while deadline >= time.monotonic() and ( + pattern is None or pattern not in output + ): + event = self.wait_for_event(["output"], timeout=deadline - time.monotonic()) + if not event: # Timeout or EOF + break + output += self.get_output(category, clear=clear) + return output def _enqueue_recv_packet(self, packet: Optional[ProtocolMessage]): - self.recv_condition.acquire() - self.recv_packets.append(packet) - self.recv_condition.notify() - self.recv_condition.release() + with self.recv_condition: + self.recv_packets.append(packet) + self.recv_condition.notify() def _handle_recv_packet(self, packet: Optional[ProtocolMessage]) -> bool: - """Called by the read thread that is waiting for all incoming packets - to store the incoming packet in "self.recv_packets" in a thread safe - way. This function will then signal the "self.recv_condition" to - indicate a new packet is available. Returns True if the caller - should keep calling this function for more packets. + """Handles an incoming packet. + + Called by the read thread that is waiting for all incoming packets + to store the incoming packet in "self._recv_packets" in a thread safe + way. This function will then signal the "self._recv_condition" to + indicate a new packet is available. + + Args: + packet: A new packet to store. + + Returns: + True if the caller should keep calling this function for more + packets. """ - # If EOF, notify the read thread by enqueuing a None. - if not packet: - self._enqueue_recv_packet(None) - return False - - # Check the packet to see if is an event packet - keepGoing = True - packet_type = packet["type"] - if packet_type == "event": - event = packet["event"] - body = None - if "body" in packet: - body = packet["body"] - # Handle the event packet and cache information from these packets - # as they come in - if event == "output": - # Store any output we receive so clients can retrieve it later. - category = body["category"] - output = body["output"] - self.output_condition.acquire() - if category in self.output: - self.output[category] += output - else: - self.output[category] = output - self.output_condition.notify() - self.output_condition.release() - # no need to add 'output' event packets to our packets list - return keepGoing - elif event == "initialized": - self.initialized = True - elif event == "process": - # When a new process is attached or launched, remember the - # details that are available in the body of the event - self.process_event_body = body - elif event == "exited": - # Process exited, mark the status to indicate the process is not - # alive. - self.exit_status = body["exitCode"] - elif event == "continued": - # When the process continues, clear the known threads and - # thread_stop_reasons. - all_threads_continued = body.get("allThreadsContinued", True) - tid = body["threadId"] - if tid in self.thread_stop_reasons: - del self.thread_stop_reasons[tid] - self._process_continued(all_threads_continued) - elif event == "stopped": - # Each thread that stops with a reason will send a - # 'stopped' event. We need to remember the thread stop - # reasons since the 'threads' command doesn't return - # that information. - self._process_stopped() - tid = body["threadId"] - self.thread_stop_reasons[tid] = body - elif event.startswith("progress"): - # Progress events come in as 'progressStart', 'progressUpdate', - # and 'progressEnd' events. Keep these around in case test - # cases want to verify them. - self.progress_events.append(packet) - elif event == "breakpoint": - # Breakpoint events are sent when a breakpoint is resolved - self._update_verified_breakpoints([body["breakpoint"]]) - elif event == "capabilities": - # Update the capabilities with new ones from the event. - self.capabilities.update(body["capabilities"]) - - elif packet_type == "response": - if packet["command"] == "disconnect": - keepGoing = False - self._enqueue_recv_packet(packet) - return keepGoing + with self._recv_condition: + self._recv_packets.append(packet) + self._recv_condition.notify() + # packet is None on EOF + return packet is not None and not ( + packet["type"] == "response" and packet["command"] == "disconnect" + ) + + def _recv_packet( + self, + *, + predicate: Optional[Callable[[ProtocolMessage], bool]] = None, + timeout: Optional[float] = None, + ) -> Optional[ProtocolMessage]: + """Processes received packets from the adapter. + Updates the DebugCommunication stateful properties based on the received + packets in the order they are received. + NOTE: The only time the session state properties should be updated is + during this call to ensure consistency during tests. + Args: + predicate: + Optional, if specified, returns the first packet that matches + the given predicate. + timeout: + Optional, if specified, processes packets until either the + timeout occurs or the predicate matches a packet, whichever + occurs first. + Returns: + The first matching packet for the given predicate, if specified, + otherwise None. + """ + assert ( + threading.current_thread != self._recv_thread + ), "Must not be called from the _recv_thread" + + def process_until_match(): + self._process_recv_packets() + for i, packet in enumerate(self._pending_packets): + if packet is None: + # We need to return a truthy value to break out of the + # wait_for, use `EOFError` as an indicator of EOF. + return EOFError() + if predicate and predicate(packet): + self._pending_packets.pop(i) + return packet + + with self._recv_condition: + packet = self._recv_condition.wait_for(process_until_match, timeout) + return None if isinstance(packet, EOFError) else packet + + def _process_recv_packets(self) -> None: + """Process received packets, updating the session state.""" + with self._recv_condition: + for packet in self._recv_packets: + # Handle events that may modify any stateful properties of + # the DAP session. + if packet and packet["type"] == "event": + self._handle_event(packet) + elif packet and packet["type"] == "request": + # Handle reverse requests and keep processing. + self._handle_reverse_request(packet) + # Move the packet to the pending queue. + self._pending_packets.append(packet) + self._recv_packets.clear() + + def _handle_event(self, packet: Event) -> None: + """Handle any events that modify debug session state we track.""" + event = packet["event"] + body: Optional[Dict] = packet.get("body", None) + + if event == "output" and body: + # Store any output we receive so clients can retrieve it later. + category = body["category"] + output = body["output"] + if category in self.output: + self.output[category] += output + else: + self.output[category] = output + elif event == "initialized": + self.initialized = True + elif event == "process": + # When a new process is attached or launched, remember the + # details that are available in the body of the event + self.process_event_body = body + elif event == "exited" and body: + # Process exited, mark the status to indicate the process is not + # alive. + self.exit_status = body["exitCode"] + elif event == "continued" and body: + # When the process continues, clear the known threads and + # thread_stop_reasons. + all_threads_continued = body.get("allThreadsContinued", True) + tid = body["threadId"] + if tid in self.thread_stop_reasons: + del self.thread_stop_reasons[tid] + self._process_continued(all_threads_continued) + elif event == "stopped" and body: + # Each thread that stops with a reason will send a + # 'stopped' event. We need to remember the thread stop + # reasons since the 'threads' command doesn't return + # that information. + self._process_stopped() + tid = body["threadId"] + self.thread_stop_reasons[tid] = body + elif event.startswith("progress"): + # Progress events come in as 'progressStart', 'progressUpdate', + # and 'progressEnd' events. Keep these around in case test + # cases want to verify them. + self.progress_events.append(packet) + elif event == "breakpoint" and body: + # Breakpoint events are sent when a breakpoint is resolved + self._update_verified_breakpoints([body["breakpoint"]]) + elif event == "capabilities" and body: + # Update the capabilities with new ones from the event. + self.capabilities.update(body["capabilities"]) + + def _handle_reverse_request(self, request: Request) -> None: + if request in self.reverse_requests: + return + self.reverse_requests.append(request) + arguments = request.get("arguments") + if request["command"] == "runInTerminal" and arguments is not None: + in_shell = arguments.get("argsCanBeInterpretedByShell", False) + print("spawning...", arguments["args"]) + proc = subprocess.Popen( + arguments["args"], + env=arguments.get("env", {}), + cwd=arguments.get("cwd", None), + stdin=subprocess.DEVNULL, + stdout=sys.stderr, + stderr=sys.stderr, + shell=in_shell, + ) + body = {} + if in_shell: + body["shellProcessId"] = proc.pid + else: + body["processId"] = proc.pid + self.send_packet( + { + "type": "response", + "seq": 0, + "request_seq": request["seq"], + "success": True, + "command": "runInTerminal", + "body": body, + } + ) + elif request["command"] == "startDebugging": + self.send_packet( + { + "type": "response", + "seq": 0, + "request_seq": request["seq"], + "success": True, + "message": None, + "command": "startDebugging", + "body": {}, + } + ) + else: + desc = 'unknown reverse request "%s"' % (request["command"]) + raise ValueError(desc) def _process_continued(self, all_threads_continued: bool): self.frame_scopes = {} if all_threads_continued: self.thread_stop_reasons = {} - def _update_verified_breakpoints(self, breakpoints: list[Event]): - for breakpoint in breakpoints: - if "id" in breakpoint: - self.resolved_breakpoints[str(breakpoint["id"])] = Breakpoint( - breakpoint - ) + def _update_verified_breakpoints(self, breakpoints: list[Breakpoint]): + for bp in breakpoints: + # If no id is set, we cannot correlate the given breakpoint across + # requests, ignore it. + if "id" not in bp: + continue + + self.resolved_breakpoints[str(bp["id"])] = bp - def send_packet(self, command_dict: Request, set_sequence=True): - """Take the "command_dict" python dictionary and encode it as a JSON - string and send the contents as a packet to the VSCode debug - adapter""" - # Set the sequence ID for this command automatically - if set_sequence: - command_dict["seq"] = self.sequence + def send_packet(self, packet: ProtocolMessage) -> int: + """Takes a dictionary representation of a DAP request and send the request to the debug adapter. + + Returns the seq number of the request. + """ + # Set the seq for requests. + if packet["type"] == "request": + packet["seq"] = self.sequence self.sequence += 1 + else: + packet["seq"] = 0 + # Encode our command dictionary as a JSON string - json_str = json.dumps(command_dict, separators=(",", ":")) + json_str = json.dumps(packet, separators=(",", ":")) + if self.trace_file: self.trace_file.write("to adapter:\n%s\n" % (json_str)) + length = len(json_str) if length > 0: # Send the encoded JSON packet and flush the 'send' file self.send.write(self.encode_content(json_str)) self.send.flush() - def recv_packet( - self, - filter_type: Optional[str] = None, - filter_event: Optional[Union[str, list[str]]] = None, - timeout: Optional[float] = None, - ) -> Optional[ProtocolMessage]: - """Get a JSON packet from the VSCode debug adapter. This function - assumes a thread that reads packets is running and will deliver - any received packets by calling handle_recv_packet(...). This - function will wait for the packet to arrive and return it when - it does.""" - while True: - try: - self.recv_condition.acquire() - packet = None - while True: - for i, curr_packet in enumerate(self.recv_packets): - if not curr_packet: - raise EOFError - packet_type = curr_packet["type"] - if filter_type is None or packet_type in filter_type: - if filter_event is None or ( - packet_type == "event" - and curr_packet["event"] in filter_event - ): - packet = self.recv_packets.pop(i) - break - if packet: - break - # Sleep until packet is received - len_before = len(self.recv_packets) - self.recv_condition.wait(timeout) - len_after = len(self.recv_packets) - if len_before == len_after: - return None # Timed out - return packet - except EOFError: - return None - finally: - self.recv_condition.release() - - def send_recv(self, command): + return packet["seq"] + + def _send_recv(self, request: Request) -> Optional[Response]: """Send a command python dictionary as JSON and receive the JSON response. Validates that the response is the correct sequence and command in the reply. Any events that are received are added to the events list in this object""" - self.send_packet(command) - done = False - while not done: - response_or_request = self.recv_packet(filter_type=["response", "request"]) - if response_or_request is None: - desc = 'no response for "%s"' % (command["command"]) - raise ValueError(desc) - if response_or_request["type"] == "response": - self.validate_response(command, response_or_request) - return response_or_request - else: - self.reverse_requests.append(response_or_request) - if response_or_request["command"] == "runInTerminal": - subprocess.Popen( - response_or_request["arguments"].get("args"), - env=response_or_request["arguments"].get("env", {}), - ) - self.send_packet( - { - "type": "response", - "request_seq": response_or_request["seq"], - "success": True, - "command": "runInTerminal", - "body": {}, - }, - ) - elif response_or_request["command"] == "startDebugging": - self.send_packet( - { - "type": "response", - "request_seq": response_or_request["seq"], - "success": True, - "command": "startDebugging", - "body": {}, - }, - ) - else: - desc = 'unknown reverse request "%s"' % ( - response_or_request["command"] - ) - raise ValueError(desc) + seq = self.send_packet(request) + response = self.receive_response(seq) + if response is None: + raise ValueError(f"no response for {request!r}") + self.validate_response(request, response) + return response - return None + def receive_response(self, seq: int) -> Optional[Response]: + """Waits for a response with the associated request_sec.""" + + def predicate(p: ProtocolMessage): + return p["type"] == "response" and p["request_seq"] == seq + + return cast(Optional[Response], self._recv_packet(predicate=predicate)) def wait_for_event( - self, filter: Union[str, list[str]], timeout: Optional[float] = None + self, filter: List[str] = [], timeout: Optional[float] = None ) -> Optional[Event]: """Wait for the first event that matches the filter.""" - return self.recv_packet( - filter_type="event", filter_event=filter, timeout=timeout + + def predicate(p: ProtocolMessage): + return p["type"] == "event" and p["event"] in filter + + return cast( + Optional[Event], self._recv_packet(predicate=predicate, timeout=timeout) ) def wait_for_stopped( self, timeout: Optional[float] = None - ) -> Optional[list[Event]]: + ) -> Optional[List[Event]]: stopped_events = [] stopped_event = self.wait_for_event( filter=["stopped", "exited"], timeout=timeout @@ -491,7 +584,7 @@ class DebugCommunication(object): def wait_for_breakpoint_events(self, timeout: Optional[float] = None): breakpoint_events: list[Event] = [] while True: - event = self.wait_for_event("breakpoint", timeout=timeout) + event = self.wait_for_event(["breakpoint"], timeout=timeout) if not event: break breakpoint_events.append(event) @@ -502,7 +595,7 @@ class DebugCommunication(object): ): """Wait for all breakpoints to be verified. Return all unverified breakpoints.""" while any(id not in self.resolved_breakpoints for id in breakpoint_ids): - breakpoint_event = self.wait_for_event("breakpoint", timeout=timeout) + breakpoint_event = self.wait_for_event(["breakpoint"], timeout=timeout) if breakpoint_event is None: break @@ -511,18 +604,18 @@ class DebugCommunication(object): for id in breakpoint_ids if ( id not in self.resolved_breakpoints - or not self.resolved_breakpoints[id].is_verified() + or not Breakpoint.is_verified(self.resolved_breakpoints[id]) ) ] def wait_for_exited(self, timeout: Optional[float] = None): - event_dict = self.wait_for_event("exited", timeout=timeout) + event_dict = self.wait_for_event(["exited"], timeout=timeout) if event_dict is None: raise ValueError("didn't get exited event") return event_dict def wait_for_terminated(self, timeout: Optional[float] = None): - event_dict = self.wait_for_event("terminated", timeout) + event_dict = self.wait_for_event(["terminated"], timeout) if event_dict is None: raise ValueError("didn't get terminated event") return event_dict @@ -733,7 +826,7 @@ class DebugCommunication(object): if gdbRemoteHostname is not None: args_dict["gdb-remote-hostname"] = gdbRemoteHostname command_dict = {"command": "attach", "type": "request", "arguments": args_dict} - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_breakpointLocations( self, file_path, line, end_line=None, column=None, end_column=None @@ -755,7 +848,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_configurationDone(self): command_dict = { @@ -763,7 +856,7 @@ class DebugCommunication(object): "type": "request", "arguments": {}, } - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) if response: self.configuration_done_sent = True self.request_threads() @@ -792,7 +885,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) if response["success"]: self._process_continued(response["body"]["allThreadsContinued"]) # Caller must still call wait_for_stopped. @@ -809,7 +902,7 @@ class DebugCommunication(object): if restartArguments: command_dict["arguments"] = restartArguments - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) # Caller must still call wait_for_stopped. return response @@ -825,7 +918,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_disassemble( self, @@ -845,7 +938,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict)["body"]["instructions"] + return self._send_recv(command_dict)["body"]["instructions"] def request_readMemory(self, memoryReference, offset, count): args_dict = { @@ -858,7 +951,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_writeMemory(self, memoryReference, data, offset=0, allowPartial=False): args_dict = { @@ -876,7 +969,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_evaluate(self, expression, frameIndex=0, threadId=None, context=None): stackFrame = self.get_stackFrame(frameIndex=frameIndex, threadId=threadId) @@ -892,7 +985,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_exceptionInfo(self, threadId=None): if threadId is None: @@ -903,7 +996,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_initialize(self, sourceInitFile=False): command_dict = { @@ -924,10 +1017,10 @@ class DebugCommunication(object): "$__lldb_sourceInitFile": sourceInitFile, }, } - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) if response: if "body" in response: - self.capabilities = response["body"] + self.capabilities.update(response.get("body", {})) return response def request_launch( @@ -1007,14 +1100,14 @@ class DebugCommunication(object): if commandEscapePrefix is not None: args_dict["commandEscapePrefix"] = commandEscapePrefix command_dict = {"command": "launch", "type": "request", "arguments": args_dict} - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_next(self, threadId, granularity="statement"): if self.exit_status is not None: raise ValueError("request_continue called after process exited") args_dict = {"threadId": threadId, "granularity": granularity} command_dict = {"command": "next", "type": "request", "arguments": args_dict} - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_stepIn(self, threadId, targetId, granularity="statement"): if self.exit_status is not None: @@ -1027,7 +1120,7 @@ class DebugCommunication(object): "granularity": granularity, } command_dict = {"command": "stepIn", "type": "request", "arguments": args_dict} - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_stepInTargets(self, frameId): if self.exit_status is not None: @@ -1039,14 +1132,14 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_stepOut(self, threadId): if self.exit_status is not None: raise ValueError("request_stepOut called after process exited") args_dict = {"threadId": threadId} command_dict = {"command": "stepOut", "type": "request", "arguments": args_dict} - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_pause(self, threadId=None): if self.exit_status is not None: @@ -1055,12 +1148,12 @@ class DebugCommunication(object): threadId = self.get_thread_id() args_dict = {"threadId": threadId} command_dict = {"command": "pause", "type": "request", "arguments": args_dict} - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_scopes(self, frameId): args_dict = {"frameId": frameId} command_dict = {"command": "scopes", "type": "request", "arguments": args_dict} - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_setBreakpoints(self, source: Source, line_array, data=None): """data is array of parameters for breakpoints in line_array. @@ -1068,7 +1161,7 @@ class DebugCommunication(object): It contains optional location/hitCondition/logMessage parameters. """ args_dict = { - "source": source.as_dict(), + "source": source, "sourceModified": False, } if line_array is not None: @@ -1096,7 +1189,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) if response["success"]: self._update_verified_breakpoints(response["body"]["breakpoints"]) return response @@ -1112,7 +1205,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_setFunctionBreakpoints(self, names, condition=None, hitCondition=None): breakpoints = [] @@ -1129,7 +1222,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) if response["success"]: self._update_verified_breakpoints(response["body"]["breakpoints"]) return response @@ -1150,7 +1243,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_setDataBreakpoint(self, dataBreakpoints): """dataBreakpoints is a list of dictionary with following fields: @@ -1167,7 +1260,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_compileUnits(self, moduleId): args_dict = {"moduleId": moduleId} @@ -1176,7 +1269,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) return response def request_completions(self, text, frameId=None): @@ -1188,15 +1281,22 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) - - def request_modules(self, startModule: int, moduleCount: int): - return self.send_recv( - { - "command": "modules", - "type": "request", - "arguments": {"startModule": startModule, "moduleCount": moduleCount}, - } + return self._send_recv(command_dict) + + def request_modules( + self, + start_module: Optional[int] = None, + module_count: Optional[int] = None, + ): + args_dict = {} + + if start_module is not None: + args_dict["startModule"] = start_module + if module_count is not None: + args_dict["moduleCount"] = module_count + + return self._send_recv( + {"command": "modules", "type": "request", "arguments": args_dict} ) def request_moduleSymbols( @@ -1216,7 +1316,7 @@ class DebugCommunication(object): "count": count, }, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_stackTrace( self, threadId=None, startFrame=None, levels=None, format=None, dump=False @@ -1235,7 +1335,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) if dump: for idx, frame in enumerate(response["body"]["stackFrames"]): name = frame["name"] @@ -1250,18 +1350,30 @@ class DebugCommunication(object): print("[%3u] %s" % (idx, name)) return response - def request_source(self, sourceReference): + def request_source( + self, *, source: Optional[Source] = None, sourceReference: Optional[int] = None + ): """Request a source from a 'Source' reference.""" + if source is None and sourceReference is None: + raise ValueError("request_source requires either source or sourceReference") + elif source is not None: + sourceReference = source["sourceReference"] + elif sourceReference is not None: + source = {"sourceReference": sourceReference} + else: + raise ValueError( + "request_source requires either source or sourceReference not both" + ) command_dict = { "command": "source", "type": "request", "arguments": { - "source": {"sourceReference": sourceReference}, + "source": source, # legacy version of the request "sourceReference": sourceReference, }, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_threads(self): """Request a list of all threads and combine any information from any @@ -1269,7 +1381,7 @@ class DebugCommunication(object): thread actually stopped. Returns an array of thread dictionaries with information about all threads""" command_dict = {"command": "threads", "type": "request", "arguments": {}} - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) if not response["success"]: self.threads = None return response @@ -1309,7 +1421,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_setVariable(self, containingVarRef, name, value, id=None): args_dict = { @@ -1324,7 +1436,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_locations(self, locationReference): args_dict = { @@ -1335,7 +1447,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_testGetTargetBreakpoints(self): """A request packet used in the LLDB test suite to get all currently @@ -1347,12 +1459,12 @@ class DebugCommunication(object): "type": "request", "arguments": {}, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def terminate(self): self.send.close() - if self.recv_thread.is_alive(): - self.recv_thread.join() + if self._recv_thread.is_alive(): + self._recv_thread.join() def request_setInstructionBreakpoints(self, memory_reference=[]): breakpoints = [] @@ -1367,7 +1479,7 @@ class DebugCommunication(object): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) class DebugAdapterServer(DebugCommunication): diff --git a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py index c51b4b1..c23b2e7 100644 --- a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py +++ b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py @@ -1,6 +1,6 @@ import os import time -from typing import Optional +from typing import Optional, Callable, Any, List, Union import uuid import dap_server @@ -67,7 +67,10 @@ class DAPTestCaseBase(TestBase): self, source_reference, lines, data=None, wait_for_resolve=True ): return self.set_source_breakpoints_from_source( - Source(source_reference=source_reference), lines, data, wait_for_resolve + Source.build(source_reference=source_reference), + lines, + data, + wait_for_resolve, ) def set_source_breakpoints_from_source( @@ -120,11 +123,19 @@ class DAPTestCaseBase(TestBase): f"Expected to resolve all breakpoints. Unresolved breakpoint ids: {unresolved_breakpoints}", ) - def waitUntil(self, condition_callback): - for _ in range(20): - if condition_callback(): + def wait_until( + self, + predicate: Callable[[], bool], + delay: float = 0.5, + timeout: float = DEFAULT_TIMEOUT, + ) -> bool: + """Repeatedly run the predicate until either the predicate returns True + or a timeout has occurred.""" + deadline = time.monotonic() + timeout + while deadline > time.monotonic(): + if predicate(): return True - time.sleep(0.5) + time.sleep(delay) return False def assertCapabilityIsSet(self, key: str, msg: Optional[str] = None) -> None: @@ -137,13 +148,16 @@ class DAPTestCaseBase(TestBase): if key in self.dap_server.capabilities: self.assertEqual(self.dap_server.capabilities[key], False, msg) - def verify_breakpoint_hit(self, breakpoint_ids, timeout=DEFAULT_TIMEOUT): + def verify_breakpoint_hit( + self, breakpoint_ids: List[Union[int, str]], timeout: float = DEFAULT_TIMEOUT + ): """Wait for the process we are debugging to stop, and verify we hit any breakpoint location in the "breakpoint_ids" array. "breakpoint_ids" should be a list of breakpoint ID strings (["1", "2"]). The return value from self.set_source_breakpoints() or self.set_function_breakpoints() can be passed to this function""" stopped_events = self.dap_server.wait_for_stopped(timeout) + normalized_bp_ids = [str(b) for b in breakpoint_ids] for stopped_event in stopped_events: if "body" in stopped_event: body = stopped_event["body"] @@ -154,22 +168,16 @@ class DAPTestCaseBase(TestBase): and body["reason"] != "instruction breakpoint" ): continue - if "description" not in body: + if "hitBreakpointIds" not in body: continue - # Descriptions for breakpoints will be in the form - # "breakpoint 1.1", so look for any description that matches - # ("breakpoint 1.") in the description field as verification - # that one of the breakpoint locations was hit. DAP doesn't - # allow breakpoints to have multiple locations, but LLDB does. - # So when looking at the description we just want to make sure - # the right breakpoint matches and not worry about the actual - # location. - description = body["description"] - for breakpoint_id in breakpoint_ids: - match_desc = f"breakpoint {breakpoint_id}." - if match_desc in description: + hit_breakpoint_ids = body["hitBreakpointIds"] + for bp in hit_breakpoint_ids: + if str(bp) in normalized_bp_ids: return - self.assertTrue(False, f"breakpoint not hit, stopped_events={stopped_events}") + self.assertTrue( + False, + f"breakpoint not hit, wanted breakpoint_ids {breakpoint_ids} in stopped_events {stopped_events}", + ) def verify_all_breakpoints_hit(self, breakpoint_ids, timeout=DEFAULT_TIMEOUT): """Wait for the process we are debugging to stop, and verify we hit @@ -213,7 +221,7 @@ class DAPTestCaseBase(TestBase): return True return False - def verify_commands(self, flavor, output, commands): + def verify_commands(self, flavor: str, output: str, commands: list[str]): self.assertTrue(output and len(output) > 0, "expect console output") lines = output.splitlines() prefix = "(lldb) " @@ -226,10 +234,11 @@ class DAPTestCaseBase(TestBase): found = True break self.assertTrue( - found, "verify '%s' found in console output for '%s'" % (cmd, flavor) + found, + f"Command '{flavor}' - '{cmd}' not found in output: {output}", ) - def get_dict_value(self, d, key_path): + def get_dict_value(self, d: dict, key_path: list[str]) -> Any: """Verify each key in the key_path array is in contained in each dictionary within "d". Assert if any key isn't in the corresponding dictionary. This is handy for grabbing values from VS @@ -298,28 +307,34 @@ class DAPTestCaseBase(TestBase): return (source["path"], stackFrame["line"]) return ("", 0) - def get_stdout(self, timeout=0.0): - return self.dap_server.get_output("stdout", timeout=timeout) + def get_stdout(self): + return self.dap_server.get_output("stdout") - def get_console(self, timeout=0.0): - return self.dap_server.get_output("console", timeout=timeout) + def get_console(self): + return self.dap_server.get_output("console") - def get_important(self, timeout=0.0): - return self.dap_server.get_output("important", timeout=timeout) + def get_important(self): + return self.dap_server.get_output("important") - def collect_stdout(self, timeout_secs, pattern=None): + def collect_stdout( + self, timeout: float = DEFAULT_TIMEOUT, pattern: Optional[str] = None + ) -> str: return self.dap_server.collect_output( - "stdout", timeout_secs=timeout_secs, pattern=pattern + "stdout", timeout=timeout, pattern=pattern ) - def collect_console(self, timeout_secs, pattern=None): + def collect_console( + self, timeout: float = DEFAULT_TIMEOUT, pattern: Optional[str] = None + ) -> str: return self.dap_server.collect_output( - "console", timeout_secs=timeout_secs, pattern=pattern + "console", timeout=timeout, pattern=pattern ) - def collect_important(self, timeout_secs, pattern=None): + def collect_important( + self, timeout: float = DEFAULT_TIMEOUT, pattern: Optional[str] = None + ) -> str: return self.dap_server.collect_output( - "important", timeout_secs=timeout_secs, pattern=pattern + "important", timeout=timeout, pattern=pattern ) def get_local_as_int(self, name, threadId=None): |