Extract Python SDK message router

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Ahmed Ibrahim
2026-05-08 19:42:01 +03:00
parent 064d4b4937
commit 0ea082ab66
2 changed files with 122 additions and 98 deletions

View File

@@ -139,6 +139,115 @@ class AppServerConfig:
experimental_api: bool = True
class _MessageRouter:
def __init__(self) -> None:
self._lock = threading.Lock()
self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {}
self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {}
self._pending_turn_notifications: dict[str, deque[Notification]] = {}
self._global_notifications: queue.Queue[NotificationQueueItem] = queue.Queue()
def create_response_waiter(self, request_id: str) -> queue.Queue[ResponseQueueItem]:
waiter: queue.Queue[ResponseQueueItem] = queue.Queue(maxsize=1)
with self._lock:
self._response_waiters[request_id] = waiter
return waiter
def discard_response_waiter(self, request_id: str) -> None:
with self._lock:
self._response_waiters.pop(request_id, None)
def next_global_notification(self) -> Notification:
item = self._global_notifications.get()
if isinstance(item, BaseException):
raise item
return item
def register_turn(self, turn_id: str) -> None:
turn_queue: queue.Queue[NotificationQueueItem] = queue.Queue()
with self._lock:
if turn_id in self._turn_notifications:
return
pending = self._pending_turn_notifications.pop(turn_id, deque())
self._turn_notifications[turn_id] = turn_queue
for notification in pending:
turn_queue.put(notification)
def unregister_turn(self, turn_id: str) -> None:
with self._lock:
self._turn_notifications.pop(turn_id, None)
def next_turn_notification(self, turn_id: str) -> Notification:
with self._lock:
turn_queue = self._turn_notifications.get(turn_id)
if turn_queue is None:
raise RuntimeError(f"turn {turn_id!r} is not registered for streaming")
item = turn_queue.get()
if isinstance(item, BaseException):
raise item
return item
def route_response(self, msg: dict[str, JsonValue]) -> None:
request_id = msg.get("id")
with self._lock:
waiter = self._response_waiters.pop(str(request_id), None)
if waiter is None:
return
if "error" in msg:
err = msg["error"]
if isinstance(err, dict):
waiter.put(
map_jsonrpc_error(
int(err.get("code", -32000)),
str(err.get("message", "unknown")),
err.get("data"),
)
)
else:
waiter.put(AppServerError("Malformed JSON-RPC error response"))
return
waiter.put(msg.get("result"))
def route_notification(self, notification: Notification) -> None:
turn_id = self._notification_turn_id(notification)
if turn_id is None:
self._global_notifications.put(notification)
return
with self._lock:
turn_queue = self._turn_notifications.get(turn_id)
if turn_queue is None:
self._pending_turn_notifications.setdefault(turn_id, deque()).append(
notification
)
return
turn_queue.put(notification)
def fail_all(self, exc: BaseException) -> None:
with self._lock:
response_waiters = list(self._response_waiters.values())
self._response_waiters.clear()
turn_queues = list(self._turn_notifications.values())
for waiter in response_waiters:
waiter.put(exc)
for turn_queue in turn_queues:
turn_queue.put(exc)
self._global_notifications.put(exc)
def _notification_turn_id(self, notification: Notification) -> str | None:
payload = notification.payload
turn_id = getattr(payload, "turn_id", None)
if isinstance(turn_id, str):
return turn_id
turn = getattr(payload, "turn", None)
nested_turn_id = getattr(turn, "id", None)
if isinstance(nested_turn_id, str):
return nested_turn_id
return None
class AppServerClient:
"""Synchronous typed JSON-RPC client for `codex app-server` over stdio."""
@@ -151,11 +260,7 @@ class AppServerClient:
self._approval_handler = approval_handler or self._default_approval_handler
self._proc: subprocess.Popen[str] | None = None
self._lock = threading.Lock()
self._router_lock = threading.Lock()
self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {}
self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {}
self._pending_turn_notifications: dict[str, deque[Notification]] = {}
self._global_notifications: queue.Queue[NotificationQueueItem] = queue.Queue()
self._router = _MessageRouter()
self._stderr_lines: deque[str] = deque(maxlen=400)
self._stderr_thread: threading.Thread | None = None
self._reader_thread: threading.Thread | None = None
@@ -250,17 +355,14 @@ class AppServerClient:
def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue:
request_id = str(uuid.uuid4())
waiter: queue.Queue[ResponseQueueItem] = queue.Queue(maxsize=1)
with self._router_lock:
self._response_waiters[request_id] = waiter
waiter = self._router.create_response_waiter(request_id)
try:
self._write_message(
{"id": request_id, "method": method, "params": params or {}}
)
except BaseException:
with self._router_lock:
self._response_waiters.pop(request_id, None)
self._router.discard_response_waiter(request_id)
raise
item = waiter.get()
@@ -272,34 +374,16 @@ class AppServerClient:
self._write_message({"method": method, "params": params or {}})
def next_notification(self) -> Notification:
item = self._global_notifications.get()
if isinstance(item, BaseException):
raise item
return item
return self._router.next_global_notification()
def register_turn_notifications(self, turn_id: str) -> None:
turn_queue: queue.Queue[NotificationQueueItem] = queue.Queue()
with self._router_lock:
if turn_id in self._turn_notifications:
return
pending = self._pending_turn_notifications.pop(turn_id, deque())
self._turn_notifications[turn_id] = turn_queue
for notification in pending:
turn_queue.put(notification)
self._router.register_turn(turn_id)
def unregister_turn_notifications(self, turn_id: str) -> None:
with self._router_lock:
self._turn_notifications.pop(turn_id, None)
self._router.unregister_turn(turn_id)
def next_turn_notification(self, turn_id: str) -> Notification:
with self._router_lock:
turn_queue = self._turn_notifications.get(turn_id)
if turn_queue is None:
raise RuntimeError(f"turn {turn_id!r} is not registered for streaming")
item = turn_queue.get()
if isinstance(item, BaseException):
raise item
return item
return self._router.next_turn_notification(turn_id)
def thread_start(
self, params: V2ThreadStartParams | JsonObject | None = None
@@ -550,77 +634,17 @@ class AppServerClient:
if "method" in msg and "id" not in msg:
method = msg["method"]
if isinstance(method, str):
self._route_notification(
self._router.route_notification(
self._coerce_notification(method, msg.get("params"))
)
continue
self._route_response(msg)
self._router.route_response(msg)
except BaseException as exc:
self._fail_pending_queues(exc)
self._router.fail_all(exc)
def _stderr_tail(self, limit: int = 40) -> str:
return "\n".join(list(self._stderr_lines)[-limit:])
def _route_response(self, msg: dict[str, JsonValue]) -> None:
request_id = msg.get("id")
with self._router_lock:
waiter = self._response_waiters.pop(str(request_id), None)
if waiter is None:
return
if "error" in msg:
err = msg["error"]
if isinstance(err, dict):
waiter.put(
map_jsonrpc_error(
int(err.get("code", -32000)),
str(err.get("message", "unknown")),
err.get("data"),
)
)
else:
waiter.put(AppServerError("Malformed JSON-RPC error response"))
return
waiter.put(msg.get("result"))
def _route_notification(self, notification: Notification) -> None:
turn_id = self._notification_turn_id(notification)
if turn_id is None:
self._global_notifications.put(notification)
return
with self._router_lock:
turn_queue = self._turn_notifications.get(turn_id)
if turn_queue is None:
self._pending_turn_notifications.setdefault(turn_id, deque()).append(
notification
)
return
turn_queue.put(notification)
def _notification_turn_id(self, notification: Notification) -> str | None:
payload = notification.payload
turn_id = getattr(payload, "turn_id", None)
if isinstance(turn_id, str):
return turn_id
turn = getattr(payload, "turn", None)
nested_turn_id = getattr(turn, "id", None)
if isinstance(nested_turn_id, str):
return nested_turn_id
return None
def _fail_pending_queues(self, exc: BaseException) -> None:
with self._router_lock:
response_waiters = list(self._response_waiters.values())
self._response_waiters.clear()
turn_queues = list(self._turn_notifications.values())
for waiter in response_waiters:
waiter.put(exc)
for turn_queue in turn_queues:
turn_queue.put(exc)
self._global_notifications.put(exc)
def _handle_server_request(self, msg: dict[str, JsonValue]) -> JsonObject:
method = msg["method"]
params = msg.get("params")