#!/usr/bin/env python3

import argparse
import base64
import json
import re
import subprocess
import sys


REPO = "openai/codex"
BRANCH_REF = "heads/main"
CARGO_TOML_PATH = "codex-rs/Cargo.toml"


def parse_args(argv: list[str]) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Publish a tagged Codex release.")
    parser.add_argument(
        "-n",
        "--dry-run",
        action="store_true",
        help="Print the version that would be used and exit before making changes.",
    )

    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument(
        "--publish-alpha",
        action="store_true",
        help="Publish the next alpha release for the upcoming minor version.",
    )
    group.add_argument(
        "--publish-release",
        action="store_true",
        help="Publish the next stable release by bumping the minor version.",
    )
    return parser.parse_args(argv[1:])


def main(argv: list[str]) -> int:
    args = parse_args(argv)
    try:
        version = determine_version(args)
        print(f"Publishing version {version}")
        if args.dry_run:
            return 0

        print("Fetching branch head...")
        base_commit = get_branch_head()
        print(f"Base commit: {base_commit}")
        print("Fetching commit tree...")
        base_tree = get_commit_tree(base_commit)
        print(f"Base tree: {base_tree}")
        print("Fetching Cargo.toml...")
        current_contents = fetch_file_contents(base_commit)
        print("Updating version...")
        updated_contents = replace_version(current_contents, version)
        print("Creating blob...")
        blob_sha = create_blob(updated_contents)
        print(f"Blob SHA: {blob_sha}")
        print("Creating tree...")
        tree_sha = create_tree(base_tree, blob_sha)
        print(f"Tree SHA: {tree_sha}")
        print("Creating commit...")
        commit_sha = create_commit(version, tree_sha, base_commit)
        print(f"Commit SHA: {commit_sha}")
        print("Creating tag...")
        tag_sha = create_tag(version, commit_sha)
        print(f"Tag SHA: {tag_sha}")
        print("Creating tag ref...")
        create_tag_ref(version, tag_sha)
        print("Done.")
    except ReleaseError as error:
        print(f"ERROR: {error}", file=sys.stderr)
        return 1
    return 0


class ReleaseError(RuntimeError):
    pass


def run_gh_api(endpoint: str, *, method: str = "GET", payload: dict | None = None) -> dict:
    print(f"Running gh api {method} {endpoint}")
    command = [
        "gh",
        "api",
        endpoint,
        "--method",
        method,
        "-H",
        "Accept: application/vnd.github+json",
    ]
    json_payload = None
    if payload is not None:
        json_payload = json.dumps(payload)
        print(f"Payload: {json_payload}")
        command.extend(["-H", "Content-Type: application/json", "--input", "-"])
    result = subprocess.run(command, text=True, capture_output=True, input=json_payload)
    if result.returncode != 0:
        message = result.stderr.strip() or result.stdout.strip() or "gh api call failed"
        raise ReleaseError(message)
    try:
        return json.loads(result.stdout or "{}")
    except json.JSONDecodeError as error:
        raise ReleaseError("Failed to parse response from gh api.") from error


def get_branch_head() -> str:
    response = run_gh_api(f"/repos/{REPO}/git/refs/{BRANCH_REF}")
    try:
        return response["object"]["sha"]
    except KeyError as error:
        raise ReleaseError("Unable to determine branch head.") from error


def get_commit_tree(commit_sha: str) -> str:
    response = run_gh_api(f"/repos/{REPO}/git/commits/{commit_sha}")
    try:
        return response["tree"]["sha"]
    except KeyError as error:
        raise ReleaseError("Commit response missing tree SHA.") from error


def fetch_file_contents(ref_sha: str) -> str:
    response = run_gh_api(f"/repos/{REPO}/contents/{CARGO_TOML_PATH}?ref={ref_sha}")
    try:
        encoded_content = response["content"].replace("\n", "")
        encoding = response.get("encoding", "")
    except KeyError as error:
        raise ReleaseError("Failed to fetch Cargo.toml contents.") from error

    if encoding != "base64":
        raise ReleaseError(f"Unexpected Cargo.toml encoding: {encoding}")

    try:
        return base64.b64decode(encoded_content).decode("utf-8")
    except (ValueError, UnicodeDecodeError) as error:
        raise ReleaseError("Failed to decode Cargo.toml contents.") from error


