Files
codex/sdk/python/src/openai_codex/client.py
Ahmed Ibrahim 4c89772314 sdk/python: add first-class login support (#23093)
## Why

The Python SDK can already create threads and run turns, but
authentication still has to be arranged outside the SDK. App-server
already exposes account login, account inspection, logout, and
`account/login/completed` notifications, so SDK users currently have to
work around a missing public client layer for a core setup step.

This change makes authentication a normal SDK workflow while preserving
the backend flow shape: API-key login completes immediately, and
interactive ChatGPT flows return live handles that complete later
through app-server notifications.

## What changed

- Added public sync and async auth methods on `Codex` / `AsyncCodex`:
  - `login_api_key(...)`
  - `login_chatgpt()`
  - `login_chatgpt_device_code()`
  - `account(...)`
  - `logout()`
- Added public browser-login and device-code handle types with
attempt-local `wait()` and `cancel()` helpers. Cancellation stays on the
handle instead of a root-level SDK method.
- Extended the Python app-server client and notification router so login
completion events are routed by `login_id` without consuming unrelated
global notifications.
- Kept login request/handle logic in a focused internal `_login.py`
module so `api.py` remains the public facade instead of absorbing more
auth plumbing.
- Exported the new handle types plus curated account/login response
types from the SDK surfaces.
- Updated SDK docs, added sync/async login walkthrough examples, and
added a notebook login walkthrough cell.

## Verification

Added SDK coverage for:

- API-key login, account readback, and logout through the app-server
harness in both sync and async clients.
- Browser login cancellation plus `handle.wait()` completion through the
real app-server boundary used by the Python SDK harness.
- Waiter routing that stays scoped across replaced interactive login
attempts, plus async handle cancellation coverage.
- Login notification demuxing, replay of early completion events, and
async client delegation.
- Public export/signature assertions.
- Real integration-suite smoke coverage for the new examples and
notebook login cell.
2026-05-16 19:49:28 -07:00

648 lines
22 KiB
Python

from __future__ import annotations
import json
import os
import subprocess
import threading
import uuid
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Iterator, TypeVar
from pydantic import BaseModel
from ._message_router import MessageRouter
from ._version import __version__ as SDK_VERSION
from .errors import AppServerError, TransportClosedError
from .generated.notification_registry import NOTIFICATION_MODELS
from .generated.v2_all import (
AccountLoginCompletedNotification,
AgentMessageDeltaNotification,
CancelLoginAccountResponse,
ChatgptDeviceCodeLoginAccountResponse,
ChatgptLoginAccountResponse,
GetAccountParams as V2GetAccountParams,
GetAccountResponse,
LoginAccountParams as V2LoginAccountParams,
LoginAccountResponse,
LogoutAccountResponse,
ModelListResponse,
ThreadArchiveResponse,
ThreadCompactStartResponse,
ThreadForkParams as V2ThreadForkParams,
ThreadForkResponse,
ThreadListParams as V2ThreadListParams,
ThreadListResponse,
ThreadReadResponse,
ThreadResumeParams as V2ThreadResumeParams,
ThreadResumeResponse,
ThreadSetNameResponse,
ThreadStartParams as V2ThreadStartParams,
ThreadStartResponse,
ThreadUnarchiveResponse,
TurnCompletedNotification,
TurnInterruptResponse,
TurnStartParams as V2TurnStartParams,
TurnStartResponse,
TurnSteerResponse,
)
from .models import (
InitializeResponse,
JsonObject,
JsonValue,
Notification,
UnknownNotification,
)
from .retry import retry_on_overload
ModelT = TypeVar("ModelT", bound=BaseModel)
ApprovalHandler = Callable[[str, JsonObject | None], JsonObject]
RUNTIME_PKG_NAME = "openai-codex-cli-bin"
def _params_dict(
params: (
V2ThreadStartParams
| V2ThreadResumeParams
| V2ThreadListParams
| V2ThreadForkParams
| V2TurnStartParams
| V2GetAccountParams
| V2LoginAccountParams
| JsonObject
| None
),
) -> JsonObject:
if params is None:
return {}
if hasattr(params, "model_dump"):
dumped = params.model_dump(
by_alias=True,
exclude_none=True,
mode="json",
)
if not isinstance(dumped, dict):
raise TypeError("Expected model_dump() to return dict")
return dumped
if isinstance(params, dict):
return params
raise TypeError(f"Expected generated params model or dict, got {type(params).__name__}")
def _installed_codex_path() -> Path:
try:
from codex_cli_bin import bundled_codex_path
except ImportError as exc:
raise FileNotFoundError(
"Unable to locate the pinned Codex runtime. Install the published SDK build "
f"with its {RUNTIME_PKG_NAME} dependency, or set AppServerConfig.codex_bin "
"explicitly."
) from exc
return bundled_codex_path()
@dataclass(frozen=True)
class CodexBinResolverOps:
installed_codex_path: Callable[[], Path]
path_exists: Callable[[Path], bool]
def _default_codex_bin_resolver_ops() -> CodexBinResolverOps:
return CodexBinResolverOps(
installed_codex_path=_installed_codex_path,
path_exists=lambda path: path.exists(),
)
def resolve_codex_bin(config: "AppServerConfig", ops: CodexBinResolverOps) -> Path:
if config.codex_bin is not None:
codex_bin = Path(config.codex_bin)
if not ops.path_exists(codex_bin):
raise FileNotFoundError(
f"Codex binary not found at {codex_bin}. Set AppServerConfig.codex_bin "
"to a valid binary path."
)
return codex_bin
return ops.installed_codex_path()
def _resolve_codex_bin(config: "AppServerConfig") -> Path:
return resolve_codex_bin(config, _default_codex_bin_resolver_ops())
@dataclass(slots=True)
class AppServerConfig:
codex_bin: str | None = None
launch_args_override: tuple[str, ...] | None = None
config_overrides: tuple[str, ...] = ()
cwd: str | None = None
env: dict[str, str] | None = None
client_name: str = "codex_python_sdk"
client_title: str = "Codex Python SDK"
client_version: str = SDK_VERSION
experimental_api: bool = True
class AppServerClient:
"""Synchronous typed JSON-RPC client for `codex app-server` over stdio."""
def __init__(
self,
config: AppServerConfig | None = None,
approval_handler: ApprovalHandler | None = None,
) -> None:
self.config = config or AppServerConfig()
self._approval_handler = approval_handler or self._default_approval_handler
self._proc: subprocess.Popen[str] | None = None
self._lock = threading.Lock()
self._router = MessageRouter()
self._stderr_lines: deque[str] = deque(maxlen=400)
self._stderr_thread: threading.Thread | None = None
self._reader_thread: threading.Thread | None = None
def __enter__(self) -> "AppServerClient":
self.start()
return self
def __exit__(self, _exc_type, _exc, _tb) -> None:
self.close()
def start(self) -> None:
if self._proc is not None:
return
if self.config.launch_args_override is not None:
args = list(self.config.launch_args_override)
else:
codex_bin = _resolve_codex_bin(self.config)
args = [str(codex_bin)]
for kv in self.config.config_overrides:
args.extend(["--config", kv])
args.extend(["app-server", "--listen", "stdio://"])
env = os.environ.copy()
if self.config.env:
env.update(self.config.env)
self._proc = subprocess.Popen(
args,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding="utf-8",
cwd=self.config.cwd,
env=env,
bufsize=1,
)
self._start_stderr_drain_thread()
self._start_reader_thread()
def close(self) -> None:
if self._proc is None:
return
proc = self._proc
self._proc = None
if proc.stdin:
proc.stdin.close()
try:
proc.terminate()
proc.wait(timeout=2)
except Exception:
proc.kill()
if self._stderr_thread and self._stderr_thread.is_alive():
self._stderr_thread.join(timeout=0.5)
if self._reader_thread and self._reader_thread.is_alive():
self._reader_thread.join(timeout=0.5)
def initialize(self) -> InitializeResponse:
result = self.request(
"initialize",
{
"clientInfo": {
"name": self.config.client_name,
"title": self.config.client_title,
"version": self.config.client_version,
},
"capabilities": {
"experimentalApi": self.config.experimental_api,
},
},
response_model=InitializeResponse,
)
self.notify("initialized", None)
return result
def request(
self,
method: str,
params: JsonObject | None,
*,
response_model: type[ModelT],
) -> ModelT:
result = self._request_raw(method, params)
if not isinstance(result, dict):
raise AppServerError(f"{method} response must be a JSON object")
return response_model.model_validate(result)
def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue:
"""Send a JSON-RPC request and wait for the reader thread to route its response."""
request_id = str(uuid.uuid4())
waiter = self._router.create_response_waiter(request_id)
try:
message: JsonObject = {"id": request_id, "method": method}
if params is not None:
message["params"] = params
self._write_message(message)
except BaseException:
self._router.discard_response_waiter(request_id)
raise
item = waiter.get()
if isinstance(item, BaseException):
raise item
return item
def notify(self, method: str, params: JsonObject | None = None) -> None:
"""Send a JSON-RPC notification without waiting for a response."""
message: JsonObject = {"method": method}
if params is not None:
message["params"] = params
self._write_message(message)
def next_notification(self) -> Notification:
"""Return the next notification that is not scoped to an active turn."""
return self._router.next_global_notification()
def register_login_notifications(self, login_id: str) -> None:
"""Start routing notifications for one interactive login attempt."""
self._router.register_login(login_id)
def unregister_login_notifications(self, login_id: str) -> None:
"""Stop routing notifications for one interactive login attempt."""
self._router.unregister_login(login_id)
def next_login_notification(self, login_id: str) -> Notification:
"""Return the next routed notification for the requested login id."""
return self._router.next_login_notification(login_id)
def register_turn_notifications(self, turn_id: str) -> None:
"""Start routing notifications for one turn into its dedicated queue."""
self._router.register_turn(turn_id)
def unregister_turn_notifications(self, turn_id: str) -> None:
"""Stop routing notifications for one turn into its dedicated queue."""
self._router.unregister_turn(turn_id)
def next_turn_notification(self, turn_id: str) -> Notification:
"""Return the next routed notification for the requested turn id."""
return self._router.next_turn_notification(turn_id)
def account_login_start(
self,
params: V2LoginAccountParams | JsonObject,
) -> LoginAccountResponse:
response = self.request(
"account/login/start",
_params_dict(params),
response_model=LoginAccountResponse,
)
response_root = response.root
if isinstance(
response_root,
ChatgptLoginAccountResponse | ChatgptDeviceCodeLoginAccountResponse,
):
self.register_login_notifications(response_root.login_id)
return response
def account_login_cancel(self, login_id: str) -> CancelLoginAccountResponse:
return self.request(
"account/login/cancel",
{"loginId": login_id},
response_model=CancelLoginAccountResponse,
)
def account_read(
self,
params: V2GetAccountParams | JsonObject | None = None,
) -> GetAccountResponse:
return self.request(
"account/read",
_params_dict(params),
response_model=GetAccountResponse,
)
def account_logout(self) -> LogoutAccountResponse:
return self.request("account/logout", None, response_model=LogoutAccountResponse)
def thread_start(
self, params: V2ThreadStartParams | JsonObject | None = None
) -> ThreadStartResponse:
return self.request(
"thread/start", _params_dict(params), response_model=ThreadStartResponse
)
def thread_resume(
self,
thread_id: str,
params: V2ThreadResumeParams | JsonObject | None = None,
) -> ThreadResumeResponse:
payload = {"threadId": thread_id, **_params_dict(params)}
return self.request("thread/resume", payload, response_model=ThreadResumeResponse)
def thread_list(
self, params: V2ThreadListParams | JsonObject | None = None
) -> ThreadListResponse:
return self.request("thread/list", _params_dict(params), response_model=ThreadListResponse)
def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse:
return self.request(
"thread/read",
{"threadId": thread_id, "includeTurns": include_turns},
response_model=ThreadReadResponse,
)
def thread_fork(
self,
thread_id: str,
params: V2ThreadForkParams | JsonObject | None = None,
) -> ThreadForkResponse:
payload = {"threadId": thread_id, **_params_dict(params)}
return self.request("thread/fork", payload, response_model=ThreadForkResponse)
def thread_archive(self, thread_id: str) -> ThreadArchiveResponse:
return self.request(
"thread/archive",
{"threadId": thread_id},
response_model=ThreadArchiveResponse,
)
def thread_unarchive(self, thread_id: str) -> ThreadUnarchiveResponse:
return self.request(
"thread/unarchive",
{"threadId": thread_id},
response_model=ThreadUnarchiveResponse,
)
def thread_set_name(self, thread_id: str, name: str) -> ThreadSetNameResponse:
return self.request(
"thread/name/set",
{"threadId": thread_id, "name": name},
response_model=ThreadSetNameResponse,
)
def thread_compact(self, thread_id: str) -> ThreadCompactStartResponse:
return self.request(
"thread/compact/start",
{"threadId": thread_id},
response_model=ThreadCompactStartResponse,
)
def turn_start(
self,
thread_id: str,
input_items: list[JsonObject] | JsonObject | str,
params: V2TurnStartParams | JsonObject | None = None,
) -> TurnStartResponse:
"""Start a turn and register its notification queue as early as possible."""
payload = {
**_params_dict(params),
"threadId": thread_id,
"input": self._normalize_input_items(input_items),
}
started = self.request("turn/start", payload, response_model=TurnStartResponse)
self.register_turn_notifications(started.turn.id)
return started
def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
return self.request(
"turn/interrupt",
{"threadId": thread_id, "turnId": turn_id},
response_model=TurnInterruptResponse,
)
def turn_steer(
self,
thread_id: str,
expected_turn_id: str,
input_items: list[JsonObject] | JsonObject | str,
) -> TurnSteerResponse:
return self.request(
"turn/steer",
{
"threadId": thread_id,
"expectedTurnId": expected_turn_id,
"input": self._normalize_input_items(input_items),
},
response_model=TurnSteerResponse,
)
def model_list(self, include_hidden: bool = False) -> ModelListResponse:
return self.request(
"model/list",
{"includeHidden": include_hidden},
response_model=ModelListResponse,
)
def request_with_retry_on_overload(
self,
method: str,
params: JsonObject | None,
*,
response_model: type[ModelT],
max_attempts: int = 3,
initial_delay_s: float = 0.25,
max_delay_s: float = 2.0,
) -> ModelT:
return retry_on_overload(
lambda: self.request(method, params, response_model=response_model),
max_attempts=max_attempts,
initial_delay_s=initial_delay_s,
max_delay_s=max_delay_s,
)
def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
"""Block on the routed turn stream until the matching completion arrives."""
self.register_turn_notifications(turn_id)
try:
while True:
notification = self.next_turn_notification(turn_id)
if (
notification.method == "turn/completed"
and isinstance(notification.payload, TurnCompletedNotification)
and notification.payload.turn.id == turn_id
):
return notification.payload
finally:
self.unregister_turn_notifications(turn_id)
def wait_for_login_completed(
self,
login_id: str,
) -> AccountLoginCompletedNotification:
"""Block until the matching interactive login attempt completes."""
self.register_login_notifications(login_id)
try:
while True:
notification = self.next_login_notification(login_id)
if (
notification.method == "account/login/completed"
and isinstance(notification.payload, AccountLoginCompletedNotification)
and notification.payload.login_id == login_id
):
return notification.payload
finally:
self.unregister_login_notifications(login_id)
def stream_text(
self,
thread_id: str,
text: str,
params: V2TurnStartParams | JsonObject | None = None,
) -> Iterator[AgentMessageDeltaNotification]:
"""Start a text turn and yield only its agent-message delta payloads."""
started = self.turn_start(thread_id, text, params=params)
turn_id = started.turn.id
self.register_turn_notifications(turn_id)
try:
while True:
notification = self.next_turn_notification(turn_id)
if (
notification.method == "item/agentMessage/delta"
and isinstance(notification.payload, AgentMessageDeltaNotification)
and notification.payload.turn_id == turn_id
):
yield notification.payload
continue
if (
notification.method == "turn/completed"
and isinstance(notification.payload, TurnCompletedNotification)
and notification.payload.turn.id == turn_id
):
break
finally:
self.unregister_turn_notifications(turn_id)
def _coerce_notification(self, method: str, params: object) -> Notification:
params_dict = params if isinstance(params, dict) else {}
model = NOTIFICATION_MODELS.get(method)
if model is None:
return Notification(method=method, payload=UnknownNotification(params=params_dict))
try:
payload = model.model_validate(params_dict)
except Exception: # noqa: BLE001
return Notification(method=method, payload=UnknownNotification(params=params_dict))
return Notification(method=method, payload=payload)
def _normalize_input_items(
self,
input_items: list[JsonObject] | JsonObject | str,
) -> list[JsonObject]:
if isinstance(input_items, str):
return [{"type": "text", "text": input_items}]
if isinstance(input_items, dict):
return [input_items]
return input_items
def _default_approval_handler(self, method: str, params: JsonObject | None) -> JsonObject:
"""Accept approval requests when the caller did not provide a handler."""
if method == "item/commandExecution/requestApproval":
return {"decision": "accept"}
if method == "item/fileChange/requestApproval":
return {"decision": "accept"}
return {}
def _start_stderr_drain_thread(self) -> None:
if self._proc is None or self._proc.stderr is None:
return
def _drain() -> None:
stderr = self._proc.stderr
if stderr is None:
return
for line in stderr:
self._stderr_lines.append(line.rstrip("\n"))
self._stderr_thread = threading.Thread(target=_drain, daemon=True)
self._stderr_thread.start()
def _start_reader_thread(self) -> None:
"""Start the sole stdout reader that fans messages into router queues."""
if self._proc is None or self._proc.stdout is None:
return
self._reader_thread = threading.Thread(target=self._reader_loop, daemon=True)
self._reader_thread.start()
def _reader_loop(self) -> None:
"""Continuously classify transport messages into requests, responses, and events."""
try:
while True:
msg = self._read_message()
if "method" in msg and "id" in msg:
response = self._handle_server_request(msg)
self._write_message({"id": msg["id"], "result": response})
continue
if "method" in msg and "id" not in msg:
method = msg["method"]
if isinstance(method, str):
self._router.route_notification(
self._coerce_notification(method, msg.get("params"))
)
continue
self._router.route_response(msg)
except BaseException as exc:
self._router.fail_all(exc)
def _stderr_tail(self, limit: int = 40) -> str:
return "\n".join(list(self._stderr_lines)[-limit:])
def _handle_server_request(self, msg: dict[str, JsonValue]) -> JsonObject:
method = msg["method"]
params = msg.get("params")
if not isinstance(method, str):
return {}
return self._approval_handler(
method,
params if isinstance(params, dict) else None,
)
def _write_message(self, payload: JsonObject) -> None:
if self._proc is None or self._proc.stdin is None:
raise TransportClosedError("app-server is not running")
with self._lock:
self._proc.stdin.write(json.dumps(payload) + "\n")
self._proc.stdin.flush()
def _read_message(self) -> dict[str, JsonValue]:
if self._proc is None or self._proc.stdout is None:
raise TransportClosedError("app-server is not running")
line = self._proc.stdout.readline()
if not line:
raise TransportClosedError(
f"app-server closed stdout. stderr_tail={self._stderr_tail()[:2000]}"
)
try:
message = json.loads(line)
except json.JSONDecodeError as exc:
raise AppServerError(f"Invalid JSON-RPC line: {line!r}") from exc
if not isinstance(message, dict):
raise AppServerError(f"Invalid JSON-RPC payload: {message!r}")
return message
def default_codex_home() -> str:
return str(Path.home() / ".codex")