From 0ea082ab6626ca13bcaf65e0c2c0df71153dfe18 Mon Sep 17 00:00:00 2001 From: Ahmed Ibrahim Date: Fri, 8 May 2026 19:42:01 +0300 Subject: [PATCH] Extract Python SDK message router Co-authored-by: Codex --- sdk/python/src/codex_app_server/client.py | 214 +++++++++++--------- sdk/python/tests/test_client_rpc_methods.py | 6 +- 2 files changed, 122 insertions(+), 98 deletions(-) diff --git a/sdk/python/src/codex_app_server/client.py b/sdk/python/src/codex_app_server/client.py index 9b46e6cd84..ed1419f29a 100644 --- a/sdk/python/src/codex_app_server/client.py +++ b/sdk/python/src/codex_app_server/client.py @@ -139,6 +139,115 @@ class AppServerConfig: experimental_api: bool = True +class _MessageRouter: + def __init__(self) -> None: + self._lock = threading.Lock() + self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {} + self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {} + self._pending_turn_notifications: dict[str, deque[Notification]] = {} + self._global_notifications: queue.Queue[NotificationQueueItem] = queue.Queue() + + def create_response_waiter(self, request_id: str) -> queue.Queue[ResponseQueueItem]: + waiter: queue.Queue[ResponseQueueItem] = queue.Queue(maxsize=1) + with self._lock: + self._response_waiters[request_id] = waiter + return waiter + + def discard_response_waiter(self, request_id: str) -> None: + with self._lock: + self._response_waiters.pop(request_id, None) + + def next_global_notification(self) -> Notification: + item = self._global_notifications.get() + if isinstance(item, BaseException): + raise item + return item + + def register_turn(self, turn_id: str) -> None: + turn_queue: queue.Queue[NotificationQueueItem] = queue.Queue() + with self._lock: + if turn_id in self._turn_notifications: + return + pending = self._pending_turn_notifications.pop(turn_id, deque()) + self._turn_notifications[turn_id] = turn_queue + for notification in pending: + turn_queue.put(notification) + + def unregister_turn(self, turn_id: str) -> None: + with self._lock: + self._turn_notifications.pop(turn_id, None) + + def next_turn_notification(self, turn_id: str) -> Notification: + with self._lock: + turn_queue = self._turn_notifications.get(turn_id) + if turn_queue is None: + raise RuntimeError(f"turn {turn_id!r} is not registered for streaming") + item = turn_queue.get() + if isinstance(item, BaseException): + raise item + return item + + def route_response(self, msg: dict[str, JsonValue]) -> None: + request_id = msg.get("id") + with self._lock: + waiter = self._response_waiters.pop(str(request_id), None) + if waiter is None: + return + + if "error" in msg: + err = msg["error"] + if isinstance(err, dict): + waiter.put( + map_jsonrpc_error( + int(err.get("code", -32000)), + str(err.get("message", "unknown")), + err.get("data"), + ) + ) + else: + waiter.put(AppServerError("Malformed JSON-RPC error response")) + return + + waiter.put(msg.get("result")) + + def route_notification(self, notification: Notification) -> None: + turn_id = self._notification_turn_id(notification) + if turn_id is None: + self._global_notifications.put(notification) + return + + with self._lock: + turn_queue = self._turn_notifications.get(turn_id) + if turn_queue is None: + self._pending_turn_notifications.setdefault(turn_id, deque()).append( + notification + ) + return + turn_queue.put(notification) + + def fail_all(self, exc: BaseException) -> None: + with self._lock: + response_waiters = list(self._response_waiters.values()) + self._response_waiters.clear() + turn_queues = list(self._turn_notifications.values()) + for waiter in response_waiters: + waiter.put(exc) + for turn_queue in turn_queues: + turn_queue.put(exc) + self._global_notifications.put(exc) + + def _notification_turn_id(self, notification: Notification) -> str | None: + payload = notification.payload + turn_id = getattr(payload, "turn_id", None) + if isinstance(turn_id, str): + return turn_id + turn = getattr(payload, "turn", None) + nested_turn_id = getattr(turn, "id", None) + if isinstance(nested_turn_id, str): + return nested_turn_id + return None + + class AppServerClient: """Synchronous typed JSON-RPC client for `codex app-server` over stdio.""" @@ -151,11 +260,7 @@ class AppServerClient: self._approval_handler = approval_handler or self._default_approval_handler self._proc: subprocess.Popen[str] | None = None self._lock = threading.Lock() - self._router_lock = threading.Lock() - self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {} - self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {} - self._pending_turn_notifications: dict[str, deque[Notification]] = {} - self._global_notifications: queue.Queue[NotificationQueueItem] = queue.Queue() + self._router = _MessageRouter() self._stderr_lines: deque[str] = deque(maxlen=400) self._stderr_thread: threading.Thread | None = None self._reader_thread: threading.Thread | None = None @@ -250,17 +355,14 @@ class AppServerClient: def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue: request_id = str(uuid.uuid4()) - waiter: queue.Queue[ResponseQueueItem] = queue.Queue(maxsize=1) - with self._router_lock: - self._response_waiters[request_id] = waiter + waiter = self._router.create_response_waiter(request_id) try: self._write_message( {"id": request_id, "method": method, "params": params or {}} ) except BaseException: - with self._router_lock: - self._response_waiters.pop(request_id, None) + self._router.discard_response_waiter(request_id) raise item = waiter.get() @@ -272,34 +374,16 @@ class AppServerClient: self._write_message({"method": method, "params": params or {}}) def next_notification(self) -> Notification: - item = self._global_notifications.get() - if isinstance(item, BaseException): - raise item - return item + return self._router.next_global_notification() def register_turn_notifications(self, turn_id: str) -> None: - turn_queue: queue.Queue[NotificationQueueItem] = queue.Queue() - with self._router_lock: - if turn_id in self._turn_notifications: - return - pending = self._pending_turn_notifications.pop(turn_id, deque()) - self._turn_notifications[turn_id] = turn_queue - for notification in pending: - turn_queue.put(notification) + self._router.register_turn(turn_id) def unregister_turn_notifications(self, turn_id: str) -> None: - with self._router_lock: - self._turn_notifications.pop(turn_id, None) + self._router.unregister_turn(turn_id) def next_turn_notification(self, turn_id: str) -> Notification: - with self._router_lock: - turn_queue = self._turn_notifications.get(turn_id) - if turn_queue is None: - raise RuntimeError(f"turn {turn_id!r} is not registered for streaming") - item = turn_queue.get() - if isinstance(item, BaseException): - raise item - return item + return self._router.next_turn_notification(turn_id) def thread_start( self, params: V2ThreadStartParams | JsonObject | None = None @@ -550,77 +634,17 @@ class AppServerClient: if "method" in msg and "id" not in msg: method = msg["method"] if isinstance(method, str): - self._route_notification( + self._router.route_notification( self._coerce_notification(method, msg.get("params")) ) continue - self._route_response(msg) + self._router.route_response(msg) except BaseException as exc: - self._fail_pending_queues(exc) + self._router.fail_all(exc) def _stderr_tail(self, limit: int = 40) -> str: return "\n".join(list(self._stderr_lines)[-limit:]) - def _route_response(self, msg: dict[str, JsonValue]) -> None: - request_id = msg.get("id") - with self._router_lock: - waiter = self._response_waiters.pop(str(request_id), None) - if waiter is None: - return - - if "error" in msg: - err = msg["error"] - if isinstance(err, dict): - waiter.put( - map_jsonrpc_error( - int(err.get("code", -32000)), - str(err.get("message", "unknown")), - err.get("data"), - ) - ) - else: - waiter.put(AppServerError("Malformed JSON-RPC error response")) - return - - waiter.put(msg.get("result")) - - def _route_notification(self, notification: Notification) -> None: - turn_id = self._notification_turn_id(notification) - if turn_id is None: - self._global_notifications.put(notification) - return - - with self._router_lock: - turn_queue = self._turn_notifications.get(turn_id) - if turn_queue is None: - self._pending_turn_notifications.setdefault(turn_id, deque()).append( - notification - ) - return - turn_queue.put(notification) - - def _notification_turn_id(self, notification: Notification) -> str | None: - payload = notification.payload - turn_id = getattr(payload, "turn_id", None) - if isinstance(turn_id, str): - return turn_id - turn = getattr(payload, "turn", None) - nested_turn_id = getattr(turn, "id", None) - if isinstance(nested_turn_id, str): - return nested_turn_id - return None - - def _fail_pending_queues(self, exc: BaseException) -> None: - with self._router_lock: - response_waiters = list(self._response_waiters.values()) - self._response_waiters.clear() - turn_queues = list(self._turn_notifications.values()) - for waiter in response_waiters: - waiter.put(exc) - for turn_queue in turn_queues: - turn_queue.put(exc) - self._global_notifications.put(exc) - def _handle_server_request(self, msg: dict[str, JsonValue]) -> JsonObject: method = msg["method"] params = msg.get("params") diff --git a/sdk/python/tests/test_client_rpc_methods.py b/sdk/python/tests/test_client_rpc_methods.py index 693cd4077a..194d9211e8 100644 --- a/sdk/python/tests/test_client_rpc_methods.py +++ b/sdk/python/tests/test_client_rpc_methods.py @@ -135,7 +135,7 @@ def test_turn_notification_router_demuxes_registered_turns() -> None: client.register_turn_notifications("turn-1") client.register_turn_notifications("turn-2") - client._route_notification( + client._router.route_notification( client._coerce_notification( "item/agentMessage/delta", { @@ -146,7 +146,7 @@ def test_turn_notification_router_demuxes_registered_turns() -> None: }, ) ) - client._route_notification( + client._router.route_notification( client._coerce_notification( "item/agentMessage/delta", { @@ -172,7 +172,7 @@ def test_turn_notification_router_demuxes_registered_turns() -> None: def test_turn_notification_router_buffers_events_before_registration() -> None: client = AppServerClient() - client._route_notification( + client._router.route_notification( client._coerce_notification( "turn/completed", {