def replace_version(contents: str, version: str) -> str:
    updated, matches = re.subn(
        r'^version = "[^"]+"', f'version = "{version}"', contents, count=1, flags=re.MULTILINE
    )
    if matches != 1:
        raise ReleaseError("Unable to update version in Cargo.toml.")
    return updated


def create_blob(content: str) -> str:
    response = run_gh_api(
        f"/repos/{REPO}/git/blobs",
        method="POST",
        payload={"content": content, "encoding": "utf-8"},
    )
    try:
        return response["sha"]
    except KeyError as error:
        raise ReleaseError("Blob creation response missing SHA.") from error


def create_tree(base_tree_sha: str, blob_sha: str) -> str:
    response = run_gh_api(
        f"/repos/{REPO}/git/trees",
        method="POST",
        payload={
            "base_tree": base_tree_sha,
            "tree": [
                {
                    "path": CARGO_TOML_PATH,
                    "mode": "100644",
                    "type": "blob",
                    "sha": blob_sha,
                }
            ],
        },
    )
    try:
        return response["sha"]
    except KeyError as error:
        raise ReleaseError("Tree creation response missing SHA.") from error


def create_commit(version: str, tree_sha: str, parent_sha: str) -> str:
    response = run_gh_api(
        f"/repos/{REPO}/git/commits",
        method="POST",
        payload={
            "message": f"Release {version}",
            "tree": tree_sha,
            "parents": [parent_sha],
        },
    )
    try:
        return response["sha"]
    except KeyError as error:
        raise ReleaseError("Commit creation response missing SHA.") from error


def create_tag(version: str, commit_sha: str) -> str:
    tag_name = f"rust-v{version}"
    response = run_gh_api(
        f"/repos/{REPO}/git/tags",
        method="POST",
        payload={
            "tag": tag_name,
            "message": f"Release {version}",
            "object": commit_sha,
            "type": "commit",
        },
    )
    try:
        return response["sha"]
    except KeyError as error:
        raise ReleaseError("Tag creation response missing SHA.") from error


def create_tag_ref(version: str, tag_sha: str) -> None:
    tag_ref = f"refs/tags/rust-v{version}"
    run_gh_api(
        f"/repos/{REPO}/git/refs",
        method="POST",
        payload={"ref": tag_ref, "sha": tag_sha},
    )


def determine_version(args: argparse.Namespace) -> str:
    latest_version = get_latest_release_version()
    major, minor, patch = parse_semver(latest_version)
    next_minor_version = format_version(major, minor + 1, patch)

    if args.publish_release:
        return next_minor_version

    alpha_prefix = f"{next_minor_version}-alpha."
    releases = list_releases()
    highest_alpha = 0
    found_alpha = False
    for release in releases:
        tag = release.get("tag_name", "")
        candidate = strip_tag_prefix(tag)
        if candidate and candidate.startswith(alpha_prefix):
            suffix = candidate[len(alpha_prefix) :]
            try:
                alpha_number = int(suffix)
            except ValueError:
                continue
            highest_alpha = max(highest_alpha, alpha_number)
            found_alpha = True

    if found_alpha:
        return f"{alpha_prefix}{highest_alpha + 1}"
    return f"{alpha_prefix}1"


def get_latest_release_version() -> str:
    response = run_gh_api(f"/repos/{REPO}/releases/latest")
    tag = response.get("tag_name")
    version = strip_tag_prefix(tag)
    if not version:
        raise ReleaseError("Latest release tag has unexpected format.")
    return version


def list_releases() -> list[dict]:
    response = run_gh_api(f"/repos/{REPO}/releases?per_page=100")
    if not isinstance(response, list):
        raise ReleaseError("Unexpected response when listing releases.")
    return response


def strip_tag_prefix(tag: str | None) -> str | None:
    if not tag:
        return None
    prefix = "rust-v"
    if not tag.startswith(prefix):
        return None
    return tag[len(prefix) :]


def parse_semver(version: str) -> tuple[int, int, int]:
    parts = version.split(".")
    if len(parts) != 3:
        raise ReleaseError(f"Unexpected version format: {version}")
    try:
        return int(parts[0]), int(parts[1]), int(parts[2])
    except ValueError as error:
        raise ReleaseError(f"Version components must be integers: {version}") from error


def format_version(major: int, minor: int, patch: int) -> str:
    return f"{major}.{minor}.{patch}"


if __name__ == "__main__":
    sys.exit(main(sys.argv))
