mirror of
https://github.com/openai/codex.git
synced 2026-05-16 01:02:48 +00:00
Document SDK turn routing helpers
Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -22,6 +22,7 @@ class MessageRouter:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Create empty response, turn, and global notification queues."""
|
||||
self._lock = threading.Lock()
|
||||
self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {}
|
||||
self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {}
|
||||
@@ -144,6 +145,7 @@ class MessageRouter:
|
||||
self._global_notifications.put(exc)
|
||||
|
||||
def _notification_turn_id(self, notification: Notification) -> str | None:
|
||||
"""Extract routing ids from known generated payloads or raw unknown payloads."""
|
||||
payload = notification.payload
|
||||
if isinstance(payload, UnknownNotification):
|
||||
raw_turn_id = payload.params.get("turnId")
|
||||
|
||||
@@ -678,6 +678,7 @@ class TurnHandle:
|
||||
return self._client.turn_interrupt(self.thread_id, self.id)
|
||||
|
||||
def stream(self) -> Iterator[Notification]:
|
||||
"""Yield only notifications routed to this turn handle."""
|
||||
self._client.register_turn_notifications(self.id)
|
||||
try:
|
||||
while True:
|
||||
@@ -730,6 +731,7 @@ class AsyncTurnHandle:
|
||||
return await self._codex._client.turn_interrupt(self.thread_id, self.id)
|
||||
|
||||
async def stream(self) -> AsyncIterator[Notification]:
|
||||
"""Yield only notifications routed to this async turn handle."""
|
||||
await self._codex._ensure_initialized()
|
||||
self._codex._client.register_turn_notifications(self.id)
|
||||
try:
|
||||
|
||||
@@ -40,13 +40,16 @@ class AsyncAppServerClient:
|
||||
"""Async wrapper around AppServerClient using thread offloading."""
|
||||
|
||||
def __init__(self, config: AppServerConfig | None = None) -> None:
|
||||
"""Create the wrapped sync client that owns the transport process."""
|
||||
self._sync = AppServerClient(config=config)
|
||||
|
||||
async def __aenter__(self) -> "AsyncAppServerClient":
|
||||
"""Start the app-server process when entering an async context."""
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, _exc_type, _exc, _tb) -> None:
|
||||
"""Close the app-server process when leaving an async context."""
|
||||
await self.close()
|
||||
|
||||
async def _call_sync(
|
||||
@@ -56,30 +59,37 @@ class AsyncAppServerClient:
|
||||
*args: ParamsT.args,
|
||||
**kwargs: ParamsT.kwargs,
|
||||
) -> ReturnT:
|
||||
"""Run a blocking sync-client operation without blocking the event loop."""
|
||||
return await asyncio.to_thread(fn, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _next_from_iterator(
|
||||
iterator: Iterator[AgentMessageDeltaNotification],
|
||||
) -> tuple[bool, AgentMessageDeltaNotification | None]:
|
||||
"""Convert StopIteration into a value that can cross asyncio.to_thread."""
|
||||
try:
|
||||
return True, next(iterator)
|
||||
except StopIteration:
|
||||
return False, None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the wrapped sync client in a worker thread."""
|
||||
await self._call_sync(self._sync.start)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the wrapped sync client in a worker thread."""
|
||||
await self._call_sync(self._sync.close)
|
||||
|
||||
async def initialize(self) -> InitializeResponse:
|
||||
"""Initialize the app-server session."""
|
||||
return await self._call_sync(self._sync.initialize)
|
||||
|
||||
def register_turn_notifications(self, turn_id: str) -> None:
|
||||
"""Register a turn notification queue on the wrapped sync client."""
|
||||
self._sync.register_turn_notifications(turn_id)
|
||||
|
||||
def unregister_turn_notifications(self, turn_id: str) -> None:
|
||||
"""Unregister a turn notification queue on the wrapped sync client."""
|
||||
self._sync.unregister_turn_notifications(turn_id)
|
||||
|
||||
async def request(
|
||||
@@ -89,6 +99,7 @@ class AsyncAppServerClient:
|
||||
*,
|
||||
response_model: type[ModelT],
|
||||
) -> ModelT:
|
||||
"""Send a typed JSON-RPC request through the wrapped sync client."""
|
||||
return await self._call_sync(
|
||||
self._sync.request,
|
||||
method,
|
||||
@@ -99,6 +110,7 @@ class AsyncAppServerClient:
|
||||
async def thread_start(
|
||||
self, params: V2ThreadStartParams | JsonObject | None = None
|
||||
) -> ThreadStartResponse:
|
||||
"""Start a thread using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_start, params)
|
||||
|
||||
async def thread_resume(
|
||||
@@ -106,16 +118,19 @@ class AsyncAppServerClient:
|
||||
thread_id: str,
|
||||
params: V2ThreadResumeParams | JsonObject | None = None,
|
||||
) -> ThreadResumeResponse:
|
||||
"""Resume a thread using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_resume, thread_id, params)
|
||||
|
||||
async def thread_list(
|
||||
self, params: V2ThreadListParams | JsonObject | None = None
|
||||
) -> ThreadListResponse:
|
||||
"""List threads using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_list, params)
|
||||
|
||||
async def thread_read(
|
||||
self, thread_id: str, include_turns: bool = False
|
||||
) -> ThreadReadResponse:
|
||||
"""Read a thread using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_read, thread_id, include_turns)
|
||||
|
||||
async def thread_fork(
|
||||
@@ -123,18 +138,23 @@ class AsyncAppServerClient:
|
||||
thread_id: str,
|
||||
params: V2ThreadForkParams | JsonObject | None = None,
|
||||
) -> ThreadForkResponse:
|
||||
"""Fork a thread using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_fork, thread_id, params)
|
||||
|
||||
async def thread_archive(self, thread_id: str) -> ThreadArchiveResponse:
|
||||
"""Archive a thread using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_archive, thread_id)
|
||||
|
||||
async def thread_unarchive(self, thread_id: str) -> ThreadUnarchiveResponse:
|
||||
"""Unarchive a thread using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_unarchive, thread_id)
|
||||
|
||||
async def thread_set_name(self, thread_id: str, name: str) -> ThreadSetNameResponse:
|
||||
"""Rename a thread using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_set_name, thread_id, name)
|
||||
|
||||
async def thread_compact(self, thread_id: str) -> ThreadCompactStartResponse:
|
||||
"""Start thread compaction using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_compact, thread_id)
|
||||
|
||||
async def turn_start(
|
||||
@@ -143,6 +163,7 @@ class AsyncAppServerClient:
|
||||
input_items: list[JsonObject] | JsonObject | str,
|
||||
params: V2TurnStartParams | JsonObject | None = None,
|
||||
) -> TurnStartResponse:
|
||||
"""Start a turn using the wrapped sync client."""
|
||||
return await self._call_sync(
|
||||
self._sync.turn_start, thread_id, input_items, params
|
||||
)
|
||||
@@ -150,6 +171,7 @@ class AsyncAppServerClient:
|
||||
async def turn_interrupt(
|
||||
self, thread_id: str, turn_id: str
|
||||
) -> TurnInterruptResponse:
|
||||
"""Interrupt a turn using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.turn_interrupt, thread_id, turn_id)
|
||||
|
||||
async def turn_steer(
|
||||
@@ -158,6 +180,7 @@ class AsyncAppServerClient:
|
||||
expected_turn_id: str,
|
||||
input_items: list[JsonObject] | JsonObject | str,
|
||||
) -> TurnSteerResponse:
|
||||
"""Send steering input to a turn using the wrapped sync client."""
|
||||
return await self._call_sync(
|
||||
self._sync.turn_steer,
|
||||
thread_id,
|
||||
@@ -166,6 +189,7 @@ class AsyncAppServerClient:
|
||||
)
|
||||
|
||||
async def model_list(self, include_hidden: bool = False) -> ModelListResponse:
|
||||
"""List models using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.model_list, include_hidden)
|
||||
|
||||
async def request_with_retry_on_overload(
|
||||
@@ -178,6 +202,7 @@ class AsyncAppServerClient:
|
||||
initial_delay_s: float = 0.25,
|
||||
max_delay_s: float = 2.0,
|
||||
) -> ModelT:
|
||||
"""Send a typed request with the sync client's overload retry policy."""
|
||||
return await self._call_sync(
|
||||
self._sync.request_with_retry_on_overload,
|
||||
method,
|
||||
@@ -189,12 +214,15 @@ class AsyncAppServerClient:
|
||||
)
|
||||
|
||||
async def next_notification(self) -> Notification:
|
||||
"""Wait for the next global notification without blocking the event loop."""
|
||||
return await self._call_sync(self._sync.next_notification)
|
||||
|
||||
async def next_turn_notification(self, turn_id: str) -> Notification:
|
||||
"""Wait for the next notification routed to one turn."""
|
||||
return await self._call_sync(self._sync.next_turn_notification, turn_id)
|
||||
|
||||
async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
|
||||
"""Wait for the completion notification routed to one turn."""
|
||||
return await self._call_sync(self._sync.wait_for_turn_completed, turn_id)
|
||||
|
||||
async def stream_text(
|
||||
@@ -203,6 +231,7 @@ class AsyncAppServerClient:
|
||||
text: str,
|
||||
params: V2TurnStartParams | JsonObject | None = None,
|
||||
) -> AsyncIterator[AgentMessageDeltaNotification]:
|
||||
"""Stream text deltas from one turn without monopolizing the event loop."""
|
||||
iterator = self._sync.stream_text(thread_id, text, params)
|
||||
while True:
|
||||
has_value, chunk = await asyncio.to_thread(
|
||||
|
||||
@@ -243,6 +243,7 @@ class AppServerClient:
|
||||
return response_model.model_validate(result)
|
||||
|
||||
def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue:
|
||||
"""Send a JSON-RPC request and wait for the reader thread to route its response."""
|
||||
request_id = str(uuid.uuid4())
|
||||
waiter = self._router.create_response_waiter(request_id)
|
||||
|
||||
@@ -260,18 +261,23 @@ class AppServerClient:
|
||||
return item
|
||||
|
||||
def notify(self, method: str, params: JsonObject | None = None) -> None:
|
||||
"""Send a JSON-RPC notification without waiting for a response."""
|
||||
self._write_message({"method": method, "params": params or {}})
|
||||
|
||||
def next_notification(self) -> Notification:
|
||||
"""Return the next notification that is not scoped to an active turn."""
|
||||
return self._router.next_global_notification()
|
||||
|
||||
def register_turn_notifications(self, turn_id: str) -> None:
|
||||
"""Start routing notifications for one turn into its dedicated queue."""
|
||||
self._router.register_turn(turn_id)
|
||||
|
||||
def unregister_turn_notifications(self, turn_id: str) -> None:
|
||||
"""Stop routing notifications for one turn into its dedicated queue."""
|
||||
self._router.unregister_turn(turn_id)
|
||||
|
||||
def next_turn_notification(self, turn_id: str) -> Notification:
|
||||
"""Return the next routed notification for the requested turn id."""
|
||||
return self._router.next_turn_notification(turn_id)
|
||||
|
||||
def thread_start(
|
||||
@@ -349,6 +355,7 @@ class AppServerClient:
|
||||
input_items: list[JsonObject] | JsonObject | str,
|
||||
params: V2TurnStartParams | JsonObject | None = None,
|
||||
) -> TurnStartResponse:
|
||||
"""Start a turn and register its notification queue as early as possible."""
|
||||
payload = {
|
||||
**_params_dict(params),
|
||||
"threadId": thread_id,
|
||||
@@ -406,6 +413,7 @@ class AppServerClient:
|
||||
)
|
||||
|
||||
def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
|
||||
"""Block on the routed turn stream until the matching completion arrives."""
|
||||
self.register_turn_notifications(turn_id)
|
||||
try:
|
||||
while True:
|
||||
@@ -425,6 +433,7 @@ class AppServerClient:
|
||||
text: str,
|
||||
params: V2TurnStartParams | JsonObject | None = None,
|
||||
) -> Iterator[AgentMessageDeltaNotification]:
|
||||
"""Start a text turn and yield only its agent-message delta payloads."""
|
||||
started = self.turn_start(thread_id, text, params=params)
|
||||
turn_id = started.turn.id
|
||||
self.register_turn_notifications(turn_id)
|
||||
@@ -477,6 +486,7 @@ class AppServerClient:
|
||||
def _default_approval_handler(
|
||||
self, method: str, params: JsonObject | None
|
||||
) -> JsonObject:
|
||||
"""Accept approval requests when the caller did not provide a handler."""
|
||||
if method == "item/commandExecution/requestApproval":
|
||||
return {"decision": "accept"}
|
||||
if method == "item/fileChange/requestApproval":
|
||||
@@ -498,6 +508,7 @@ class AppServerClient:
|
||||
self._stderr_thread.start()
|
||||
|
||||
def _start_reader_thread(self) -> None:
|
||||
"""Start the sole stdout reader that fans messages into router queues."""
|
||||
if self._proc is None or self._proc.stdout is None:
|
||||
return
|
||||
|
||||
@@ -505,6 +516,7 @@ class AppServerClient:
|
||||
self._reader_thread.start()
|
||||
|
||||
def _reader_loop(self) -> None:
|
||||
"""Continuously classify transport messages into requests, responses, and events."""
|
||||
try:
|
||||
while True:
|
||||
msg = self._read_message()
|
||||
|
||||
@@ -165,6 +165,7 @@ NESTED_TURN_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = (
|
||||
|
||||
|
||||
def notification_turn_id(payload: BaseModel) -> str | None:
|
||||
"""Return the turn id carried by generated notification payload metadata."""
|
||||
if isinstance(payload, DIRECT_TURN_ID_NOTIFICATION_TYPES):
|
||||
return payload.turn_id if isinstance(payload.turn_id, str) else None
|
||||
if isinstance(payload, NESTED_TURN_NOTIFICATION_TYPES):
|
||||
|
||||
Reference in New Issue
Block a user