#!/usr/bin/env python3

import argparse
import base64
import json
import re
import subprocess
import sys


REPO = "openai/codex"
BRANCH_REF = "heads/main"
CARGO_TOML_PATH = "codex-rs/Cargo.toml"


def parse_args(argv: list[str]) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Publish a tagged Codex release.")
    parser.add_argument(
        "-n",
        "--dry-run",
        action="store_true",
        help="Print the version that would be used and exit before making changes.",
    )
    parser.add_argument(
        "--promote-alpha",
        metavar="VERSION",
        help="Promote an existing alpha tag (e.g., 0.56.0-alpha.5) by using its merge-base with main as the base commit.",
    )

    group = parser.add_mutually_exclusive_group()
    group.add_argument(
        "--publish-alpha",
        action="store_true",
        help="Publish the next alpha release for the upcoming minor version.",
    )
    group.add_argument(
        "--publish-release",
        action="store_true",
        help="Publish the next stable release by bumping the minor version.",
    )
    parser.add_argument(
        "--emergency-version-override",
        help="Publish a specific version because tag was created for the previous release but it never succeeded. Value should be semver, e.g., `0.43.0-alpha.9`.",
    )

    args = parser.parse_args(argv[1:])
    if not (
        args.publish_alpha
        or args.publish_release
        or args.emergency_version_override
        or args.promote_alpha
    ):
        parser.error(
            "Must specify --publish-alpha, --publish-release, --promote-alpha, or --emergency-version-override."
        )
    return args


def main(argv: list[str]) -> int:
    args = parse_args(argv)

    # Strip the leading "v" if present.
    promote_alpha = args.promote_alpha
    if promote_alpha and promote_alpha.startswith("v"):
        promote_alpha = promote_alpha[1:]

    try:
        if promote_alpha:
            version = derive_release_version_from_alpha(promote_alpha)
        elif args.emergency_version_override:
            version = args.emergency_version_override
        else:
            version = determine_version(args)
        print(f"Publishing version {version}")
        if promote_alpha:
            base_commit = get_promote_alpha_base_commit(promote_alpha)
            if args.dry_run:
                print(
                    f"Would publish version {version} using base commit {base_commit} derived from rust-v{promote_alpha}."
                )
                return 0
        elif args.dry_run:
            return 0

        if not promote_alpha:
            print("Fetching branch head...")
            base_commit = get_branch_head()
        print(f"Base commit: {base_commit}")
        print("Fetching commit tree...")
        base_tree = get_commit_tree(base_commit)
        print(f"Base tree: {base_tree}")
        print("Fetching Cargo.toml...")
        current_contents = fetch_file_contents(base_commit)
        print("Updating version...")
        updated_contents = replace_version(current_contents, version)
        print("Creating blob...")
        blob_sha = create_blob(updated_contents)
        print(f"Blob SHA: {blob_sha}")
        print("Creating tree...")
        tree_sha = create_tree(base_tree, blob_sha)
        print(f"Tree SHA: {tree_sha}")
        print("Creating commit...")
        commit_sha = create_commit(version, tree_sha, base_commit)
        print(f"Commit SHA: {commit_sha}")
        print("Creating tag...")
        tag_sha = create_tag(version, commit_sha)
        print(f"Tag SHA: {tag_sha}")
        print("Creating tag ref...")
        create_tag_ref(version, tag_sha)
        print("Done.")
    except ReleaseError as error:
        print(f"ERROR: {error}", file=sys.stderr)
        return 1
    return 0


class ReleaseError(RuntimeError):
    pass


def run_gh_api(endpoint: str, *, method: str = "GET", payload: dict | None = None) -> dict:
    print(f"Running gh api {method} {endpoint}")
    command = [
        "gh",
        "api",
        endpoint,
        "--method",
        method,
        "-H",
        "Accept: application/vnd.github+json",
    ]
    json_payload = None
    if payload is not None:
        json_payload = json.dumps(payload)
        print(f"Payload: {json_payload}")
        command.extend(["-H", "Content-Type: application/json", "--input", "-"])
    result = subprocess.run(command, text=True, capture_output=True, input=json_payload)
    if result.returncode != 0:
        message = result.stderr.strip() or result.stdout.strip() or "gh api call failed"
        raise ReleaseError(message)
    try:
        return json.loads(result.stdout or "{}")
    except json.JSONDecodeError as error:
        raise ReleaseError("Failed to parse response from gh api.") from error


