Document SDK turn routing helpers

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Ahmed Ibrahim
2026-05-09 10:23:06 +03:00
parent 11e31d7d38
commit becbd2a127
10 changed files with 103 additions and 0 deletions

View File

@@ -22,6 +22,7 @@ class MessageRouter:
"""
def __init__(self) -> None:
"""Create empty response, turn, and global notification queues."""
self._lock = threading.Lock()
self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {}
self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {}
@@ -144,6 +145,7 @@ class MessageRouter:
self._global_notifications.put(exc)
def _notification_turn_id(self, notification: Notification) -> str | None:
"""Extract routing ids from known generated payloads or raw unknown payloads."""
payload = notification.payload
if isinstance(payload, UnknownNotification):
raw_turn_id = payload.params.get("turnId")

View File

@@ -678,6 +678,7 @@ class TurnHandle:
return self._client.turn_interrupt(self.thread_id, self.id)
def stream(self) -> Iterator[Notification]:
"""Yield only notifications routed to this turn handle."""
self._client.register_turn_notifications(self.id)
try:
while True:
@@ -730,6 +731,7 @@ class AsyncTurnHandle:
return await self._codex._client.turn_interrupt(self.thread_id, self.id)
async def stream(self) -> AsyncIterator[Notification]:
"""Yield only notifications routed to this async turn handle."""
await self._codex._ensure_initialized()
self._codex._client.register_turn_notifications(self.id)
try:

View File

@@ -40,13 +40,16 @@ class AsyncAppServerClient:
"""Async wrapper around AppServerClient using thread offloading."""
def __init__(self, config: AppServerConfig | None = None) -> None:
"""Create the wrapped sync client that owns the transport process."""
self._sync = AppServerClient(config=config)
async def __aenter__(self) -> "AsyncAppServerClient":
"""Start the app-server process when entering an async context."""
await self.start()
return self
async def __aexit__(self, _exc_type, _exc, _tb) -> None:
"""Close the app-server process when leaving an async context."""
await self.close()
async def _call_sync(
@@ -56,30 +59,37 @@ class AsyncAppServerClient:
*args: ParamsT.args,
**kwargs: ParamsT.kwargs,
) -> ReturnT:
"""Run a blocking sync-client operation without blocking the event loop."""
return await asyncio.to_thread(fn, *args, **kwargs)
@staticmethod
def _next_from_iterator(
iterator: Iterator[AgentMessageDeltaNotification],
) -> tuple[bool, AgentMessageDeltaNotification | None]:
"""Convert StopIteration into a value that can cross asyncio.to_thread."""
try:
return True, next(iterator)
except StopIteration:
return False, None
async def start(self) -> None:
"""Start the wrapped sync client in a worker thread."""
await self._call_sync(self._sync.start)
async def close(self) -> None:
"""Close the wrapped sync client in a worker thread."""
await self._call_sync(self._sync.close)
async def initialize(self) -> InitializeResponse:
"""Initialize the app-server session."""
return await self._call_sync(self._sync.initialize)
def register_turn_notifications(self, turn_id: str) -> None:
"""Register a turn notification queue on the wrapped sync client."""
self._sync.register_turn_notifications(turn_id)
def unregister_turn_notifications(self, turn_id: str) -> None:
"""Unregister a turn notification queue on the wrapped sync client."""
self._sync.unregister_turn_notifications(turn_id)
async def request(
@@ -89,6 +99,7 @@ class AsyncAppServerClient:
*,
response_model: type[ModelT],
) -> ModelT:
"""Send a typed JSON-RPC request through the wrapped sync client."""
return await self._call_sync(
self._sync.request,
method,
@@ -99,6 +110,7 @@ class AsyncAppServerClient:
async def thread_start(
self, params: V2ThreadStartParams | JsonObject | None = None
) -> ThreadStartResponse:
"""Start a thread using the wrapped sync client."""
return await self._call_sync(self._sync.thread_start, params)
async def thread_resume(
@@ -106,16 +118,19 @@ class AsyncAppServerClient:
thread_id: str,
params: V2ThreadResumeParams | JsonObject | None = None,
) -> ThreadResumeResponse:
"""Resume a thread using the wrapped sync client."""
return await self._call_sync(self._sync.thread_resume, thread_id, params)
async def thread_list(
self, params: V2ThreadListParams | JsonObject | None = None
) -> ThreadListResponse:
"""List threads using the wrapped sync client."""
return await self._call_sync(self._sync.thread_list, params)
async def thread_read(
self, thread_id: str, include_turns: bool = False
) -> ThreadReadResponse:
"""Read a thread using the wrapped sync client."""
return await self._call_sync(self._sync.thread_read, thread_id, include_turns)
async def thread_fork(
@@ -123,18 +138,23 @@ class AsyncAppServerClient:
thread_id: str,
params: V2ThreadForkParams | JsonObject | None = None,
) -> ThreadForkResponse:
"""Fork a thread using the wrapped sync client."""
return await self._call_sync(self._sync.thread_fork, thread_id, params)
async def thread_archive(self, thread_id: str) -> ThreadArchiveResponse:
"""Archive a thread using the wrapped sync client."""
return await self._call_sync(self._sync.thread_archive, thread_id)
async def thread_unarchive(self, thread_id: str) -> ThreadUnarchiveResponse:
"""Unarchive a thread using the wrapped sync client."""
return await self._call_sync(self._sync.thread_unarchive, thread_id)
async def thread_set_name(self, thread_id: str, name: str) -> ThreadSetNameResponse:
"""Rename a thread using the wrapped sync client."""
return await self._call_sync(self._sync.thread_set_name, thread_id, name)
async def thread_compact(self, thread_id: str) -> ThreadCompactStartResponse:
"""Start thread compaction using the wrapped sync client."""
return await self._call_sync(self._sync.thread_compact, thread_id)
async def turn_start(
@@ -143,6 +163,7 @@ class AsyncAppServerClient:
input_items: list[JsonObject] | JsonObject | str,
params: V2TurnStartParams | JsonObject | None = None,
) -> TurnStartResponse:
"""Start a turn using the wrapped sync client."""
return await self._call_sync(
self._sync.turn_start, thread_id, input_items, params
)
@@ -150,6 +171,7 @@ class AsyncAppServerClient:
async def turn_interrupt(
self, thread_id: str, turn_id: str
) -> TurnInterruptResponse:
"""Interrupt a turn using the wrapped sync client."""
return await self._call_sync(self._sync.turn_interrupt, thread_id, turn_id)
async def turn_steer(
@@ -158,6 +180,7 @@ class AsyncAppServerClient:
expected_turn_id: str,
input_items: list[JsonObject] | JsonObject | str,
) -> TurnSteerResponse:
"""Send steering input to a turn using the wrapped sync client."""
return await self._call_sync(
self._sync.turn_steer,
thread_id,
@@ -166,6 +189,7 @@ class AsyncAppServerClient:
)
async def model_list(self, include_hidden: bool = False) -> ModelListResponse:
"""List models using the wrapped sync client."""
return await self._call_sync(self._sync.model_list, include_hidden)
async def request_with_retry_on_overload(
@@ -178,6 +202,7 @@ class AsyncAppServerClient:
initial_delay_s: float = 0.25,
max_delay_s: float = 2.0,
) -> ModelT:
"""Send a typed request with the sync client's overload retry policy."""
return await self._call_sync(
self._sync.request_with_retry_on_overload,
method,
@@ -189,12 +214,15 @@ class AsyncAppServerClient:
)
async def next_notification(self) -> Notification:
"""Wait for the next global notification without blocking the event loop."""
return await self._call_sync(self._sync.next_notification)
async def next_turn_notification(self, turn_id: str) -> Notification:
"""Wait for the next notification routed to one turn."""
return await self._call_sync(self._sync.next_turn_notification, turn_id)
async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
"""Wait for the completion notification routed to one turn."""
return await self._call_sync(self._sync.wait_for_turn_completed, turn_id)
async def stream_text(
@@ -203,6 +231,7 @@ class AsyncAppServerClient:
text: str,
params: V2TurnStartParams | JsonObject | None = None,
) -> AsyncIterator[AgentMessageDeltaNotification]:
"""Stream text deltas from one turn without monopolizing the event loop."""
iterator = self._sync.stream_text(thread_id, text, params)
while True:
has_value, chunk = await asyncio.to_thread(

View File

@@ -243,6 +243,7 @@ class AppServerClient:
return response_model.model_validate(result)
def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue:
"""Send a JSON-RPC request and wait for the reader thread to route its response."""
request_id = str(uuid.uuid4())
waiter = self._router.create_response_waiter(request_id)
@@ -260,18 +261,23 @@ class AppServerClient:
return item
def notify(self, method: str, params: JsonObject | None = None) -> None:
"""Send a JSON-RPC notification without waiting for a response."""
self._write_message({"method": method, "params": params or {}})
def next_notification(self) -> Notification:
"""Return the next notification that is not scoped to an active turn."""
return self._router.next_global_notification()
def register_turn_notifications(self, turn_id: str) -> None:
"""Start routing notifications for one turn into its dedicated queue."""
self._router.register_turn(turn_id)
def unregister_turn_notifications(self, turn_id: str) -> None:
"""Stop routing notifications for one turn into its dedicated queue."""
self._router.unregister_turn(turn_id)
def next_turn_notification(self, turn_id: str) -> Notification:
"""Return the next routed notification for the requested turn id."""
return self._router.next_turn_notification(turn_id)
def thread_start(
@@ -349,6 +355,7 @@ class AppServerClient:
input_items: list[JsonObject] | JsonObject | str,
params: V2TurnStartParams | JsonObject | None = None,
) -> TurnStartResponse:
"""Start a turn and register its notification queue as early as possible."""
payload = {
**_params_dict(params),
"threadId": thread_id,
@@ -406,6 +413,7 @@ class AppServerClient:
)
def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
"""Block on the routed turn stream until the matching completion arrives."""
self.register_turn_notifications(turn_id)
try:
while True:
@@ -425,6 +433,7 @@ class AppServerClient:
text: str,
params: V2TurnStartParams | JsonObject | None = None,
) -> Iterator[AgentMessageDeltaNotification]:
"""Start a text turn and yield only its agent-message delta payloads."""
started = self.turn_start(thread_id, text, params=params)
turn_id = started.turn.id
self.register_turn_notifications(turn_id)
@@ -477,6 +486,7 @@ class AppServerClient:
def _default_approval_handler(
self, method: str, params: JsonObject | None
) -> JsonObject:
"""Accept approval requests when the caller did not provide a handler."""
if method == "item/commandExecution/requestApproval":
return {"decision": "accept"}
if method == "item/fileChange/requestApproval":
@@ -498,6 +508,7 @@ class AppServerClient:
self._stderr_thread.start()
def _start_reader_thread(self) -> None:
"""Start the sole stdout reader that fans messages into router queues."""
if self._proc is None or self._proc.stdout is None:
return
@@ -505,6 +516,7 @@ class AppServerClient:
self._reader_thread.start()
def _reader_loop(self) -> None:
"""Continuously classify transport messages into requests, responses, and events."""
try:
while True:
msg = self._read_message()

View File

@@ -165,6 +165,7 @@ NESTED_TURN_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = (
def notification_turn_id(payload: BaseModel) -> str | None:
"""Return the turn id carried by generated notification payload metadata."""
if isinstance(payload, DIRECT_TURN_ID_NOTIFICATION_TYPES):
return payload.turn_id if isinstance(payload.turn_id, str) else None
if isinstance(payload, NESTED_TURN_NOTIFICATION_TYPES):