#!/usr/bin/env python3 """A script to automate the creation and landing of a stack of Pull Requests.""" import argparse import json import os import re import subprocess import sys import time import urllib.error import urllib.request # TODO: Remove typing workarounds when we use a newer python. from typing import List, Optional, Tuple from http.client import HTTPResponse from dataclasses import dataclass # --- Constants --- # BASE_BRANCH = "main" GITHUB_REMOTE_NAME = "origin" UPSTREAM_REMOTE_NAME = "upstream" DEFAULT_LABEL = "skip-precommit-approval" LLVM_GITHUB_TOKEN_VAR = "LLVM_GITHUB_TOKEN" LLVM_REPO = "llvm/llvm-project" GITHUB_API = "https://api.github.com" MERGE_MAX_RETRIES = 10 MERGE_RETRY_DELAY = 5 # seconds REQUEST_TIMEOUT = 30 # seconds class LlvmPrError(Exception): """Custom exception for errors in the PR automator script.""" @dataclass class PRAutomatorConfig: """Configuration Data.""" user_login: str token: str base_branch: str upstream_remote: str prefix: str draft: bool labels: List[str] no_merge: bool auto_merge: bool class CommandRunner: """Handles command execution and output. Supports dry runs and verbosity level.""" def __init__( self, dry_run: bool = False, verbose: bool = False, quiet: bool = False ): self.dry_run = dry_run self.verbose = verbose self.quiet = quiet def print(self, message: str, file=sys.stdout) -> None: if self.quiet and file == sys.stdout: return print(message, file=file) def verbose_print(self, message: str, file=sys.stdout) -> None: if self.verbose: print(message, file) def run_command( self, command: List[str], check: bool = True, capture_output: bool = False, text: bool = False, stdin_input: Optional[str] = None, read_only: bool = False, env: Optional[dict] = None, ) -> subprocess.CompletedProcess: if self.dry_run and not read_only: self.print(f"[Dry Run] Would run: {' '.join(command)}") return subprocess.CompletedProcess(command, 0, "", "") self.verbose_print(f"Running: {' '.join(command)}") try: return subprocess.run( command, check=check, capture_output=capture_output, text=text, input=stdin_input, env=env, ) except FileNotFoundError as e: raise LlvmPrError( f"Command '{command[0]}' not found. Is it installed and in your PATH?" ) from e except subprocess.CalledProcessError as e: self.print(f"Error running command: {' '.join(command)}", file=sys.stderr) if e.stdout: self.print(f"--- stdout ---\n{e.stdout}", file=sys.stderr) if e.stderr: self.print(f"--- stderr ---\n{e.stderr}", file=sys.stderr) raise e class GitHubAPI: """A wrapper for the GitHub API.""" def __init__(self, runner: CommandRunner, token: str): self.runner = runner self.headers = { "Authorization": f"token {token}", "Accept": "application/vnd.github.v3+json", "User-Agent": "llvm-push-pr", } self.opener = urllib.request.build_opener( urllib.request.HTTPHandler(), urllib.request.HTTPSHandler() ) def _request( self, method: str, endpoint: str, json_payload: Optional[dict] = None ) -> HTTPResponse: url = f"{GITHUB_API}{endpoint}" self.runner.verbose_print(f"API Request: {method.upper()} {url}") if json_payload: self.runner.verbose_print(f"Payload: {json_payload}") data = None headers = self.headers.copy() if json_payload: data = json.dumps(json_payload).encode("utf-8") headers["Content-Type"] = "application/json" req = urllib.request.Request( url, data=data, headers=headers, method=method.upper() ) try: return self.opener.open(req, timeout=REQUEST_TIMEOUT) except urllib.error.HTTPError as e: self.runner.print( f"Error making API request to {url}: {e}", file=sys.stderr ) self.runner.verbose_print( f"Error response body: {e.read().decode()}", file=sys.stderr ) raise e def _request_and_parse_json( self, method: str, endpoint: str, json_payload: Optional[dict] = None ) -> dict: with self._request(method, endpoint, json_payload) as response: # Expect a 200 'OK' or 201 'Created' status on success and JSON body. self._log_unexpected_status([200, 201], response.status) response_text = response.read().decode("utf-8") if response_text: return json.loads(response_text) return {} def _request_no_content( self, method: str, endpoint: str, json_payload: Optional[dict] = None ) -> None: with self._request(method, endpoint, json_payload) as response: # Expected a 204 No Content status on success, indicating the # operation was successful but there is no body. self._log_unexpected_status([204], response.status) def _log_unexpected_status( self, expected_statuses: List[int], actual_status: int ) -> None: if actual_status not in expected_statuses: self.runner.print( f"Warning: Expected status {', '.join(map(str, expected_statuses))}, but got {actual_status}", file=sys.stderr, ) def get_user_login(self) -> str: return self._request_and_parse_json("GET", "/user")["login"] def create_pr( self, head_branch: str, base_branch: str, title: str, body: str, draft: bool, ) -> int: if self.runner.dry_run: self.runner.print(f"[Dry Run] Would create pull request for '{head_branch}'...") return 0 self.runner.print(f"Creating pull request for '{head_branch}'...") data = { "title": title, "body": body, "head": head_branch, "base": base_branch, "draft": draft, } response_data = self._request_and_parse_json( "POST", f"/repos/{LLVM_REPO}/pulls", json_payload=data ) self.runner.print(f"Pull request created: {response_data.get("html_url")}") return response_data.get("number") def get_repo_settings(self) -> dict: return self._request_and_parse_json("GET", f"/repos/{LLVM_REPO}") def _get_pr_details(self, pr_number: str) -> dict: """Fetches the JSON details for a given pull request number.""" return self._request_and_parse_json( "GET", f"/repos/{LLVM_REPO}/pulls/{pr_number}" ) def add_labels( self, pr_number: int, labels: List[str], ) -> None: if self.runner.dry_run: self.runner.print(f"[Dry Run] Would set labels for #{pr_number}: {' '.join(labels)}") return None self.runner.print(f"Setting labels for #{pr_number}: {' '.join(labels)}") self._request_and_parse_json( "POST", f"/repos/{LLVM_REPO}/issues/{pr_number}/labels", json_payload={"labels": labels}, ) def _attempt_squash_merge(self, pr_number: str) -> bool: """Attempts to squash merge a PR, returning True on success.""" try: self._request_and_parse_json( "PUT", f"/repos/{LLVM_REPO}/pulls/{pr_number}/merge", json_payload={"merge_method": "squash"}, ) return True except urllib.error.HTTPError as e: # A 405 status code means the PR is not in a mergeable state. if e.code == 405: return False # Re-raise other HTTP errors. raise e def merge_pr(self, pr_number: int) -> Optional[str]: if self.runner.dry_run: self.runner.print(f"[Dry Run] Would merge #{pr_number}") return None for i in range(MERGE_MAX_RETRIES): self.runner.print( f"Attempting to merge #{pr_number} (attempt {i + 1}/{MERGE_MAX_RETRIES})..." ) pr_data = self._get_pr_details(pr_number) head_branch = pr_data["head"]["ref"] if pr_data.get("mergeable_state") == "dirty": raise LlvmPrError("Merge conflict.") if pr_data.get("mergeable"): if self._attempt_squash_merge(pr_number): self.runner.print("Successfully merged.") time.sleep(2) # Allow GitHub's API to update. return head_branch self.runner.print( f"PR not mergeable yet (state: {pr_data.get('mergeable_state', 'unknown')}). Retrying in {MERGE_RETRY_DELAY} seconds..." ) time.sleep(MERGE_RETRY_DELAY) raise LlvmPrError(f"PR was not mergeable after {MERGE_MAX_RETRIES} attempts.") def enable_auto_merge(self, pr_number: int) -> None: if self.runner.dry_run: self.runner.print(f"[Dry Run] Would enable auto-merge for #{pr_number}") return self.runner.print(f"Enabling auto-merge for #{pr_number}...") data = { "enabled": True, "merge_method": "squash", } self._request_no_content( "PUT", f"/repos/{LLVM_REPO}/pulls/{pr_number}/auto-merge", json_payload=data, ) self.runner.print("Auto-merge enabled.") def delete_branch( self, branch_name: str, default_branch: Optional[str] = None ) -> None: if default_branch and branch_name == default_branch: self.runner.print( f"Error: Refusing to delete the default branch '{branch_name}'.", file=sys.stderr, ) return try: self._request_no_content( "DELETE", f"/repos/{LLVM_REPO}/git/refs/heads/{branch_name}" ) except urllib.error.HTTPError as e: if e.code == 422: self.runner.print( f"Warning: Remote branch '{branch_name}' was already deleted, skipping deletion.", file=sys.stderr, ) else: raise e class LLVMPRAutomator: """Automates the process of creating and landing a stack of GitHub Pull Requests.""" def __init__( self, runner: CommandRunner, github_api: "GitHubAPI", config: "PRAutomatorConfig", remote: str, ): self.runner = runner self.github_api = github_api self.config = config self.remote = remote self.original_branch: str = "" self.created_branches: List[str] = [] self.repo_settings: dict = {} def _get_git_env(self) -> dict: git_env = os.environ.copy() git_env[LLVM_GITHUB_TOKEN_VAR] = self.config.token git_env["GIT_TERMINAL_PROMPT"] = "0" git_env["GIT_CONFIG_COUNT"] = "1" git_env["GIT_CONFIG_KEY_0"] = "credential.helper" git_env[ "GIT_CONFIG_VALUE_0" ] = f"!{sys.executable} -c \"import os; print('username=x'); print('password=' + os.environ['{LLVM_GITHUB_TOKEN_VAR}']);\"" return git_env def _get_current_branch(self) -> str: result = self.runner.run_command( ["git", "rev-parse", "--abbrev-ref", "HEAD"], capture_output=True, text=True, read_only=True, ) return result.stdout.strip() def _check_work_tree(self) -> None: result = self.runner.run_command( ["git", "status", "--porcelain"], capture_output=True, text=True, read_only=True, ) for file_path in [line.strip() for line in result.stdout.split("\n")[:-1]]: if not file_path.startswith("?"): raise LlvmPrError( "Your working tree is dirty. Please stash or commit your changes." ) def _rebase_current_branch(self) -> None: self._check_work_tree() target = f"{self.config.upstream_remote}/{self.config.base_branch}" self.runner.print( f"Fetching from '{self.config.upstream_remote}' and rebasing '{self.original_branch}' on top of '{target}'..." ) git_env = self._get_git_env() refspec = f"refs/heads/{self.config.base_branch}:refs/remotes/{self.config.upstream_remote}/{self.config.base_branch}" self.runner.run_command( ["git", "fetch", self.config.upstream_remote, refspec], env=git_env ) try: self.runner.run_command(["git", "rebase", target], env=git_env) except subprocess.CalledProcessError as e: self.runner.print( "Error: The rebase operation failed, likely due to a merge conflict.", file=sys.stderr, ) if e.stdout: self.runner.print(f"--- stdout ---\n{e.stdout}", file=sys.stderr) if e.stderr: self.runner.print(f"--- stderr ---\n{e.stderr}", file=sys.stderr) # Check if rebase is in progress before aborting rebase_status_result = self.runner.run_command( ["git", "status", "--verify-status=REBASE_HEAD"], check=False, capture_output=True, text=True, read_only=True, env=git_env, ) # REBASE_HEAD exists, so rebase is in progress if rebase_status_result.returncode == 0: self.runner.print("Aborting rebase...", file=sys.stderr) self.runner.run_command( ["git", "rebase", "--abort"], check=False, env=git_env ) raise LlvmPrError("rebase operation failed.") from e def _get_commit_stack(self) -> List[str]: target = f"{self.config.upstream_remote}/{self.config.base_branch}" result = self.runner.run_command( ["git", "rev-list", "--reverse", f"{target}..HEAD"], capture_output=True, text=True, read_only=True, ) return result.stdout.strip().splitlines() def _get_commit_details(self, commit_hash: str) -> Tuple[str, str]: # Get the subject and body from git show. Insert "\n\n" between to make # parsing simple to do w/ split. result = self.runner.run_command( ["git", "show", "-s", "--format=%B", commit_hash], capture_output=True, text=True, read_only=True, ) parts = [item.strip() for item in result.stdout.split("\n", 1)] title = parts[0] body = parts[1] if len(parts) > 1 else "" return title, body def _sanitize_branch_name(self, text: str) -> str: sanitized = re.sub(r"[^\w\s-]", "", text).strip().lower() sanitized = re.sub(r"[-\s]+", "-", sanitized) # Use "auto-pr" as a fallback. return sanitized or "auto-pr" def _validate_merge_config(self, num_commits: int) -> None: if num_commits > 1: if self.config.auto_merge: raise LlvmPrError("--auto-merge is only supported for a single commit.") if self.config.no_merge: raise LlvmPrError( "--no-merge is only supported for a single commit. " "For stacks, the script must merge sequentially." ) self.runner.print(f"Found {num_commits} commit(s) to process.") def _create_and_push_branch_for_commit( self, commit_hash: str, base_branch_name: str, index: int ) -> str: branch_name = f"{self.config.prefix}{base_branch_name}-{index + 1}" commit_title, _ = self._get_commit_details(commit_hash) self.runner.print(f"Processing commit {commit_hash[:7]}: {commit_title}") self.runner.print(f"Pushing commit to temporary branch '{branch_name}'") git_env = self._get_git_env() push_command = [ "git", "push", self.remote, f"{commit_hash}:refs/heads/{branch_name}", ] self.runner.run_command(push_command, env=git_env) self.created_branches.append(branch_name) return branch_name def _process_commit( self, commit_hash: str, base_branch_name: str, index: int ) -> None: commit_title, commit_body = self._get_commit_details(commit_hash) temp_branch = self._create_and_push_branch_for_commit( commit_hash, base_branch_name, index ) pr_number = self.github_api.create_pr( head_branch=f"{self.config.user_login}:{temp_branch}", base_branch=self.config.base_branch, title=commit_title, body=commit_body, draft=self.config.draft, ) # TODO: There's a possibility of a race with PR labelers workflow. # To avoid it, we could create the PR as a draft, set the labels # and change the status, but that requires the use of GraphQL API. if self.config.labels: self.github_api.add_labels(pr_number, self.config.labels) if self.config.no_merge: return if self.config.auto_merge: self.github_api.enable_auto_merge(pr_number) else: merged_branch = self.github_api.merge_pr(pr_number) if merged_branch and not self.repo_settings.get("delete_branch_on_merge"): # After a merge, the branch should be deleted. self.github_api.delete_branch(merged_branch) if temp_branch in self.created_branches: # If the branch was successfully merged, it should not be deleted # again during cleanup. self.created_branches.remove(temp_branch) def run(self) -> None: self.repo_settings = self.github_api.get_repo_settings() self.original_branch = self._get_current_branch() self.runner.print(f"On branch: {self.original_branch}") try: commits = self._get_commit_stack() if not commits: self.runner.print("No new commits to process.") return self._validate_merge_config(len(commits)) branch_base_name = self.original_branch if self.original_branch == "main": first_commit_title, _ = self._get_commit_details(commits[0]) branch_base_name = self._sanitize_branch_name(first_commit_title) for i in range(len(commits)): if not commits: self.runner.print("Success! All commits have been landed.") break self._process_commit(commits[0], branch_base_name, i) self._rebase_current_branch() # After a rebase, the commit hashes can change, so we need to # get the latest commit stack. commits = self._get_commit_stack() finally: self._cleanup() def _cleanup(self) -> None: self.runner.print(f"Returning to original branch: {self.original_branch}") self.runner.run_command( ["git", "checkout", self.original_branch], capture_output=True ) if self.created_branches: self.runner.print("Cleaning up temporary remote branches...") for branch in self.created_branches: self.github_api.delete_branch(branch) def check_prerequisites(runner: CommandRunner) -> None: runner.print("Checking prerequisites...") runner.run_command(["git", "--version"], capture_output=True, read_only=True) result = runner.run_command( ["git", "rev-parse", "--is-inside-work-tree"], check=False, capture_output=True, text=True, read_only=True, ) if result.returncode != 0 or result.stdout.strip() != "true": raise LlvmPrError("This script must be run inside a git repository.") runner.print("Prerequisites met.") def main() -> None: parser = argparse.ArgumentParser( description="Create and land a stack of Pull Requests." ) command_runner = CommandRunner() token = os.getenv(LLVM_GITHUB_TOKEN_VAR) if not token: raise LlvmPrError(f"{LLVM_GITHUB_TOKEN_VAR} environment variable not set.") parser.add_argument( "--base", default=BASE_BRANCH, help=f"Base branch to target (default: {BASE_BRANCH})", ) parser.add_argument( "--remote", default=GITHUB_REMOTE_NAME, help=f"Remote for your fork to push to (default: {GITHUB_REMOTE_NAME})", ) parser.add_argument( "--upstream-remote", default=UPSTREAM_REMOTE_NAME, help=f"Remote for the upstream repository (default: {UPSTREAM_REMOTE_NAME})", ) parser.add_argument( "--login", default=None, help="The user login to use. If not provided this will be queried from the TOKEN", ) parser.add_argument( "--prefix", default=None, help="Prefix for temporary branches (default: users/)", ) parser.add_argument( "--draft", action="store_true", help="Create pull requests as drafts." ) parser.add_argument( "--labels", nargs='*', default=[DEFAULT_LABEL], help=f"Set the PR labels (default: {DEFAULT_LABEL})." ) merging = parser.add_mutually_exclusive_group() merging.add_argument( "--no-merge", action="store_true", help="Create PRs but do not merge them." ) merging.add_argument( "--auto-merge", action="store_true", help="Enable auto-merge for each PR instead of attempting to merge immediately.", ) parser.add_argument( "--dry-run", action="store_true", help="Print commands without executing them." ) verbosity = parser.add_mutually_exclusive_group() verbosity.add_argument( "-v", "--verbose", action="store_true", help="Print all commands being run." ) verbosity.add_argument( "-q", "--quiet", action="store_true", help="Print only essential output and errors.", ) args = parser.parse_args() command_runner = CommandRunner( dry_run=args.dry_run, verbose=args.verbose, quiet=args.quiet ) check_prerequisites(command_runner) github_api = GitHubAPI(command_runner, token) if not args.login: try: args.login = github_api.get_user_login() except urllib.error.HTTPError as e: raise LlvmPrError(f"Could not fetch user login from GitHub: {e}") from e if not args.prefix: args.prefix = f"users/{args.login}/" if not args.prefix.endswith("/"): args.prefix += "/" try: config = PRAutomatorConfig( user_login=args.login, token=token, base_branch=args.base, upstream_remote=args.upstream_remote, prefix=args.prefix, draft=args.draft, labels=args.labels, no_merge=args.no_merge, auto_merge=args.auto_merge, ) automator = LLVMPRAutomator( runner=command_runner, github_api=github_api, config=config, remote=args.remote, ) automator.run() except LlvmPrError as e: sys.exit(f"Error: {e}") if __name__ == "__main__": main()