#!/usr/bin/env python3
import argparse
import json
import shutil
import subprocess
import sys
from collections import defaultdict
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import quote


def _run(cmd: List[str]) -> Tuple[int, str, str]:
    proc = subprocess.run(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
        check=False,
    )
    return proc.returncode, proc.stdout, proc.stderr


def require_gh():
    if shutil.which("gh") is None:
        print("Error: GitHub CLI 'gh' not found. Please install and authenticate.", file=sys.stderr)
        sys.exit(1)


def iso_to_utc_str(iso: Optional[str]) -> str:
    if not iso:
        return ""
    try:
        # Handle both Z and offset formats
        if iso.endswith("Z"):
            dt = datetime.fromisoformat(iso.replace("Z", "+00:00"))
        else:
            dt = datetime.fromisoformat(iso)
        dt_utc = dt.astimezone(timezone.utc)
        return dt_utc.strftime("%Y-%m-%d %H:%M:%S UTC")
    except Exception:
        return iso


def pr_view(repo: str, pr_number: int) -> Dict[str, Any]:
    fields = [
        "number",
        "title",
        "body",
        "url",
        "author",
        "createdAt",
        "updatedAt",
        "additions",
        "deletions",
        "changedFiles",
        "commits",
        "baseRefName",
        "headRefName",
        "headRepositoryOwner",
    ]
    code, out, err = _run(["gh", "pr", "view", str(pr_number), "-R", repo, "--json", ",".join(fields)])
    if code != 0:
        print(f"Error: failed to fetch PR via gh: {err.strip()}", file=sys.stderr)
        sys.exit(1)
    try:
        data = json.loads(out)
    except json.JSONDecodeError as e:
        print(f"Error: failed to parse gh JSON output: {e}", file=sys.stderr)
        sys.exit(1)
    return data


def pr_combined_diff(repo: str, pr: Dict[str, Any]) -> str:
    # Prefer a single combined diff between base and head.
    base = pr.get("baseRefName")
    head_branch = pr.get("headRefName")
    head_owner = (pr.get("headRepositoryOwner") or {}).get("login")
    if not base or not head_branch:
        # Fallback to gh pr diff if fields unavailable
        code, out, err = _run(["gh", "pr", "diff", str(pr.get("number")), "-R", repo, "--color=never"])
        if code != 0:
            print(f"Error: failed to fetch PR diff: {err.strip()}", file=sys.stderr)
            sys.exit(1)
        return out.rstrip()

    base_owner = repo.split("/", 1)[0]
    if head_owner and head_owner != base_owner:
        head = f"{head_owner}:{head_branch}"
    else:
        head = head_branch

    path = f"/repos/{repo}/compare/{quote(base, safe='')}...{quote(head, safe='')}"
    code, out, err = _run(["gh", "api", "-H", "Accept: application/vnd.github.v3.diff", path])
    if code == 0 and out.strip():
        return out.rstrip()
    # Fallback
    code, out, err = _run(["gh", "pr", "diff", str(pr.get("number")), "-R", repo, "--color=never"])
    if code != 0:
        print(f"Error: failed to fetch PR diff: {err.strip()}", file=sys.stderr)
        sys.exit(1)
    return out.rstrip()


def pr_review_comments(repo: str, pr_number: int) -> List[Dict[str, Any]]:
    # Pull Request Review Comments (code comments). Fetch up to 1000 via pages of 100.
    all_comments: List[Dict[str, Any]] = []
    page = 1
    while True:
        path = f"/repos/{repo}/pulls/{pr_number}/comments?per_page=100&page={page}"
        code, out, err = _run(["gh", "api", path])
        if code != 0:
            print(f"Error: failed to fetch review comments: {err.strip()}", file=sys.stderr)
            sys.exit(1)
        try:
            batch = json.loads(out)
        except json.JSONDecodeError:
            print("Error: could not parse review comments JSON.", file=sys.stderr)
            sys.exit(1)
        if not batch:
            break
        all_comments.extend(batch)
        if len(batch) < 100:
            break
        page += 1
        if page > 10:  # safety cap
            break
    return all_comments