def get_branch_head() -> str:
    response = run_gh_api(f"/repos/{REPO}/git/refs/{BRANCH_REF}")
    try:
        return response["object"]["sha"]
    except KeyError as error:
        raise ReleaseError("Unable to determine branch head.") from error


def get_promote_alpha_base_commit(alpha_version: str) -> str:
    tag_name = f"rust-v{alpha_version}"
    tag_commit_sha = get_tag_commit_sha(tag_name)
    return get_merge_base_with_main(tag_commit_sha)


def get_tag_commit_sha(tag_name: str) -> str:
    response = run_gh_api(f"/repos/{REPO}/git/refs/tags/{tag_name}")
    try:
        sha = response["object"]["sha"]
        obj_type = response["object"]["type"]
    except KeyError as error:
        raise ReleaseError(f"Unable to resolve tag {tag_name}.") from error
    while obj_type == "tag":
        tag_response = run_gh_api(f"/repos/{REPO}/git/tags/{sha}")
        try:
            sha = tag_response["object"]["sha"]
            obj_type = tag_response["object"]["type"]
        except KeyError as error:
            raise ReleaseError(f"Unable to resolve annotated tag {tag_name}.") from error
    if obj_type != "commit":
        raise ReleaseError(f"Tag {tag_name} does not reference a commit.")
    return sha


def get_merge_base_with_main(commit_sha: str) -> str:
    response = run_gh_api(f"/repos/{REPO}/compare/main...{commit_sha}")
    try:
        return response["merge_base_commit"]["sha"]
    except KeyError as error:
        raise ReleaseError("Unable to determine merge base with main.") from error


def get_commit_tree(commit_sha: str) -> str:
    response = run_gh_api(f"/repos/{REPO}/git/commits/{commit_sha}")
    try:
        return response["tree"]["sha"]
    except KeyError as error:
        raise ReleaseError("Commit response missing tree SHA.") from error


def fetch_file_contents(ref_sha: str) -> str:
    response = run_gh_api(f"/repos/{REPO}/contents/{CARGO_TOML_PATH}?ref={ref_sha}")
    try:
        encoded_content = response["content"].replace("\n", "")
        encoding = response.get("encoding", "")
    except KeyError as error:
        raise ReleaseError("Failed to fetch Cargo.toml contents.") from error

    if encoding != "base64":
        raise ReleaseError(f"Unexpected Cargo.toml encoding: {encoding}")

    try:
        return base64.b64decode(encoded_content).decode("utf-8")
    except (ValueError, UnicodeDecodeError) as error:
        raise ReleaseError("Failed to decode Cargo.toml contents.") from error


def replace_version(contents: str, version: str) -> str:
    updated, matches = re.subn(
        r'^version = "[^"]+"', f'version = "{version}"', contents, count=1, flags=re.MULTILINE
    )
    if matches != 1:
        raise ReleaseError("Unable to update version in Cargo.toml.")
    return updated


def create_blob(content: str) -> str:
    response = run_gh_api(
        f"/repos/{REPO}/git/blobs",
        method="POST",
        payload={"content": content, "encoding": "utf-8"},
    )
    try:
        return response["sha"]
    except KeyError as error:
        raise ReleaseError("Blob creation response missing SHA.") from error


def create_tree(base_tree_sha: str, blob_sha: str) -> str:
    response = run_gh_api(
        f"/repos/{REPO}/git/trees",
        method="POST",
        payload={
            "base_tree": base_tree_sha,
            "tree": [
                {
                    "path": CARGO_TOML_PATH,
                    "mode": "100644",
                    "type": "blob",
                    "sha": blob_sha,
                }
            ],
        },
    )
    try:
        return response["sha"]
    except KeyError as error:
        raise ReleaseError("Tree creation response missing SHA.") from error


