Files
codex/codex-rs/pr2md
Daniel Edrisian 1cf16f35f5 update
2025-09-02 14:05:05 -07:00

361 lines
12 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import json
import shutil
import subprocess
import sys
from collections import defaultdict
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import quote
def _run(cmd: List[str]) -> Tuple[int, str, str]:
proc = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=False,
)
return proc.returncode, proc.stdout, proc.stderr
def require_gh():
if shutil.which("gh") is None:
print("Error: GitHub CLI 'gh' not found. Please install and authenticate.", file=sys.stderr)
sys.exit(1)
def iso_to_utc_str(iso: Optional[str]) -> str:
if not iso:
return ""
try:
# Handle both Z and offset formats
if iso.endswith("Z"):
dt = datetime.fromisoformat(iso.replace("Z", "+00:00"))
else:
dt = datetime.fromisoformat(iso)
dt_utc = dt.astimezone(timezone.utc)
return dt_utc.strftime("%Y-%m-%d %H:%M:%S UTC")
except Exception:
return iso
def pr_view(repo: str, pr_number: int) -> Dict[str, Any]:
fields = [
"number",
"title",
"body",
"url",
"author",
"createdAt",
"updatedAt",
"additions",
"deletions",
"changedFiles",
"commits",
"baseRefName",
"headRefName",
"headRepositoryOwner",
]
code, out, err = _run(["gh", "pr", "view", str(pr_number), "-R", repo, "--json", ",".join(fields)])
if code != 0:
print(f"Error: failed to fetch PR via gh: {err.strip()}", file=sys.stderr)
sys.exit(1)
try:
data = json.loads(out)
except json.JSONDecodeError as e:
print(f"Error: failed to parse gh JSON output: {e}", file=sys.stderr)
sys.exit(1)
return data
def pr_combined_diff(repo: str, pr: Dict[str, Any]) -> str:
# Prefer a single combined diff between base and head.
base = pr.get("baseRefName")
head_branch = pr.get("headRefName")
head_owner = (pr.get("headRepositoryOwner") or {}).get("login")
if not base or not head_branch:
# Fallback to gh pr diff if fields unavailable
code, out, err = _run(["gh", "pr", "diff", str(pr.get("number")), "-R", repo, "--color=never"])
if code != 0:
print(f"Error: failed to fetch PR diff: {err.strip()}", file=sys.stderr)
sys.exit(1)
return out.rstrip()
base_owner = repo.split("/", 1)[0]
if head_owner and head_owner != base_owner:
head = f"{head_owner}:{head_branch}"
else:
head = head_branch
path = f"/repos/{repo}/compare/{quote(base, safe='')}...{quote(head, safe='')}"
code, out, err = _run(["gh", "api", "-H", "Accept: application/vnd.github.v3.diff", path])
if code == 0 and out.strip():
return out.rstrip()
# Fallback
code, out, err = _run(["gh", "pr", "diff", str(pr.get("number")), "-R", repo, "--color=never"])
if code != 0:
print(f"Error: failed to fetch PR diff: {err.strip()}", file=sys.stderr)
sys.exit(1)
return out.rstrip()
def pr_review_comments(repo: str, pr_number: int) -> List[Dict[str, Any]]:
# Pull Request Review Comments (code comments). Fetch up to 1000 via pages of 100.
all_comments: List[Dict[str, Any]] = []
page = 1
while True:
path = f"/repos/{repo}/pulls/{pr_number}/comments?per_page=100&page={page}"
code, out, err = _run(["gh", "api", path])
if code != 0:
print(f"Error: failed to fetch review comments: {err.strip()}", file=sys.stderr)
sys.exit(1)
try:
batch = json.loads(out)
except json.JSONDecodeError:
print("Error: could not parse review comments JSON.", file=sys.stderr)
sys.exit(1)
if not batch:
break
all_comments.extend(batch)
if len(batch) < 100:
break
page += 1
if page > 10: # safety cap
break
return all_comments
def parse_repo_from_url(url: str) -> Optional[str]:
u = url.strip()
if not u:
return None
# Common forms:
# - SSH scp-like: <user>@github.com:owner/repo.git
# - SSH URL: ssh://<user>@github.com/owner/repo.git
# - HTTPS: https://github.com/owner/repo(.git)
# - Bare: github.com/owner/repo(.git)
if "github.com:" in u:
# scp-like syntax
path = u.split("github.com:", 1)[1]
elif "github.com/" in u:
path = u.split("github.com/", 1)[1]
elif u.startswith("github.com/"):
path = u.split("github.com/", 1)[1]
else:
return None
# Remove trailing .git if present
if path.endswith(".git"):
path = path[:-4]
# Keep only owner/repo
parts = path.strip("/").split("/")
if len(parts) >= 2:
return f"{parts[0]}/{parts[1]}"
return None
def detect_repo_from_git() -> Optional[str]:
# Ensure we're inside a git repo
code, out, _ = _run(["git", "rev-parse", "--is-inside-work-tree"])
if code != 0 or out.strip() != "true":
return None
code, origin_url, _ = _run(["git", "config", "--get", "remote.origin.url"])
if code != 0:
return None
return parse_repo_from_url(origin_url)
def blockquote(text: str) -> str:
lines = text.splitlines() or [""]
return "\n".join("> " + ln for ln in lines)
def format_header(pr: Dict[str, Any]) -> str:
number = pr.get("number")
title = pr.get("title", "")
url = pr.get("url", "")
author_login = (pr.get("author") or {}).get("login", "")
created = iso_to_utc_str(pr.get("createdAt"))
updated = iso_to_utc_str(pr.get("updatedAt"))
additions = pr.get("additions", 0)
deletions = pr.get("deletions", 0)
changed_files = pr.get("changedFiles", 0)
commits_obj = pr.get("commits")
if isinstance(commits_obj, dict):
if "totalCount" in commits_obj and isinstance(commits_obj["totalCount"], (int, float)):
commits_count = int(commits_obj["totalCount"]) # GraphQL connection
elif "nodes" in commits_obj and isinstance(commits_obj["nodes"], list):
commits_count = len(commits_obj["nodes"]) # fallback
else:
commits_count = None
elif isinstance(commits_obj, list):
commits_count = len(commits_obj)
elif isinstance(commits_obj, (int, float)):
commits_count = int(commits_obj)
else:
commits_count = None
commits_str = str(commits_count) if commits_count is not None else "?"
lines = []
lines.append(f"# PR #{number}: {title}")
lines.append("")
lines.append(f"- URL: {url}")
lines.append(f"- Author: {author_login}")
lines.append(f"- Created: {created}")
lines.append(f"- Updated: {updated}")
lines.append(f"- Changes: +{additions}/-{deletions}, Files changed: {changed_files}, Commits: {commits_str}")
return "\n".join(lines)
def format_description(body: Optional[str]) -> str:
desc = body or ""
desc = desc.strip()
if not desc:
desc = "(No description.)"
return f"\n## Description\n\n{desc}\n"
def format_diff(diff_text: str) -> str:
return f"\n## Full Diff\n\n```diff\n{diff_text}\n```\n"
def format_review_comments(comments: List[Dict[str, Any]], reviewer: Optional[str]) -> str:
if reviewer:
reviewer_lc = reviewer.lower()
comments = [c for c in comments if ((c.get("user") or {}).get("login", "").lower() == reviewer_lc)]
if not comments:
return "\n## Review Comments\n\n(No review comments.)\n"
# Group by file path, preserve PR order but sort paths for stable output
by_path: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
for c in comments:
by_path[c.get("path", "(unknown)")].append(c)
out_lines: List[str] = []
out_lines.append("\n## Review Comments\n")
for path in sorted(by_path.keys()):
out_lines.append(f"### {path}\n")
for c in by_path[path]:
created = iso_to_utc_str(c.get("created_at"))
url = c.get("html_url", "")
diff_hunk = c.get("diff_hunk", "").rstrip()
body = c.get("body", "")
out_lines.append(f"- Created: {created} | Link: {url}")
out_lines.append("")
if diff_hunk:
out_lines.append("```diff")
out_lines.append(diff_hunk)
out_lines.append("```")
out_lines.append("")
if body:
out_lines.append(blockquote(body))
out_lines.append("")
return "\n".join(out_lines).rstrip() + "\n"
def main():
parser = argparse.ArgumentParser(
prog="pr2md",
description=(
"Render a GitHub PR into Markdown including description, full diff, and review comments.\n"
"Requires GitHub CLI (gh) to be installed and authenticated."
),
)
parser.add_argument(
"pr_number",
nargs="?",
help="Pull request number (optional; auto-detect from current branch if omitted)",
)
parser.add_argument(
"repo",
nargs="?",
help="Repository in 'owner/repo' form; inferred from git origin if omitted",
)
parser.add_argument("--reviewer", help="Only include comments from this reviewer (login)")
args = parser.parse_args()
require_gh()
# Disambiguate single positional arg: if only one is provided and it looks like owner/repo,
# treat it as repo, not PR number.
if args.pr_number and not args.repo and "/" in args.pr_number and not args.pr_number.isdigit():
args.repo, args.pr_number = args.pr_number, None
repo = args.repo or detect_repo_from_git()
if not repo:
print(
"Error: Could not determine repository from git origin. Pass repo as 'owner/repo'.",
file=sys.stderr,
)
sys.exit(2)
# Determine PR number: use provided, else try to find open/draft PR for current branch
pr_number: Optional[int]
if args.pr_number:
try:
pr_number = int(args.pr_number)
except ValueError:
print("Error: PR number must be an integer.", file=sys.stderr)
sys.exit(2)
else:
# Detect from current branch
code, branch_out, _ = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
branch = branch_out.strip() if code == 0 else ""
if not branch or branch == "HEAD":
print("Error: Not on a branch. Provide a PR number explicitly.", file=sys.stderr)
sys.exit(2)
# Query open PRs and find one with matching head branch and owner
owner = repo.split("/", 1)[0]
code, out, err = _run([
"gh", "pr", "list", "-R", repo, "--state", "open",
"--json", "number,headRefName,isDraft,headRepositoryOwner",
])
if code != 0:
print(f"Error: failed to list PRs: {err.strip()}", file=sys.stderr)
sys.exit(1)
try:
pr_list = json.loads(out)
except json.JSONDecodeError:
print("Error: failed to parse PR list JSON.", file=sys.stderr)
sys.exit(1)
candidates = [
pr for pr in pr_list
if pr.get("headRefName") == branch and ((pr.get("headRepositoryOwner") or {}).get("login") == owner)
]
if not candidates:
# Relax owner constraint if none found
candidates = [pr for pr in pr_list if pr.get("headRefName") == branch]
if not candidates:
print(
f"Error: No open PR found for branch '{branch}'. Provide a PR number.",
file=sys.stderr,
)
sys.exit(2)
# If multiple, pick the first
pr_number = int(candidates[0]["number"])
pr = pr_view(repo, pr_number)
diff_text = pr_combined_diff(repo, pr)
comments = pr_review_comments(repo, pr_number)
parts = [
format_header(pr),
format_description(pr.get("body")),
format_diff(diff_text),
format_review_comments(comments, args.reviewer),
]
sys.stdout.write("\n".join(p.rstrip() for p in parts if p))
if __name__ == "__main__":
main()