mirror of
https://github.com/openai/codex.git
synced 2026-05-24 13:04:29 +00:00
## Why The Python SDK can already create threads and run turns, but authentication still has to be arranged outside the SDK. App-server already exposes account login, account inspection, logout, and `account/login/completed` notifications, so SDK users currently have to work around a missing public client layer for a core setup step. This change makes authentication a normal SDK workflow while preserving the backend flow shape: API-key login completes immediately, and interactive ChatGPT flows return live handles that complete later through app-server notifications. ## What changed - Added public sync and async auth methods on `Codex` / `AsyncCodex`: - `login_api_key(...)` - `login_chatgpt()` - `login_chatgpt_device_code()` - `account(...)` - `logout()` - Added public browser-login and device-code handle types with attempt-local `wait()` and `cancel()` helpers. Cancellation stays on the handle instead of a root-level SDK method. - Extended the Python app-server client and notification router so login completion events are routed by `login_id` without consuming unrelated global notifications. - Kept login request/handle logic in a focused internal `_login.py` module so `api.py` remains the public facade instead of absorbing more auth plumbing. - Exported the new handle types plus curated account/login response types from the SDK surfaces. - Updated SDK docs, added sync/async login walkthrough examples, and added a notebook login walkthrough cell. ## Verification Added SDK coverage for: - API-key login, account readback, and logout through the app-server harness in both sync and async clients. - Browser login cancellation plus `handle.wait()` completion through the real app-server boundary used by the Python SDK harness. - Waiter routing that stays scoped across replaced interactive login attempts, plus async handle cancellation coverage. - Login notification demuxing, replay of early completion events, and async client delegation. - Public export/signature assertions. - Real integration-suite smoke coverage for the new examples and notebook login cell.
223 lines
8.9 KiB
Python
223 lines
8.9 KiB
Python
from __future__ import annotations
|
|
|
|
import queue
|
|
import threading
|
|
from collections import deque
|
|
|
|
from .errors import AppServerError, map_jsonrpc_error
|
|
from .generated.notification_registry import notification_turn_id
|
|
from .generated.v2_all import AccountLoginCompletedNotification
|
|
from .models import JsonValue, Notification, UnknownNotification
|
|
|
|
ResponseQueueItem = JsonValue | BaseException
|
|
NotificationQueueItem = Notification | BaseException
|
|
|
|
|
|
class MessageRouter:
|
|
"""Route reader-thread messages to the SDK operation waiting for them.
|
|
|
|
The app-server stdio transport is a single ordered stream, so only the
|
|
reader thread should consume stdout. This router keeps the rest of the SDK
|
|
from competing for that stream by giving each in-flight JSON-RPC request
|
|
and active turn stream its own queue.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
"""Create empty response, turn, and global notification queues."""
|
|
self._lock = threading.Lock()
|
|
self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {}
|
|
self._login_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {}
|
|
self._pending_login_notifications: dict[str, deque[Notification]] = {}
|
|
self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {}
|
|
self._pending_turn_notifications: dict[str, deque[Notification]] = {}
|
|
self._global_notifications: queue.Queue[NotificationQueueItem] = queue.Queue()
|
|
|
|
def create_response_waiter(self, request_id: str) -> queue.Queue[ResponseQueueItem]:
|
|
"""Register a one-shot queue for a JSON-RPC response id."""
|
|
|
|
waiter: queue.Queue[ResponseQueueItem] = queue.Queue(maxsize=1)
|
|
with self._lock:
|
|
self._response_waiters[request_id] = waiter
|
|
return waiter
|
|
|
|
def discard_response_waiter(self, request_id: str) -> None:
|
|
"""Remove a response waiter when the request could not be written."""
|
|
|
|
with self._lock:
|
|
self._response_waiters.pop(request_id, None)
|
|
|
|
def next_global_notification(self) -> Notification:
|
|
"""Block until the next notification that is not scoped to a turn."""
|
|
|
|
item = self._global_notifications.get()
|
|
if isinstance(item, BaseException):
|
|
raise item
|
|
return item
|
|
|
|
def register_login(self, login_id: str) -> None:
|
|
"""Register a queue for one interactive login attempt."""
|
|
|
|
login_queue: queue.Queue[NotificationQueueItem] = queue.Queue()
|
|
with self._lock:
|
|
if login_id in self._login_notifications:
|
|
return
|
|
pending = self._pending_login_notifications.pop(login_id, deque())
|
|
self._login_notifications[login_id] = login_queue
|
|
for notification in pending:
|
|
login_queue.put(notification)
|
|
|
|
def unregister_login(self, login_id: str) -> None:
|
|
"""Stop routing future notifications for one login attempt."""
|
|
|
|
with self._lock:
|
|
self._login_notifications.pop(login_id, None)
|
|
|
|
def next_login_notification(self, login_id: str) -> Notification:
|
|
"""Block until the next notification for a registered login attempt."""
|
|
|
|
with self._lock:
|
|
login_queue = self._login_notifications.get(login_id)
|
|
if login_queue is None:
|
|
raise RuntimeError(f"login {login_id!r} is not registered for waiting")
|
|
item = login_queue.get()
|
|
if isinstance(item, BaseException):
|
|
raise item
|
|
return item
|
|
|
|
def register_turn(self, turn_id: str) -> None:
|
|
"""Register a queue for a turn stream and replay early events."""
|
|
|
|
turn_queue: queue.Queue[NotificationQueueItem] = queue.Queue()
|
|
with self._lock:
|
|
if turn_id in self._turn_notifications:
|
|
return
|
|
# A turn can emit events immediately after turn/start, before the
|
|
# caller receives the TurnHandle and starts streaming.
|
|
pending = self._pending_turn_notifications.pop(turn_id, deque())
|
|
self._turn_notifications[turn_id] = turn_queue
|
|
for notification in pending:
|
|
turn_queue.put(notification)
|
|
|
|
def unregister_turn(self, turn_id: str) -> None:
|
|
"""Stop routing future turn events to the stream queue."""
|
|
|
|
with self._lock:
|
|
self._turn_notifications.pop(turn_id, None)
|
|
|
|
def next_turn_notification(self, turn_id: str) -> Notification:
|
|
"""Block until the next notification for a registered turn."""
|
|
|
|
with self._lock:
|
|
turn_queue = self._turn_notifications.get(turn_id)
|
|
if turn_queue is None:
|
|
raise RuntimeError(f"turn {turn_id!r} is not registered for streaming")
|
|
item = turn_queue.get()
|
|
if isinstance(item, BaseException):
|
|
raise item
|
|
return item
|
|
|
|
def route_response(self, msg: dict[str, JsonValue]) -> None:
|
|
"""Deliver a JSON-RPC response or error to its request waiter."""
|
|
|
|
request_id = msg.get("id")
|
|
with self._lock:
|
|
waiter = self._response_waiters.pop(str(request_id), None)
|
|
if waiter is None:
|
|
return
|
|
|
|
if "error" in msg:
|
|
err = msg["error"]
|
|
if isinstance(err, dict):
|
|
waiter.put(
|
|
map_jsonrpc_error(
|
|
int(err.get("code", -32000)),
|
|
str(err.get("message", "unknown")),
|
|
err.get("data"),
|
|
)
|
|
)
|
|
else:
|
|
waiter.put(AppServerError("Malformed JSON-RPC error response"))
|
|
return
|
|
|
|
waiter.put(msg.get("result"))
|
|
|
|
def route_notification(self, notification: Notification) -> None:
|
|
"""Deliver a notification to a turn queue or the global queue."""
|
|
|
|
login_id = self._notification_login_id(notification)
|
|
if login_id is not None:
|
|
with self._lock:
|
|
login_queue = self._login_notifications.get(login_id)
|
|
if login_queue is None:
|
|
self._pending_login_notifications.setdefault(login_id, deque()).append(
|
|
notification
|
|
)
|
|
return
|
|
login_queue.put(notification)
|
|
return
|
|
|
|
turn_id = self._notification_turn_id(notification)
|
|
if turn_id is None:
|
|
self._global_notifications.put(notification)
|
|
return
|
|
|
|
with self._lock:
|
|
turn_queue = self._turn_notifications.get(turn_id)
|
|
if turn_queue is None:
|
|
if notification.method == "turn/completed":
|
|
self._pending_turn_notifications.pop(turn_id, None)
|
|
return
|
|
self._pending_turn_notifications.setdefault(turn_id, deque()).append(notification)
|
|
return
|
|
turn_queue.put(notification)
|
|
|
|
def fail_all(self, exc: BaseException) -> None:
|
|
"""Wake every blocked waiter when the reader thread exits."""
|
|
|
|
with self._lock:
|
|
response_waiters = list(self._response_waiters.values())
|
|
self._response_waiters.clear()
|
|
login_queues = list(self._login_notifications.values())
|
|
self._login_notifications.clear()
|
|
self._pending_login_notifications.clear()
|
|
turn_queues = list(self._turn_notifications.values())
|
|
self._pending_turn_notifications.clear()
|
|
# Put the same transport failure into every queue so no SDK call blocks
|
|
# forever waiting for a response that cannot arrive.
|
|
for waiter in response_waiters:
|
|
waiter.put(exc)
|
|
for login_queue in login_queues:
|
|
login_queue.put(exc)
|
|
for turn_queue in turn_queues:
|
|
turn_queue.put(exc)
|
|
self._global_notifications.put(exc)
|
|
|
|
def _notification_login_id(self, notification: Notification) -> str | None:
|
|
"""Extract the login attempt id from completion notifications."""
|
|
if notification.method != "account/login/completed":
|
|
return None
|
|
|
|
payload = notification.payload
|
|
if isinstance(payload, AccountLoginCompletedNotification):
|
|
return payload.login_id
|
|
if isinstance(payload, UnknownNotification):
|
|
raw_login_id = payload.params.get("loginId")
|
|
if isinstance(raw_login_id, str):
|
|
return raw_login_id
|
|
return None
|
|
|
|
def _notification_turn_id(self, notification: Notification) -> str | None:
|
|
"""Extract routing ids from known generated payloads or raw unknown payloads."""
|
|
payload = notification.payload
|
|
if isinstance(payload, UnknownNotification):
|
|
raw_turn_id = payload.params.get("turnId")
|
|
if isinstance(raw_turn_id, str):
|
|
return raw_turn_id
|
|
raw_turn = payload.params.get("turn")
|
|
if isinstance(raw_turn, dict):
|
|
raw_nested_turn_id = raw_turn.get("id")
|
|
if isinstance(raw_nested_turn_id, str):
|
|
return raw_nested_turn_id
|
|
return None
|
|
return notification_turn_id(payload)
|