From 787566c19e7bcc5a2c4ea4e8443e13e701124da8 Mon Sep 17 00:00:00 2001 From: Ahmed Ibrahim Date: Fri, 8 May 2026 19:43:43 +0300 Subject: [PATCH] Move Python SDK message router to module Co-authored-by: Codex --- .../src/codex_app_server/_message_router.py | 150 ++++++++++++++++++ sdk/python/src/codex_app_server/client.py | 117 +------------- 2 files changed, 153 insertions(+), 114 deletions(-) create mode 100644 sdk/python/src/codex_app_server/_message_router.py diff --git a/sdk/python/src/codex_app_server/_message_router.py b/sdk/python/src/codex_app_server/_message_router.py new file mode 100644 index 0000000000..ceac0304a1 --- /dev/null +++ b/sdk/python/src/codex_app_server/_message_router.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import queue +import threading +from collections import deque + +from .errors import AppServerError, map_jsonrpc_error +from .models import JsonValue, Notification + +ResponseQueueItem = JsonValue | BaseException +NotificationQueueItem = Notification | BaseException + + +class MessageRouter: + """Route reader-thread messages to the SDK operation waiting for them. + + The app-server stdio transport is a single ordered stream, so only the + reader thread should consume stdout. This router keeps the rest of the SDK + from competing for that stream by giving each in-flight JSON-RPC request + and active turn stream its own queue. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {} + self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {} + self._pending_turn_notifications: dict[str, deque[Notification]] = {} + self._global_notifications: queue.Queue[NotificationQueueItem] = queue.Queue() + + def create_response_waiter(self, request_id: str) -> queue.Queue[ResponseQueueItem]: + """Register a one-shot queue for a JSON-RPC response id.""" + + waiter: queue.Queue[ResponseQueueItem] = queue.Queue(maxsize=1) + with self._lock: + self._response_waiters[request_id] = waiter + return waiter + + def discard_response_waiter(self, request_id: str) -> None: + """Remove a response waiter when the request could not be written.""" + + with self._lock: + self._response_waiters.pop(request_id, None) + + def next_global_notification(self) -> Notification: + """Block until the next notification that is not scoped to a turn.""" + + item = self._global_notifications.get() + if isinstance(item, BaseException): + raise item + return item + + def register_turn(self, turn_id: str) -> None: + """Register a queue for a turn stream and replay early events.""" + + turn_queue: queue.Queue[NotificationQueueItem] = queue.Queue() + with self._lock: + if turn_id in self._turn_notifications: + return + # A turn can emit events immediately after turn/start, before the + # caller receives the TurnHandle and starts streaming. + pending = self._pending_turn_notifications.pop(turn_id, deque()) + self._turn_notifications[turn_id] = turn_queue + for notification in pending: + turn_queue.put(notification) + + def unregister_turn(self, turn_id: str) -> None: + """Stop routing future turn events to the stream queue.""" + + with self._lock: + self._turn_notifications.pop(turn_id, None) + + def next_turn_notification(self, turn_id: str) -> Notification: + """Block until the next notification for a registered turn.""" + + with self._lock: + turn_queue = self._turn_notifications.get(turn_id) + if turn_queue is None: + raise RuntimeError(f"turn {turn_id!r} is not registered for streaming") + item = turn_queue.get() + if isinstance(item, BaseException): + raise item + return item + + def route_response(self, msg: dict[str, JsonValue]) -> None: + """Deliver a JSON-RPC response or error to its request waiter.""" + + request_id = msg.get("id") + with self._lock: + waiter = self._response_waiters.pop(str(request_id), None) + if waiter is None: + return + + if "error" in msg: + err = msg["error"] + if isinstance(err, dict): + waiter.put( + map_jsonrpc_error( + int(err.get("code", -32000)), + str(err.get("message", "unknown")), + err.get("data"), + ) + ) + else: + waiter.put(AppServerError("Malformed JSON-RPC error response")) + return + + waiter.put(msg.get("result")) + + def route_notification(self, notification: Notification) -> None: + """Deliver a notification to a turn queue or the global queue.""" + + turn_id = self._notification_turn_id(notification) + if turn_id is None: + self._global_notifications.put(notification) + return + + with self._lock: + turn_queue = self._turn_notifications.get(turn_id) + if turn_queue is None: + self._pending_turn_notifications.setdefault(turn_id, deque()).append( + notification + ) + return + turn_queue.put(notification) + + def fail_all(self, exc: BaseException) -> None: + """Wake every blocked waiter when the reader thread exits.""" + + with self._lock: + response_waiters = list(self._response_waiters.values()) + self._response_waiters.clear() + turn_queues = list(self._turn_notifications.values()) + # Put the same transport failure into every queue so no SDK call blocks + # forever waiting for a response that cannot arrive. + for waiter in response_waiters: + waiter.put(exc) + for turn_queue in turn_queues: + turn_queue.put(exc) + self._global_notifications.put(exc) + + def _notification_turn_id(self, notification: Notification) -> str | None: + payload = notification.payload + turn_id = getattr(payload, "turn_id", None) + if isinstance(turn_id, str): + return turn_id + turn = getattr(payload, "turn", None) + nested_turn_id = getattr(turn, "id", None) + if isinstance(nested_turn_id, str): + return nested_turn_id + return None diff --git a/sdk/python/src/codex_app_server/client.py b/sdk/python/src/codex_app_server/client.py index ed1419f29a..59e2a3ca88 100644 --- a/sdk/python/src/codex_app_server/client.py +++ b/sdk/python/src/codex_app_server/client.py @@ -2,7 +2,6 @@ from __future__ import annotations import json import os -import queue import subprocess import threading import uuid @@ -13,7 +12,7 @@ from typing import Callable, Iterable, Iterator, TypeVar from pydantic import BaseModel -from .errors import AppServerError, TransportClosedError, map_jsonrpc_error +from .errors import AppServerError, TransportClosedError from .generated.notification_registry import NOTIFICATION_MODELS from .generated.v2_all import ( AgentMessageDeltaNotification, @@ -44,14 +43,13 @@ from .models import ( Notification, UnknownNotification, ) +from ._message_router import MessageRouter from .retry import retry_on_overload from ._version import __version__ as SDK_VERSION ModelT = TypeVar("ModelT", bound=BaseModel) ApprovalHandler = Callable[[str, JsonObject | None], JsonObject] RUNTIME_PKG_NAME = "openai-codex-cli-bin" -ResponseQueueItem = JsonValue | BaseException -NotificationQueueItem = Notification | BaseException def _params_dict( @@ -139,115 +137,6 @@ class AppServerConfig: experimental_api: bool = True -class _MessageRouter: - def __init__(self) -> None: - self._lock = threading.Lock() - self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {} - self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {} - self._pending_turn_notifications: dict[str, deque[Notification]] = {} - self._global_notifications: queue.Queue[NotificationQueueItem] = queue.Queue() - - def create_response_waiter(self, request_id: str) -> queue.Queue[ResponseQueueItem]: - waiter: queue.Queue[ResponseQueueItem] = queue.Queue(maxsize=1) - with self._lock: - self._response_waiters[request_id] = waiter - return waiter - - def discard_response_waiter(self, request_id: str) -> None: - with self._lock: - self._response_waiters.pop(request_id, None) - - def next_global_notification(self) -> Notification: - item = self._global_notifications.get() - if isinstance(item, BaseException): - raise item - return item - - def register_turn(self, turn_id: str) -> None: - turn_queue: queue.Queue[NotificationQueueItem] = queue.Queue() - with self._lock: - if turn_id in self._turn_notifications: - return - pending = self._pending_turn_notifications.pop(turn_id, deque()) - self._turn_notifications[turn_id] = turn_queue - for notification in pending: - turn_queue.put(notification) - - def unregister_turn(self, turn_id: str) -> None: - with self._lock: - self._turn_notifications.pop(turn_id, None) - - def next_turn_notification(self, turn_id: str) -> Notification: - with self._lock: - turn_queue = self._turn_notifications.get(turn_id) - if turn_queue is None: - raise RuntimeError(f"turn {turn_id!r} is not registered for streaming") - item = turn_queue.get() - if isinstance(item, BaseException): - raise item - return item - - def route_response(self, msg: dict[str, JsonValue]) -> None: - request_id = msg.get("id") - with self._lock: - waiter = self._response_waiters.pop(str(request_id), None) - if waiter is None: - return - - if "error" in msg: - err = msg["error"] - if isinstance(err, dict): - waiter.put( - map_jsonrpc_error( - int(err.get("code", -32000)), - str(err.get("message", "unknown")), - err.get("data"), - ) - ) - else: - waiter.put(AppServerError("Malformed JSON-RPC error response")) - return - - waiter.put(msg.get("result")) - - def route_notification(self, notification: Notification) -> None: - turn_id = self._notification_turn_id(notification) - if turn_id is None: - self._global_notifications.put(notification) - return - - with self._lock: - turn_queue = self._turn_notifications.get(turn_id) - if turn_queue is None: - self._pending_turn_notifications.setdefault(turn_id, deque()).append( - notification - ) - return - turn_queue.put(notification) - - def fail_all(self, exc: BaseException) -> None: - with self._lock: - response_waiters = list(self._response_waiters.values()) - self._response_waiters.clear() - turn_queues = list(self._turn_notifications.values()) - for waiter in response_waiters: - waiter.put(exc) - for turn_queue in turn_queues: - turn_queue.put(exc) - self._global_notifications.put(exc) - - def _notification_turn_id(self, notification: Notification) -> str | None: - payload = notification.payload - turn_id = getattr(payload, "turn_id", None) - if isinstance(turn_id, str): - return turn_id - turn = getattr(payload, "turn", None) - nested_turn_id = getattr(turn, "id", None) - if isinstance(nested_turn_id, str): - return nested_turn_id - return None - - class AppServerClient: """Synchronous typed JSON-RPC client for `codex app-server` over stdio.""" @@ -260,7 +149,7 @@ class AppServerClient: self._approval_handler = approval_handler or self._default_approval_handler self._proc: subprocess.Popen[str] | None = None self._lock = threading.Lock() - self._router = _MessageRouter() + self._router = MessageRouter() self._stderr_lines: deque[str] = deque(maxlen=400) self._stderr_thread: threading.Thread | None = None self._reader_thread: threading.Thread | None = None