def parse_repo_from_url(url: str) -> Optional[str]:
    u = url.strip()
    if not u:
        return None

    # Common forms:
    # - SSH scp-like: <user>@github.com:owner/repo.git
    # - SSH URL: ssh://<user>@github.com/owner/repo.git
    # - HTTPS: https://github.com/owner/repo(.git)
    # - Bare: github.com/owner/repo(.git)
    if "github.com:" in u:
        # scp-like syntax
        path = u.split("github.com:", 1)[1]
    elif "github.com/" in u:
        path = u.split("github.com/", 1)[1]
    elif u.startswith("github.com/"):
        path = u.split("github.com/", 1)[1]
    else:
        return None

    # Remove trailing .git if present
    if path.endswith(".git"):
        path = path[:-4]

    # Keep only owner/repo
    parts = path.strip("/").split("/")
    if len(parts) >= 2:
        return f"{parts[0]}/{parts[1]}"
    return None


def detect_repo_from_git() -> Optional[str]:
    # Ensure we're inside a git repo
    code, out, _ = _run(["git", "rev-parse", "--is-inside-work-tree"])
    if code != 0 or out.strip() != "true":
        return None
    code, origin_url, _ = _run(["git", "config", "--get", "remote.origin.url"])
    if code != 0:
        return None
    return parse_repo_from_url(origin_url)


def blockquote(text: str) -> str:
    lines = text.splitlines() or [""]
    return "\n".join("> " + ln for ln in lines)


def format_header(pr: Dict[str, Any]) -> str:
    number = pr.get("number")
    title = pr.get("title", "")
    url = pr.get("url", "")
    author_login = (pr.get("author") or {}).get("login", "")
    created = iso_to_utc_str(pr.get("createdAt"))
    updated = iso_to_utc_str(pr.get("updatedAt"))
    additions = pr.get("additions", 0)
    deletions = pr.get("deletions", 0)
    changed_files = pr.get("changedFiles", 0)
    commits_obj = pr.get("commits")
    if isinstance(commits_obj, dict):
        if "totalCount" in commits_obj and isinstance(commits_obj["totalCount"], (int, float)):
            commits_count = int(commits_obj["totalCount"])  # GraphQL connection
        elif "nodes" in commits_obj and isinstance(commits_obj["nodes"], list):
            commits_count = len(commits_obj["nodes"])  # fallback
        else:
            commits_count = None
    elif isinstance(commits_obj, list):
        commits_count = len(commits_obj)
    elif isinstance(commits_obj, (int, float)):
        commits_count = int(commits_obj)
    else:
        commits_count = None
    commits_str = str(commits_count) if commits_count is not None else "?"

    lines = []
    lines.append(f"# PR #{number}: {title}")
    lines.append("")
    lines.append(f"- URL: {url}")
    lines.append(f"- Author: {author_login}")
    lines.append(f"- Created: {created}")
    lines.append(f"- Updated: {updated}")
    lines.append(f"- Changes: +{additions}/-{deletions}, Files changed: {changed_files}, Commits: {commits_str}")
    return "\n".join(lines)


def format_description(body: Optional[str]) -> str:
    desc = body or ""
    desc = desc.strip()
    if not desc:
        desc = "(No description.)"
    return f"\n## Description\n\n{desc}\n"


def format_diff(diff_text: str) -> str:
    return f"\n## Full Diff\n\n```diff\n{diff_text}\n```\n"


