diff options
Diffstat (limited to 'lldb/packages/Python')
7 files changed, 105 insertions, 114 deletions
diff --git a/lldb/packages/Python/lldbsuite/support/seven.py b/lldb/packages/Python/lldbsuite/support/seven.py index 1b96658..8e621ba 100644 --- a/lldb/packages/Python/lldbsuite/support/seven.py +++ b/lldb/packages/Python/lldbsuite/support/seven.py @@ -1,5 +1,4 @@ import binascii -import shlex import subprocess @@ -38,8 +37,3 @@ def unhexlify(hexstr): def hexlify(data): """Hex-encode string data. The result if always a string.""" return bitcast_to_string(binascii.hexlify(bitcast_to_bytes(data))) - - -# TODO: Replace this with `shlex.join` when minimum Python version is >= 3.8 -def join_for_shell(split_command): - return " ".join([shlex.quote(part) for part in split_command]) diff --git a/lldb/packages/Python/lldbsuite/test/gdbclientutils.py b/lldb/packages/Python/lldbsuite/test/gdbclientutils.py index 53e991a..1a2860a 100644 --- a/lldb/packages/Python/lldbsuite/test/gdbclientutils.py +++ b/lldb/packages/Python/lldbsuite/test/gdbclientutils.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod import ctypes import errno import io @@ -5,6 +6,7 @@ import threading import socket import traceback from lldbsuite.support import seven +from typing import Optional, List, Tuple def checksum(message): @@ -86,7 +88,7 @@ class MockGDBServerResponder: handles any packet not recognized in the common packet handling code. """ - registerCount = 40 + registerCount: int = 40 class RESPONSE_DISCONNECT: pass @@ -95,7 +97,7 @@ class MockGDBServerResponder: pass def __init__(self): - self.packetLog = [] + self.packetLog: List[str] = [] def respond(self, packet): """ @@ -241,7 +243,7 @@ class MockGDBServerResponder: def qHostInfo(self): return "ptrsize:8;endian:little;" - def qEcho(self): + def qEcho(self, num: int): return "E04" def qQueryGDBServer(self): @@ -262,10 +264,10 @@ class MockGDBServerResponder: def D(self, packet): return "OK" - def readRegisters(self): + def readRegisters(self) -> str: return "00000000" * self.registerCount - def readRegister(self, register): + def readRegister(self, register: int) -> str: return "00000000" def writeRegisters(self, registers_hex): @@ -305,7 +307,9 @@ class MockGDBServerResponder: # SIGINT is 2, return type is 2 digit hex string return "S02" - def qXferRead(self, obj, annex, offset, length): + def qXferRead( + self, obj: str, annex: str, offset: int, length: int + ) -> Tuple[Optional[str], bool]: return None, False def _qXferResponse(self, data, has_more): @@ -373,15 +377,17 @@ class MockGDBServerResponder: pass -class ServerChannel: +class ServerChannel(ABC): """ A wrapper class for TCP or pty-based server. """ - def get_connect_address(self): + @abstractmethod + def get_connect_address(self) -> str: """Get address for the client to connect to.""" - def get_connect_url(self): + @abstractmethod + def get_connect_url(self) -> str: """Get URL suitable for process connect command.""" def close_server(self): @@ -393,10 +399,12 @@ class ServerChannel: def close_connection(self): """Close all resources used by the accepted connection.""" - def recv(self): + @abstractmethod + def recv(self) -> bytes: """Receive a data packet from the connected client.""" - def sendall(self, data): + @abstractmethod + def sendall(self, data: bytes) -> None: """Send the data to the connected client.""" @@ -427,11 +435,11 @@ class ServerSocket(ServerChannel): self._connection.close() self._connection = None - def recv(self): + def recv(self) -> bytes: assert self._connection is not None return self._connection.recv(4096) - def sendall(self, data): + def sendall(self, data: bytes) -> None: assert self._connection is not None return self._connection.sendall(data) @@ -443,10 +451,10 @@ class TCPServerSocket(ServerSocket): )[0] super().__init__(family, type, proto, addr) - def get_connect_address(self): + def get_connect_address(self) -> str: return "[{}]:{}".format(*self._server_socket.getsockname()) - def get_connect_url(self): + def get_connect_url(self) -> str: return "connect://" + self.get_connect_address() @@ -454,10 +462,10 @@ class UnixServerSocket(ServerSocket): def __init__(self, addr): super().__init__(socket.AF_UNIX, socket.SOCK_STREAM, 0, addr) - def get_connect_address(self): + def get_connect_address(self) -> str: return self._server_socket.getsockname() - def get_connect_url(self): + def get_connect_url(self) -> str: return "unix-connect://" + self.get_connect_address() @@ -471,7 +479,7 @@ class PtyServerSocket(ServerChannel): self._primary = io.FileIO(primary, "r+b") self._secondary = io.FileIO(secondary, "r+b") - def get_connect_address(self): + def get_connect_address(self) -> str: libc = ctypes.CDLL(None) libc.ptsname.argtypes = (ctypes.c_int,) libc.ptsname.restype = ctypes.c_char_p @@ -484,7 +492,7 @@ class PtyServerSocket(ServerChannel): self._secondary.close() self._primary.close() - def recv(self): + def recv(self) -> bytes: try: return self._primary.read(4096) except OSError as e: @@ -493,8 +501,8 @@ class PtyServerSocket(ServerChannel): return b"" raise - def sendall(self, data): - return self._primary.write(data) + def sendall(self, data: bytes) -> None: + self._primary.write(data) class MockGDBServer: @@ -527,18 +535,21 @@ class MockGDBServer: self._thread.join() self._thread = None - def get_connect_address(self): + def get_connect_address(self) -> str: + assert self._socket is not None return self._socket.get_connect_address() - def get_connect_url(self): + def get_connect_url(self) -> str: + assert self._socket is not None return self._socket.get_connect_url() def run(self): + assert self._socket is not None # For testing purposes, we only need to worry about one client # connecting just one time. try: self._socket.accept() - except: + except Exception: traceback.print_exc() return self._shouldSendAck = True @@ -553,7 +564,7 @@ class MockGDBServer: self._receive(data) except self.TerminateConnectionException: pass - except Exception as e: + except Exception: print( "An exception happened when receiving the response from the gdb server. Closing the client..." ) @@ -586,7 +597,9 @@ class MockGDBServer: Once a complete packet is found at the front of self._receivedData, its data is removed form self._receivedData. """ + assert self._receivedData is not None data = self._receivedData + assert self._receivedDataOffset is not None i = self._receivedDataOffset data_len = len(data) if data_len == 0: @@ -639,10 +652,13 @@ class MockGDBServer: self._receivedDataOffset = 0 return packet - def _sendPacket(self, packet): - self._socket.sendall(seven.bitcast_to_bytes(frame_packet(packet))) + def _sendPacket(self, packet: str): + assert self._socket is not None + framed_packet = seven.bitcast_to_bytes(frame_packet(packet)) + self._socket.sendall(framed_packet) def _handlePacket(self, packet): + assert self._socket is not None if packet is self.PACKET_ACK: # Ignore ACKs from the client. For the future, we can consider # adding validation code to make sure the client only sends ACKs diff --git a/lldb/packages/Python/lldbsuite/test/lldbtest.py b/lldb/packages/Python/lldbsuite/test/lldbtest.py index 8074922..b92de94 100644 --- a/lldb/packages/Python/lldbsuite/test/lldbtest.py +++ b/lldb/packages/Python/lldbsuite/test/lldbtest.py @@ -36,6 +36,7 @@ import json import os.path import re import shutil +import shlex import signal from subprocess import * import sys @@ -56,7 +57,6 @@ from . import lldbutil from . import test_categories from lldbsuite.support import encoded_file from lldbsuite.support import funcutils -from lldbsuite.support import seven from lldbsuite.test_event import build_exception # See also dotest.parseOptionsAndInitTestdirs(), where the environment variables @@ -1508,7 +1508,7 @@ class Base(unittest.TestCase): self.runBuildCommand(command) def runBuildCommand(self, command): - self.trace(seven.join_for_shell(command)) + self.trace(shlex.join(command)) try: output = check_output(command, stderr=STDOUT, errors="replace") except CalledProcessError as cpe: diff --git a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py index 8eb64b4..a3d924d 100644 --- a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py +++ b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py @@ -27,6 +27,10 @@ from typing import ( Literal, ) +# set timeout based on whether ASAN was enabled or not. Increase +# timeout by a factor of 10 if ASAN is enabled. +DEFAULT_TIMEOUT = 10 * (10 if ("ASAN_OPTIONS" in os.environ) else 1) + ## DAP type references @@ -282,26 +286,24 @@ class DebugCommunication(object): def collect_output( self, category: str, - timeout: float, pattern: Optional[str] = None, clear=True, ) -> str: """Collect output from 'output' events. Args: category: The category to collect. - timeout: The max duration for collecting output. pattern: Optional, if set, return once this pattern is detected in the collected output. Returns: The collected output. """ - deadline = time.monotonic() + timeout + deadline = time.monotonic() + DEFAULT_TIMEOUT output = self.get_output(category, clear) while deadline >= time.monotonic() and ( pattern is None or pattern not in output ): - event = self.wait_for_event(["output"], timeout=deadline - time.monotonic()) + event = self.wait_for_event(["output"]) if not event: # Timeout or EOF break output += self.get_output(category, clear=clear) @@ -339,7 +341,7 @@ class DebugCommunication(object): self, *, predicate: Optional[Callable[[ProtocolMessage], bool]] = None, - timeout: Optional[float] = None, + timeout: Optional[float] = DEFAULT_TIMEOUT, ) -> Optional[ProtocolMessage]: """Processes received packets from the adapter. Updates the DebugCommunication stateful properties based on the received @@ -555,25 +557,20 @@ class DebugCommunication(object): return cast(Optional[Response], self._recv_packet(predicate=predicate)) - def wait_for_event( - self, filter: List[str] = [], timeout: Optional[float] = None - ) -> Optional[Event]: + def wait_for_event(self, filter: List[str] = []) -> Optional[Event]: """Wait for the first event that matches the filter.""" def predicate(p: ProtocolMessage): return p["type"] == "event" and p["event"] in filter return cast( - Optional[Event], self._recv_packet(predicate=predicate, timeout=timeout) + Optional[Event], + self._recv_packet(predicate=predicate), ) - def wait_for_stopped( - self, timeout: Optional[float] = None - ) -> Optional[List[Event]]: + def wait_for_stopped(self) -> Optional[List[Event]]: stopped_events = [] - stopped_event = self.wait_for_event( - filter=["stopped", "exited"], timeout=timeout - ) + stopped_event = self.wait_for_event(filter=["stopped", "exited"]) while stopped_event: stopped_events.append(stopped_event) # If we exited, then we are done @@ -582,26 +579,28 @@ class DebugCommunication(object): # Otherwise we stopped and there might be one or more 'stopped' # events for each thread that stopped with a reason, so keep # checking for more 'stopped' events and return all of them - stopped_event = self.wait_for_event( - filter=["stopped", "exited"], timeout=0.25 + # Use a shorter timeout for additional stopped events + def predicate(p: ProtocolMessage): + return p["type"] == "event" and p["event"] in ["stopped", "exited"] + + stopped_event = cast( + Optional[Event], self._recv_packet(predicate=predicate, timeout=0.25) ) return stopped_events - def wait_for_breakpoint_events(self, timeout: Optional[float] = None): + def wait_for_breakpoint_events(self): breakpoint_events: list[Event] = [] while True: - event = self.wait_for_event(["breakpoint"], timeout=timeout) + event = self.wait_for_event(["breakpoint"]) if not event: break breakpoint_events.append(event) return breakpoint_events - def wait_for_breakpoints_to_be_verified( - self, breakpoint_ids: list[str], timeout: Optional[float] = None - ): + def wait_for_breakpoints_to_be_verified(self, breakpoint_ids: list[str]): """Wait for all breakpoints to be verified. Return all unverified breakpoints.""" while any(id not in self.resolved_breakpoints for id in breakpoint_ids): - breakpoint_event = self.wait_for_event(["breakpoint"], timeout=timeout) + breakpoint_event = self.wait_for_event(["breakpoint"]) if breakpoint_event is None: break @@ -614,14 +613,14 @@ class DebugCommunication(object): ) ] - def wait_for_exited(self, timeout: Optional[float] = None): - event_dict = self.wait_for_event(["exited"], timeout=timeout) + def wait_for_exited(self): + event_dict = self.wait_for_event(["exited"]) if event_dict is None: raise ValueError("didn't get exited event") return event_dict - def wait_for_terminated(self, timeout: Optional[float] = None): - event_dict = self.wait_for_event(["terminated"], timeout) + def wait_for_terminated(self): + event_dict = self.wait_for_event(["terminated"]) if event_dict is None: raise ValueError("didn't get terminated event") return event_dict @@ -1610,7 +1609,7 @@ class DebugAdapterServer(DebugCommunication): # new messages will arrive and it should shutdown on its # own. process.stdin.close() - process.wait(timeout=20) + process.wait(timeout=DEFAULT_TIMEOUT) except subprocess.TimeoutExpired: process.kill() process.wait() diff --git a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py index f7b1ed8..29935bb 100644 --- a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py +++ b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py @@ -18,7 +18,7 @@ import base64 class DAPTestCaseBase(TestBase): # set timeout based on whether ASAN was enabled or not. Increase # timeout by a factor of 10 if ASAN is enabled. - DEFAULT_TIMEOUT = 10 * (10 if ("ASAN_OPTIONS" in os.environ) else 1) + DEFAULT_TIMEOUT = dap_server.DEFAULT_TIMEOUT NO_DEBUG_INFO_TESTCASE = True def create_debug_adapter( @@ -118,11 +118,9 @@ class DAPTestCaseBase(TestBase): self.wait_for_breakpoints_to_resolve(breakpoint_ids) return breakpoint_ids - def wait_for_breakpoints_to_resolve( - self, breakpoint_ids: list[str], timeout: Optional[float] = DEFAULT_TIMEOUT - ): + def wait_for_breakpoints_to_resolve(self, breakpoint_ids: list[str]): unresolved_breakpoints = self.dap_server.wait_for_breakpoints_to_be_verified( - breakpoint_ids, timeout + breakpoint_ids ) self.assertEqual( len(unresolved_breakpoints), @@ -134,11 +132,10 @@ class DAPTestCaseBase(TestBase): self, predicate: Callable[[], bool], delay: float = 0.5, - timeout: float = DEFAULT_TIMEOUT, ) -> bool: """Repeatedly run the predicate until either the predicate returns True or a timeout has occurred.""" - deadline = time.monotonic() + timeout + deadline = time.monotonic() + self.DEFAULT_TIMEOUT while deadline > time.monotonic(): if predicate(): return True @@ -155,15 +152,13 @@ class DAPTestCaseBase(TestBase): if key in self.dap_server.capabilities: self.assertEqual(self.dap_server.capabilities[key], False, msg) - def verify_breakpoint_hit( - self, breakpoint_ids: List[Union[int, str]], timeout: float = DEFAULT_TIMEOUT - ): + def verify_breakpoint_hit(self, breakpoint_ids: List[Union[int, str]]): """Wait for the process we are debugging to stop, and verify we hit any breakpoint location in the "breakpoint_ids" array. "breakpoint_ids" should be a list of breakpoint ID strings (["1", "2"]). The return value from self.set_source_breakpoints() or self.set_function_breakpoints() can be passed to this function""" - stopped_events = self.dap_server.wait_for_stopped(timeout) + stopped_events = self.dap_server.wait_for_stopped() normalized_bp_ids = [str(b) for b in breakpoint_ids] for stopped_event in stopped_events: if "body" in stopped_event: @@ -186,11 +181,11 @@ class DAPTestCaseBase(TestBase): f"breakpoint not hit, wanted breakpoint_ids {breakpoint_ids} in stopped_events {stopped_events}", ) - def verify_all_breakpoints_hit(self, breakpoint_ids, timeout=DEFAULT_TIMEOUT): + def verify_all_breakpoints_hit(self, breakpoint_ids): """Wait for the process we are debugging to stop, and verify we hit all of the breakpoint locations in the "breakpoint_ids" array. "breakpoint_ids" should be a list of int breakpoint IDs ([1, 2]).""" - stopped_events = self.dap_server.wait_for_stopped(timeout) + stopped_events = self.dap_server.wait_for_stopped() for stopped_event in stopped_events: if "body" in stopped_event: body = stopped_event["body"] @@ -208,12 +203,12 @@ class DAPTestCaseBase(TestBase): return self.assertTrue(False, f"breakpoints not hit, stopped_events={stopped_events}") - def verify_stop_exception_info(self, expected_description, timeout=DEFAULT_TIMEOUT): + def verify_stop_exception_info(self, expected_description): """Wait for the process we are debugging to stop, and verify the stop reason is 'exception' and that the description matches 'expected_description' """ - stopped_events = self.dap_server.wait_for_stopped(timeout) + stopped_events = self.dap_server.wait_for_stopped() for stopped_event in stopped_events: if "body" in stopped_event: body = stopped_event["body"] @@ -338,26 +333,14 @@ class DAPTestCaseBase(TestBase): def get_important(self): return self.dap_server.get_output("important") - def collect_stdout( - self, timeout: float = DEFAULT_TIMEOUT, pattern: Optional[str] = None - ) -> str: - return self.dap_server.collect_output( - "stdout", timeout=timeout, pattern=pattern - ) + def collect_stdout(self, pattern: Optional[str] = None) -> str: + return self.dap_server.collect_output("stdout", pattern=pattern) - def collect_console( - self, timeout: float = DEFAULT_TIMEOUT, pattern: Optional[str] = None - ) -> str: - return self.dap_server.collect_output( - "console", timeout=timeout, pattern=pattern - ) + def collect_console(self, pattern: Optional[str] = None) -> str: + return self.dap_server.collect_output("console", pattern=pattern) - def collect_important( - self, timeout: float = DEFAULT_TIMEOUT, pattern: Optional[str] = None - ) -> str: - return self.dap_server.collect_output( - "important", timeout=timeout, pattern=pattern - ) + def collect_important(self, pattern: Optional[str] = None) -> str: + return self.dap_server.collect_output("important", pattern=pattern) def get_local_as_int(self, name, threadId=None): value = self.dap_server.get_local_variable_value(name, threadId=threadId) @@ -393,14 +376,13 @@ class DAPTestCaseBase(TestBase): targetId=None, waitForStop=True, granularity="statement", - timeout=DEFAULT_TIMEOUT, ): response = self.dap_server.request_stepIn( threadId=threadId, targetId=targetId, granularity=granularity ) self.assertTrue(response["success"]) if waitForStop: - return self.dap_server.wait_for_stopped(timeout) + return self.dap_server.wait_for_stopped() return None def stepOver( @@ -408,7 +390,6 @@ class DAPTestCaseBase(TestBase): threadId=None, waitForStop=True, granularity="statement", - timeout=DEFAULT_TIMEOUT, ): response = self.dap_server.request_next( threadId=threadId, granularity=granularity @@ -417,40 +398,40 @@ class DAPTestCaseBase(TestBase): response["success"], f"next request failed: response {response}" ) if waitForStop: - return self.dap_server.wait_for_stopped(timeout) + return self.dap_server.wait_for_stopped() return None - def stepOut(self, threadId=None, waitForStop=True, timeout=DEFAULT_TIMEOUT): + def stepOut(self, threadId=None, waitForStop=True): self.dap_server.request_stepOut(threadId=threadId) if waitForStop: - return self.dap_server.wait_for_stopped(timeout) + return self.dap_server.wait_for_stopped() return None def do_continue(self): # `continue` is a keyword. resp = self.dap_server.request_continue() self.assertTrue(resp["success"], f"continue request failed: {resp}") - def continue_to_next_stop(self, timeout=DEFAULT_TIMEOUT): + def continue_to_next_stop(self): self.do_continue() - return self.dap_server.wait_for_stopped(timeout) + return self.dap_server.wait_for_stopped() - def continue_to_breakpoint(self, breakpoint_id: str, timeout=DEFAULT_TIMEOUT): - self.continue_to_breakpoints((breakpoint_id), timeout) + def continue_to_breakpoint(self, breakpoint_id: str): + self.continue_to_breakpoints((breakpoint_id)) - def continue_to_breakpoints(self, breakpoint_ids, timeout=DEFAULT_TIMEOUT): + def continue_to_breakpoints(self, breakpoint_ids): self.do_continue() - self.verify_breakpoint_hit(breakpoint_ids, timeout) + self.verify_breakpoint_hit(breakpoint_ids) - def continue_to_exception_breakpoint(self, filter_label, timeout=DEFAULT_TIMEOUT): + def continue_to_exception_breakpoint(self, filter_label): self.do_continue() self.assertTrue( - self.verify_stop_exception_info(filter_label, timeout), + self.verify_stop_exception_info(filter_label), 'verify we got "%s"' % (filter_label), ) - def continue_to_exit(self, exitCode=0, timeout=DEFAULT_TIMEOUT): + def continue_to_exit(self, exitCode=0): self.do_continue() - stopped_events = self.dap_server.wait_for_stopped(timeout) + stopped_events = self.dap_server.wait_for_stopped() self.assertEqual( len(stopped_events), 1, "stopped_events = {}".format(stopped_events) ) diff --git a/lldb/packages/Python/lldbsuite/test/tools/lldb-server/gdbremote_testcase.py b/lldb/packages/Python/lldbsuite/test/tools/lldb-server/gdbremote_testcase.py index aea6b9f..5ba642b 100644 --- a/lldb/packages/Python/lldbsuite/test/tools/lldb-server/gdbremote_testcase.py +++ b/lldb/packages/Python/lldbsuite/test/tools/lldb-server/gdbremote_testcase.py @@ -931,6 +931,7 @@ class GdbRemoteTestCaseBase(Base, metaclass=GdbRemoteTestCaseFactory): "QNonStop", "SupportedWatchpointTypes", "SupportedCompressions", + "MultiMemRead", ] def parse_qSupported_response(self, context): diff --git a/lldb/packages/Python/lldbsuite/test_event/build_exception.py b/lldb/packages/Python/lldbsuite/test_event/build_exception.py index 931c15d..c3ae2cd 100644 --- a/lldb/packages/Python/lldbsuite/test_event/build_exception.py +++ b/lldb/packages/Python/lldbsuite/test_event/build_exception.py @@ -1,10 +1,10 @@ -from lldbsuite.support import seven +import shlex class BuildError(Exception): def __init__(self, called_process_error): super(BuildError, self).__init__("Error when building test subject") - self.command = seven.join_for_shell(called_process_error.cmd) + self.command = shlex.join(called_process_error.cmd) self.build_error = called_process_error.output def __str__(self): |