def create_commit(version: str, tree_sha: str, parent_sha: str) -> str:
    response = run_gh_api(
        f"/repos/{REPO}/git/commits",
        method="POST",
        payload={
            "message": f"Release {version}",
            "tree": tree_sha,
            "parents": [parent_sha],
        },
    )
    try:
        return response["sha"]
    except KeyError as error:
        raise ReleaseError("Commit creation response missing SHA.") from error


def create_tag(version: str, commit_sha: str) -> str:
    tag_name = f"rust-v{version}"
    response = run_gh_api(
        f"/repos/{REPO}/git/tags",
        method="POST",
        payload={
            "tag": tag_name,
            "message": f"Release {version}",
            "object": commit_sha,
            "type": "commit",
        },
    )
    try:
        return response["sha"]
    except KeyError as error:
        raise ReleaseError("Tag creation response missing SHA.") from error


def create_tag_ref(version: str, tag_sha: str) -> None:
    tag_ref = f"refs/tags/rust-v{version}"
    run_gh_api(
        f"/repos/{REPO}/git/refs",
        method="POST",
        payload={"ref": tag_ref, "sha": tag_sha},
    )


def determine_version(args: argparse.Namespace) -> str:
    latest_version = get_latest_release_version()
    # When determining the next version after the current release,
    # we should always increment the minor version, but reset the
    # patch to zero. In practice, `patch` should only be non-zero if
    # --emergency-version-override was used.
    major, minor, _patch = parse_semver(latest_version)
    next_minor_version = format_version(major, minor + 1, 0)

    if args.publish_release:
        return next_minor_version

    alpha_prefix = f"{next_minor_version}-alpha."
    releases = list_releases()
    highest_alpha = 0
    found_alpha = False
    for release in releases:
        tag = release.get("tag_name", "")
        candidate = strip_tag_prefix(tag)
        if candidate and candidate.startswith(alpha_prefix):
            suffix = candidate[len(alpha_prefix) :]
            try:
                alpha_number = int(suffix)
            except ValueError:
                continue
            highest_alpha = max(highest_alpha, alpha_number)
            found_alpha = True

    if found_alpha:
        return f"{alpha_prefix}{highest_alpha + 1}"
    return f"{alpha_prefix}1"


def get_latest_release_version() -> str:
    response = run_gh_api(f"/repos/{REPO}/releases/latest")
    tag = response.get("tag_name")
    version = strip_tag_prefix(tag)
    if not version:
        raise ReleaseError("Latest release tag has unexpected format.")
    return version


def list_releases() -> list[dict]:
    response = run_gh_api(f"/repos/{REPO}/releases?per_page=100")
    if not isinstance(response, list):
        raise ReleaseError("Unexpected response when listing releases.")
    return response


def strip_tag_prefix(tag: str | None) -> str | None:
    if not tag:
        return None
    prefix = "rust-v"
    if not tag.startswith(prefix):
        return None
    return tag[len(prefix) :]


def parse_semver(version: str) -> tuple[int, int, int]:
    parts = version.split(".")
    if len(parts) != 3:
        raise ReleaseError(f"Unexpected version format: {version}")
    try:
        return int(parts[0]), int(parts[1]), int(parts[2])
    except ValueError as error:
        raise ReleaseError(f"Version components must be integers: {version}") from error


def format_version(major: int, minor: int, patch: int) -> str:
    return f"{major}.{minor}.{patch}"


def derive_release_version_from_alpha(alpha_version: str) -> str:
    match = re.match(r"^(\d+)\.(\d+)\.(\d+)-alpha\.(\d+)$", alpha_version)
    if match is None:
        raise ReleaseError(f"Unexpected alpha version format: {alpha_version}")
    return f"{match.group(1)}.{match.group(2)}.{match.group(3)}"


if __name__ == "__main__":
    sys.exit(main(sys.argv))