def format_review_comments(comments: List[Dict[str, Any]], reviewer: Optional[str]) -> str:
    if reviewer:
        reviewer_lc = reviewer.lower()
        comments = [c for c in comments if ((c.get("user") or {}).get("login", "").lower() == reviewer_lc)]

    if not comments:
        return "\n## Review Comments\n\n(No review comments.)\n"

    # Group by file path, preserve PR order but sort paths for stable output
    by_path: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
    for c in comments:
        by_path[c.get("path", "(unknown)")].append(c)

    out_lines: List[str] = []
    out_lines.append("\n## Review Comments\n")
    for path in sorted(by_path.keys()):
        out_lines.append(f"### {path}\n")
        for c in by_path[path]:
            created = iso_to_utc_str(c.get("created_at"))
            url = c.get("html_url", "")
            diff_hunk = c.get("diff_hunk", "").rstrip()
            body = c.get("body", "")
            out_lines.append(f"- Created: {created} | Link: {url}")
            out_lines.append("")
            if diff_hunk:
                out_lines.append("```diff")
                out_lines.append(diff_hunk)
                out_lines.append("```")
                out_lines.append("")
            if body:
                out_lines.append(blockquote(body))
                out_lines.append("")
    return "\n".join(out_lines).rstrip() + "\n"


def main():
    parser = argparse.ArgumentParser(
        prog="pr2md",
        description=(
            "Render a GitHub PR into Markdown including description, full diff, and review comments.\n"
            "Requires GitHub CLI (gh) to be installed and authenticated."
        ),
    )
    parser.add_argument(
        "pr_number",
        nargs="?",
        help="Pull request number (optional; auto-detect from current branch if omitted)",
    )
    parser.add_argument(
        "repo",
        nargs="?",
        help="Repository in 'owner/repo' form; inferred from git origin if omitted",
    )
    parser.add_argument("--reviewer", help="Only include comments from this reviewer (login)")

    args = parser.parse_args()

    require_gh()

    # Disambiguate single positional arg: if only one is provided and it looks like owner/repo,
    # treat it as repo, not PR number.
    if args.pr_number and not args.repo and "/" in args.pr_number and not args.pr_number.isdigit():
        args.repo, args.pr_number = args.pr_number, None

    repo = args.repo or detect_repo_from_git()
    if not repo:
        print(
            "Error: Could not determine repository from git origin. Pass repo as 'owner/repo'.",
            file=sys.stderr,
        )
        sys.exit(2)

    # Determine PR number: use provided, else try to find open/draft PR for current branch
    pr_number: Optional[int]
    if args.pr_number:
        try:
            pr_number = int(args.pr_number)
        except ValueError:
            print("Error: PR number must be an integer.", file=sys.stderr)
            sys.exit(2)
    else:
        # Detect from current branch
        code, branch_out, _ = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
        branch = branch_out.strip() if code == 0 else ""
        if not branch or branch == "HEAD":
            print("Error: Not on a branch. Provide a PR number explicitly.", file=sys.stderr)
            sys.exit(2)

        # Query open PRs and find one with matching head branch and owner
        owner = repo.split("/", 1)[0]
        code, out, err = _run([
            "gh", "pr", "list", "-R", repo, "--state", "open",
            "--json", "number,headRefName,isDraft,headRepositoryOwner",
        ])
        if code != 0:
            print(f"Error: failed to list PRs: {err.strip()}", file=sys.stderr)
            sys.exit(1)
        try:
            pr_list = json.loads(out)
        except json.JSONDecodeError:
            print("Error: failed to parse PR list JSON.", file=sys.stderr)
            sys.exit(1)

        candidates = [
            pr for pr in pr_list
            if pr.get("headRefName") == branch and ((pr.get("headRepositoryOwner") or {}).get("login") == owner)
        ]
        if not candidates:
            # Relax owner constraint if none found
            candidates = [pr for pr in pr_list if pr.get("headRefName") == branch]
        if not candidates:
            print(
                f"Error: No open PR found for branch '{branch}'. Provide a PR number.",
                file=sys.stderr,
            )
            sys.exit(2)
        # If multiple, pick the first
        pr_number = int(candidates[0]["number"])

    pr = pr_view(repo, pr_number)
    diff_text = pr_combined_diff(repo, pr)
    comments = pr_review_comments(repo, pr_number)

    parts = [
        format_header(pr),
        format_description(pr.get("body")),
        format_diff(diff_text),
        format_review_comments(comments, args.reviewer),
    ]
    sys.stdout.write("\n".join(p.rstrip() for p in parts if p))


if __name__ == "__main__":
    main()
