Files
codex/sdk/python/src/codex_app_server/_message_router.py
Ahmed Ibrahim ebe75bb683 Route Python SDK turn notifications by ID (#21778)
## Why

The Python SDK previously protected the stdio transport with a single
active turn-consumer guard. That avoided competing reads from stdout,
but it also meant one `Codex`/`AsyncCodex` client could not stream
multiple active turns at the same time. Notifications could also arrive
before the caller received a `TurnHandle` and registered for streaming,
so the SDK needed an explicit routing layer instead of letting
individual API calls read directly from the shared transport.

## What Changed

- Added a private `MessageRouter` that owns per-request response queues,
per-turn notification queues, pending turn-notification replay, and
global notification delivery behind a single stdout reader thread.
- Generated typed notification routing metadata so turn IDs come from
known payload shapes instead of router-side attribute guessing, with
explicit fallback handling for unknown notification payloads.
- Updated sync and async turn streaming so `TurnHandle.stream()`/`run()`
and `stream_text()` consume only notifications for their own turn ID,
while `AsyncAppServerClient` no longer serializes all transport calls
behind one async lock.
- Cleared pending turn-notification buffers when unregistered turns
complete so never-consumed turn handles do not leave stale queues
behind.
- Removed the internal stream-until helper now that turn completion
waiting can register directly with routed turn notifications.
- Updated Python SDK docs and focused tests for concurrent transport
calls, interleaved turn routing, buffered early notifications, unknown
notification routing, async delegation, and routed turn completion
behavior.

## Validation

- `uv run --extra dev ruff format scripts/update_sdk_artifacts.py
src/codex_app_server/_message_router.py src/codex_app_server/client.py
src/codex_app_server/generated/notification_registry.py
tests/test_client_rpc_methods.py
tests/test_public_api_runtime_behavior.py
tests/test_async_client_behavior.py`
- `uv run --extra dev ruff check scripts/update_sdk_artifacts.py
src/codex_app_server/_message_router.py src/codex_app_server/client.py
src/codex_app_server/generated/notification_registry.py
tests/test_client_rpc_methods.py
tests/test_public_api_runtime_behavior.py
tests/test_async_client_behavior.py`
- `uv run --extra dev pytest tests/test_client_rpc_methods.py
tests/test_public_api_runtime_behavior.py
tests/test_async_client_behavior.py`
- `git diff --check`

---------

Co-authored-by: Codex <noreply@openai.com>
2026-05-09 04:16:23 +00:00

159 lines
6.1 KiB
Python

from __future__ import annotations
import queue
import threading
from collections import deque
from .errors import AppServerError, map_jsonrpc_error
from .generated.notification_registry import notification_turn_id
from .models import JsonValue, Notification, UnknownNotification
ResponseQueueItem = JsonValue | BaseException
NotificationQueueItem = Notification | BaseException
class MessageRouter:
"""Route reader-thread messages to the SDK operation waiting for them.
The app-server stdio transport is a single ordered stream, so only the
reader thread should consume stdout. This router keeps the rest of the SDK
from competing for that stream by giving each in-flight JSON-RPC request
and active turn stream its own queue.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {}
self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {}
self._pending_turn_notifications: dict[str, deque[Notification]] = {}
self._global_notifications: queue.Queue[NotificationQueueItem] = queue.Queue()
def create_response_waiter(self, request_id: str) -> queue.Queue[ResponseQueueItem]:
"""Register a one-shot queue for a JSON-RPC response id."""
waiter: queue.Queue[ResponseQueueItem] = queue.Queue(maxsize=1)
with self._lock:
self._response_waiters[request_id] = waiter
return waiter
def discard_response_waiter(self, request_id: str) -> None:
"""Remove a response waiter when the request could not be written."""
with self._lock:
self._response_waiters.pop(request_id, None)
def next_global_notification(self) -> Notification:
"""Block until the next notification that is not scoped to a turn."""
item = self._global_notifications.get()
if isinstance(item, BaseException):
raise item
return item
def register_turn(self, turn_id: str) -> None:
"""Register a queue for a turn stream and replay early events."""
turn_queue: queue.Queue[NotificationQueueItem] = queue.Queue()
with self._lock:
if turn_id in self._turn_notifications:
return
# A turn can emit events immediately after turn/start, before the
# caller receives the TurnHandle and starts streaming.
pending = self._pending_turn_notifications.pop(turn_id, deque())
self._turn_notifications[turn_id] = turn_queue
for notification in pending:
turn_queue.put(notification)
def unregister_turn(self, turn_id: str) -> None:
"""Stop routing future turn events to the stream queue."""
with self._lock:
self._turn_notifications.pop(turn_id, None)
def next_turn_notification(self, turn_id: str) -> Notification:
"""Block until the next notification for a registered turn."""
with self._lock:
turn_queue = self._turn_notifications.get(turn_id)
if turn_queue is None:
raise RuntimeError(f"turn {turn_id!r} is not registered for streaming")
item = turn_queue.get()
if isinstance(item, BaseException):
raise item
return item
def route_response(self, msg: dict[str, JsonValue]) -> None:
"""Deliver a JSON-RPC response or error to its request waiter."""
request_id = msg.get("id")
with self._lock:
waiter = self._response_waiters.pop(str(request_id), None)
if waiter is None:
return
if "error" in msg:
err = msg["error"]
if isinstance(err, dict):
waiter.put(
map_jsonrpc_error(
int(err.get("code", -32000)),
str(err.get("message", "unknown")),
err.get("data"),
)
)
else:
waiter.put(AppServerError("Malformed JSON-RPC error response"))
return
waiter.put(msg.get("result"))
def route_notification(self, notification: Notification) -> None:
"""Deliver a notification to a turn queue or the global queue."""
turn_id = self._notification_turn_id(notification)
if turn_id is None:
self._global_notifications.put(notification)
return
with self._lock:
turn_queue = self._turn_notifications.get(turn_id)
if turn_queue is None:
if notification.method == "turn/completed":
self._pending_turn_notifications.pop(turn_id, None)
return
self._pending_turn_notifications.setdefault(turn_id, deque()).append(
notification
)
return
turn_queue.put(notification)
def fail_all(self, exc: BaseException) -> None:
"""Wake every blocked waiter when the reader thread exits."""
with self._lock:
response_waiters = list(self._response_waiters.values())
self._response_waiters.clear()
turn_queues = list(self._turn_notifications.values())
self._pending_turn_notifications.clear()
# Put the same transport failure into every queue so no SDK call blocks
# forever waiting for a response that cannot arrive.
for waiter in response_waiters:
waiter.put(exc)
for turn_queue in turn_queues:
turn_queue.put(exc)
self._global_notifications.put(exc)
def _notification_turn_id(self, notification: Notification) -> str | None:
payload = notification.payload
if isinstance(payload, UnknownNotification):
raw_turn_id = payload.params.get("turnId")
if isinstance(raw_turn_id, str):
return raw_turn_id
raw_turn = payload.params.get("turn")
if isinstance(raw_turn, dict):
raw_nested_turn_id = raw_turn.get("id")
if isinstance(raw_nested_turn_id, str):
return raw_nested_turn_id
return None
return notification_turn_id(payload)