diff --git a/sdk/python/docs/api-reference.md b/sdk/python/docs/api-reference.md index ddeaf39cd0..c7a763498d 100644 --- a/sdk/python/docs/api-reference.md +++ b/sdk/python/docs/api-reference.md @@ -2,7 +2,7 @@ Public surface of `codex_app_server` for app-server v2. -This SDK surface is experimental. The current implementation intentionally allows only one active turn consumer (`Thread.run()`, `TurnHandle.stream()`, or `TurnHandle.run()`) per client instance at a time. +This SDK surface is experimental. Turn streams are routed by turn ID so one client can consume multiple active turns concurrently. ## Package Entry @@ -137,8 +137,8 @@ Use `turn(...)` when you need low-level turn control (`stream()`, `steer()`, Behavior notes: -- `stream()` and `run()` are exclusive per client instance in the current experimental build -- starting a second turn consumer on the same `Codex` instance raises `RuntimeError` +- `stream()` and `run()` consume only notifications for their own turn ID +- one `Codex` instance can stream multiple active turns concurrently ### AsyncTurnHandle @@ -149,8 +149,8 @@ Behavior notes: Behavior notes: -- `stream()` and `run()` are exclusive per client instance in the current experimental build -- starting a second turn consumer on the same `AsyncCodex` instance raises `RuntimeError` +- `stream()` and `run()` consume only notifications for their own turn ID +- one `AsyncCodex` instance can stream multiple active turns concurrently ## Inputs diff --git a/sdk/python/docs/getting-started.md b/sdk/python/docs/getting-started.md index 45ad1eb51f..1794d39f70 100644 --- a/sdk/python/docs/getting-started.md +++ b/sdk/python/docs/getting-started.md @@ -45,7 +45,7 @@ What happened: - `thread.run("...")` started a turn, consumed events until completion, and returned the final assistant response plus collected items and usage. - `result.final_response` is `None` when no final-answer or phase-less assistant message item completes for the turn. - use `thread.turn(...)` when you need a `TurnHandle` for streaming, steering, interrupting, or turn IDs/status -- one client can have only one active turn consumer (`thread.run(...)`, `TurnHandle.stream()`, or `TurnHandle.run()`) at a time in the current experimental build +- one client can consume multiple active turns concurrently; turn streams are routed by turn ID ## 3) Continue the same thread (multi-turn) diff --git a/sdk/python/scripts/update_sdk_artifacts.py b/sdk/python/scripts/update_sdk_artifacts.py index be9a115914..4ff6f0c24f 100755 --- a/sdk/python/scripts/update_sdk_artifacts.py +++ b/sdk/python/scripts/update_sdk_artifacts.py @@ -585,6 +585,43 @@ def _notification_specs() -> list[tuple[str, str]]: return specs +def _notification_turn_id_specs( + specs: list[tuple[str, str]], +) -> tuple[list[str], list[str]]: + server_notifications = json.loads( + (schema_root_dir() / "ServerNotification.json").read_text() + ) + definitions = server_notifications.get("definitions", {}) + if not isinstance(definitions, dict): + return ([], []) + + direct: list[str] = [] + nested: list[str] = [] + for _, class_name in specs: + definition = definitions.get(class_name) + if not isinstance(definition, dict): + continue + props = definition.get("properties", {}) + if not isinstance(props, dict): + continue + if "turnId" in props: + direct.append(class_name) + continue + turn = props.get("turn") + if isinstance(turn, dict) and turn.get("$ref") == "#/definitions/Turn": + nested.append(class_name) + + return (sorted(set(direct)), sorted(set(nested))) + + +def _type_tuple_source(class_names: list[str]) -> str: + if not class_names: + return "()" + if len(class_names) == 1: + return f"({class_names[0]},)" + return "(\n" + "".join(f" {class_name},\n" for class_name in class_names) + ")" + + def generate_notification_registry() -> None: out = ( sdk_root() @@ -595,6 +632,7 @@ def generate_notification_registry() -> None: ) specs = _notification_specs() class_names = sorted({class_name for _, class_name in specs}) + direct_turn_id_types, nested_turn_types = _notification_turn_id_specs(specs) lines = [ "# Auto-generated by scripts/update_sdk_artifacts.py", @@ -616,7 +654,26 @@ def generate_notification_registry() -> None: ) for method, class_name in specs: lines.append(f' "{method}": {class_name},') - lines.extend(["}", ""]) + lines.extend( + [ + "}", + "", + "DIRECT_TURN_ID_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = " + f"{_type_tuple_source(direct_turn_id_types)}", + "", + "NESTED_TURN_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = " + f"{_type_tuple_source(nested_turn_types)}", + "", + "", + "def notification_turn_id(payload: BaseModel) -> str | None:", + " if isinstance(payload, DIRECT_TURN_ID_NOTIFICATION_TYPES):", + " return payload.turn_id if isinstance(payload.turn_id, str) else None", + " if isinstance(payload, NESTED_TURN_NOTIFICATION_TYPES):", + " return payload.turn.id", + " return None", + "", + ] + ) out.write_text("\n".join(lines)) diff --git a/sdk/python/src/codex_app_server/_message_router.py b/sdk/python/src/codex_app_server/_message_router.py new file mode 100644 index 0000000000..6de575166c --- /dev/null +++ b/sdk/python/src/codex_app_server/_message_router.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import queue +import threading +from collections import deque + +from .errors import AppServerError, map_jsonrpc_error +from .generated.notification_registry import notification_turn_id +from .models import JsonValue, Notification, UnknownNotification + +ResponseQueueItem = JsonValue | BaseException +NotificationQueueItem = Notification | BaseException + + +class MessageRouter: + """Route reader-thread messages to the SDK operation waiting for them. + + The app-server stdio transport is a single ordered stream, so only the + reader thread should consume stdout. This router keeps the rest of the SDK + from competing for that stream by giving each in-flight JSON-RPC request + and active turn stream its own queue. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {} + self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {} + self._pending_turn_notifications: dict[str, deque[Notification]] = {} + self._global_notifications: queue.Queue[NotificationQueueItem] = queue.Queue() + + def create_response_waiter(self, request_id: str) -> queue.Queue[ResponseQueueItem]: + """Register a one-shot queue for a JSON-RPC response id.""" + + waiter: queue.Queue[ResponseQueueItem] = queue.Queue(maxsize=1) + with self._lock: + self._response_waiters[request_id] = waiter + return waiter + + def discard_response_waiter(self, request_id: str) -> None: + """Remove a response waiter when the request could not be written.""" + + with self._lock: + self._response_waiters.pop(request_id, None) + + def next_global_notification(self) -> Notification: + """Block until the next notification that is not scoped to a turn.""" + + item = self._global_notifications.get() + if isinstance(item, BaseException): + raise item + return item + + def register_turn(self, turn_id: str) -> None: + """Register a queue for a turn stream and replay early events.""" + + turn_queue: queue.Queue[NotificationQueueItem] = queue.Queue() + with self._lock: + if turn_id in self._turn_notifications: + return + # A turn can emit events immediately after turn/start, before the + # caller receives the TurnHandle and starts streaming. + pending = self._pending_turn_notifications.pop(turn_id, deque()) + self._turn_notifications[turn_id] = turn_queue + for notification in pending: + turn_queue.put(notification) + + def unregister_turn(self, turn_id: str) -> None: + """Stop routing future turn events to the stream queue.""" + + with self._lock: + self._turn_notifications.pop(turn_id, None) + + def next_turn_notification(self, turn_id: str) -> Notification: + """Block until the next notification for a registered turn.""" + + with self._lock: + turn_queue = self._turn_notifications.get(turn_id) + if turn_queue is None: + raise RuntimeError(f"turn {turn_id!r} is not registered for streaming") + item = turn_queue.get() + if isinstance(item, BaseException): + raise item + return item + + def route_response(self, msg: dict[str, JsonValue]) -> None: + """Deliver a JSON-RPC response or error to its request waiter.""" + + request_id = msg.get("id") + with self._lock: + waiter = self._response_waiters.pop(str(request_id), None) + if waiter is None: + return + + if "error" in msg: + err = msg["error"] + if isinstance(err, dict): + waiter.put( + map_jsonrpc_error( + int(err.get("code", -32000)), + str(err.get("message", "unknown")), + err.get("data"), + ) + ) + else: + waiter.put(AppServerError("Malformed JSON-RPC error response")) + return + + waiter.put(msg.get("result")) + + def route_notification(self, notification: Notification) -> None: + """Deliver a notification to a turn queue or the global queue.""" + + turn_id = self._notification_turn_id(notification) + if turn_id is None: + self._global_notifications.put(notification) + return + + with self._lock: + turn_queue = self._turn_notifications.get(turn_id) + if turn_queue is None: + if notification.method == "turn/completed": + self._pending_turn_notifications.pop(turn_id, None) + return + self._pending_turn_notifications.setdefault(turn_id, deque()).append( + notification + ) + return + turn_queue.put(notification) + + def fail_all(self, exc: BaseException) -> None: + """Wake every blocked waiter when the reader thread exits.""" + + with self._lock: + response_waiters = list(self._response_waiters.values()) + self._response_waiters.clear() + turn_queues = list(self._turn_notifications.values()) + self._pending_turn_notifications.clear() + # Put the same transport failure into every queue so no SDK call blocks + # forever waiting for a response that cannot arrive. + for waiter in response_waiters: + waiter.put(exc) + for turn_queue in turn_queues: + turn_queue.put(exc) + self._global_notifications.put(exc) + + def _notification_turn_id(self, notification: Notification) -> str | None: + payload = notification.payload + if isinstance(payload, UnknownNotification): + raw_turn_id = payload.params.get("turnId") + if isinstance(raw_turn_id, str): + return raw_turn_id + raw_turn = payload.params.get("turn") + if isinstance(raw_turn, dict): + raw_nested_turn_id = raw_turn.get("id") + if isinstance(raw_nested_turn_id, str): + return raw_nested_turn_id + return None + return notification_turn_id(payload) diff --git a/sdk/python/src/codex_app_server/api.py b/sdk/python/src/codex_app_server/api.py index 2c71859cc8..6f61a55868 100644 --- a/sdk/python/src/codex_app_server/api.py +++ b/sdk/python/src/codex_app_server/api.py @@ -38,14 +38,14 @@ from .generated.v2_all import ( ) from .models import InitializeResponse, JsonObject, Notification, ServerInfo from ._inputs import ( - ImageInput, + ImageInput as ImageInput, Input, - InputItem, - LocalImageInput, - MentionInput, + InputItem as InputItem, + LocalImageInput as LocalImageInput, + MentionInput as MentionInput, RunInput, - SkillInput, - TextInput, + SkillInput as SkillInput, + TextInput as TextInput, _normalize_run_input, _to_wire_input, ) @@ -274,6 +274,7 @@ class Codex: def thread_unarchive(self, thread_id: str) -> Thread: unarchived = self._client.thread_unarchive(thread_id) return Thread(self._client, unarchived.thread.id) + # END GENERATED: Codex.flat_methods def models(self, *, include_hidden: bool = False) -> ModelListResponse: @@ -476,6 +477,7 @@ class AsyncCodex: await self._ensure_initialized() unarchived = await self._client.thread_unarchive(thread_id) return AsyncThread(self, unarchived.thread.id) + # END GENERATED: AsyncCodex.flat_methods async def models(self, *, include_hidden: bool = False) -> ModelListResponse: @@ -555,6 +557,7 @@ class Thread: ) turn = self._client.turn_start(self.id, wire_input, params=params) return TurnHandle(self._client, self.id, turn.turn.id) + # END GENERATED: Thread.flat_methods def read(self, *, include_turns: bool = False) -> ThreadReadResponse: @@ -644,6 +647,7 @@ class AsyncThread: params=params, ) return AsyncTurnHandle(self._codex, self.id, turn.turn.id) + # END GENERATED: AsyncThread.flat_methods async def read(self, *, include_turns: bool = False) -> ThreadReadResponse: @@ -674,11 +678,10 @@ class TurnHandle: return self._client.turn_interrupt(self.thread_id, self.id) def stream(self) -> Iterator[Notification]: - # TODO: replace this client-wide experimental guard with per-turn event demux. - self._client.acquire_turn_consumer(self.id) + self._client.register_turn_notifications(self.id) try: while True: - event = self._client.next_notification() + event = self._client.next_turn_notification(self.id) yield event if ( event.method == "turn/completed" @@ -687,7 +690,7 @@ class TurnHandle: ): break finally: - self._client.release_turn_consumer(self.id) + self._client.unregister_turn_notifications(self.id) def run(self) -> AppServerTurn: completed: TurnCompletedNotification | None = None @@ -728,11 +731,10 @@ class AsyncTurnHandle: async def stream(self) -> AsyncIterator[Notification]: await self._codex._ensure_initialized() - # TODO: replace this client-wide experimental guard with per-turn event demux. - self._codex._client.acquire_turn_consumer(self.id) + self._codex._client.register_turn_notifications(self.id) try: while True: - event = await self._codex._client.next_notification() + event = await self._codex._client.next_turn_notification(self.id) yield event if ( event.method == "turn/completed" @@ -741,7 +743,7 @@ class AsyncTurnHandle: ): break finally: - self._codex._client.release_turn_consumer(self.id) + self._codex._client.unregister_turn_notifications(self.id) async def run(self) -> AppServerTurn: completed: TurnCompletedNotification | None = None diff --git a/sdk/python/src/codex_app_server/async_client.py b/sdk/python/src/codex_app_server/async_client.py index 6ca0c42a78..6768c7d9fc 100644 --- a/sdk/python/src/codex_app_server/async_client.py +++ b/sdk/python/src/codex_app_server/async_client.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio from collections.abc import Iterator -from typing import AsyncIterator, Callable, Iterable, ParamSpec, TypeVar +from typing import AsyncIterator, Callable, ParamSpec, TypeVar from pydantic import BaseModel @@ -41,8 +41,6 @@ class AsyncAppServerClient: def __init__(self, config: AppServerConfig | None = None) -> None: self._sync = AppServerClient(config=config) - # Single stdio transport cannot be read safely from multiple threads. - self._transport_lock = asyncio.Lock() async def __aenter__(self) -> "AsyncAppServerClient": await self.start() @@ -58,8 +56,7 @@ class AsyncAppServerClient: *args: ParamsT.args, **kwargs: ParamsT.kwargs, ) -> ReturnT: - async with self._transport_lock: - return await asyncio.to_thread(fn, *args, **kwargs) + return await asyncio.to_thread(fn, *args, **kwargs) @staticmethod def _next_from_iterator( @@ -79,11 +76,11 @@ class AsyncAppServerClient: async def initialize(self) -> InitializeResponse: return await self._call_sync(self._sync.initialize) - def acquire_turn_consumer(self, turn_id: str) -> None: - self._sync.acquire_turn_consumer(turn_id) + def register_turn_notifications(self, turn_id: str) -> None: + self._sync.register_turn_notifications(turn_id) - def release_turn_consumer(self, turn_id: str) -> None: - self._sync.release_turn_consumer(turn_id) + def unregister_turn_notifications(self, turn_id: str) -> None: + self._sync.unregister_turn_notifications(turn_id) async def request( self, @@ -99,7 +96,9 @@ class AsyncAppServerClient: response_model=response_model, ) - async def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse: + async def thread_start( + self, params: V2ThreadStartParams | JsonObject | None = None + ) -> ThreadStartResponse: return await self._call_sync(self._sync.thread_start, params) async def thread_resume( @@ -109,10 +108,14 @@ class AsyncAppServerClient: ) -> ThreadResumeResponse: return await self._call_sync(self._sync.thread_resume, thread_id, params) - async def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse: + async def thread_list( + self, params: V2ThreadListParams | JsonObject | None = None + ) -> ThreadListResponse: return await self._call_sync(self._sync.thread_list, params) - async def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse: + async def thread_read( + self, thread_id: str, include_turns: bool = False + ) -> ThreadReadResponse: return await self._call_sync(self._sync.thread_read, thread_id, include_turns) async def thread_fork( @@ -140,9 +143,13 @@ class AsyncAppServerClient: input_items: list[JsonObject] | JsonObject | str, params: V2TurnStartParams | JsonObject | None = None, ) -> TurnStartResponse: - return await self._call_sync(self._sync.turn_start, thread_id, input_items, params) + return await self._call_sync( + self._sync.turn_start, thread_id, input_items, params + ) - async def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse: + async def turn_interrupt( + self, thread_id: str, turn_id: str + ) -> TurnInterruptResponse: return await self._call_sync(self._sync.turn_interrupt, thread_id, turn_id) async def turn_steer( @@ -184,25 +191,24 @@ class AsyncAppServerClient: async def next_notification(self) -> Notification: return await self._call_sync(self._sync.next_notification) + async def next_turn_notification(self, turn_id: str) -> Notification: + return await self._call_sync(self._sync.next_turn_notification, turn_id) + async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification: return await self._call_sync(self._sync.wait_for_turn_completed, turn_id) - async def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]: - return await self._call_sync(self._sync.stream_until_methods, methods) - async def stream_text( self, thread_id: str, text: str, params: V2TurnStartParams | JsonObject | None = None, ) -> AsyncIterator[AgentMessageDeltaNotification]: - async with self._transport_lock: - iterator = self._sync.stream_text(thread_id, text, params) - while True: - has_value, chunk = await asyncio.to_thread( - self._next_from_iterator, - iterator, - ) - if not has_value: - break - yield chunk + iterator = self._sync.stream_text(thread_id, text, params) + while True: + has_value, chunk = await asyncio.to_thread( + self._next_from_iterator, + iterator, + ) + if not has_value: + break + yield chunk diff --git a/sdk/python/src/codex_app_server/client.py b/sdk/python/src/codex_app_server/client.py index 665e1c6725..ce3df4e416 100644 --- a/sdk/python/src/codex_app_server/client.py +++ b/sdk/python/src/codex_app_server/client.py @@ -8,11 +8,11 @@ import uuid from collections import deque from dataclasses import dataclass from pathlib import Path -from typing import Callable, Iterable, Iterator, TypeVar +from typing import Callable, Iterator, TypeVar from pydantic import BaseModel -from .errors import AppServerError, TransportClosedError, map_jsonrpc_error +from .errors import AppServerError, TransportClosedError from .generated.notification_registry import NOTIFICATION_MODELS from .generated.v2_all import ( AgentMessageDeltaNotification, @@ -43,6 +43,7 @@ from .models import ( Notification, UnknownNotification, ) +from ._message_router import MessageRouter from .retry import retry_on_overload from ._version import __version__ as SDK_VERSION @@ -75,7 +76,9 @@ def _params_dict( return dumped if isinstance(params, dict): return params - raise TypeError(f"Expected generated params model or dict, got {type(params).__name__}") + raise TypeError( + f"Expected generated params model or dict, got {type(params).__name__}" + ) def _installed_codex_path() -> Path: @@ -146,11 +149,10 @@ class AppServerClient: self._approval_handler = approval_handler or self._default_approval_handler self._proc: subprocess.Popen[str] | None = None self._lock = threading.Lock() - self._turn_consumer_lock = threading.Lock() - self._active_turn_consumer: str | None = None - self._pending_notifications: deque[Notification] = deque() + self._router = MessageRouter() self._stderr_lines: deque[str] = deque(maxlen=400) self._stderr_thread: threading.Thread | None = None + self._reader_thread: threading.Thread | None = None def __enter__(self) -> "AppServerClient": self.start() @@ -189,13 +191,13 @@ class AppServerClient: ) self._start_stderr_drain_thread() + self._start_reader_thread() def close(self) -> None: if self._proc is None: return proc = self._proc self._proc = None - self._active_turn_consumer = None if proc.stdin: proc.stdin.close() @@ -207,6 +209,8 @@ class AppServerClient: if self._stderr_thread and self._stderr_thread.is_alive(): self._stderr_thread.join(timeout=0.5) + if self._reader_thread and self._reader_thread.is_alive(): + self._reader_thread.join(timeout=0.5) def initialize(self) -> InitializeResponse: result = self.request( @@ -240,70 +244,42 @@ class AppServerClient: def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue: request_id = str(uuid.uuid4()) - self._write_message({"id": request_id, "method": method, "params": params or {}}) + waiter = self._router.create_response_waiter(request_id) - while True: - msg = self._read_message() + try: + self._write_message( + {"id": request_id, "method": method, "params": params or {}} + ) + except BaseException: + self._router.discard_response_waiter(request_id) + raise - if "method" in msg and "id" in msg: - response = self._handle_server_request(msg) - self._write_message({"id": msg["id"], "result": response}) - continue - - if "method" in msg and "id" not in msg: - self._pending_notifications.append( - self._coerce_notification(msg["method"], msg.get("params")) - ) - continue - - if msg.get("id") != request_id: - continue - - if "error" in msg: - err = msg["error"] - if isinstance(err, dict): - raise map_jsonrpc_error( - int(err.get("code", -32000)), - str(err.get("message", "unknown")), - err.get("data"), - ) - raise AppServerError("Malformed JSON-RPC error response") - - return msg.get("result") + item = waiter.get() + if isinstance(item, BaseException): + raise item + return item def notify(self, method: str, params: JsonObject | None = None) -> None: self._write_message({"method": method, "params": params or {}}) def next_notification(self) -> Notification: - if self._pending_notifications: - return self._pending_notifications.popleft() + return self._router.next_global_notification() - while True: - msg = self._read_message() - if "method" in msg and "id" in msg: - response = self._handle_server_request(msg) - self._write_message({"id": msg["id"], "result": response}) - continue - if "method" in msg and "id" not in msg: - return self._coerce_notification(msg["method"], msg.get("params")) + def register_turn_notifications(self, turn_id: str) -> None: + self._router.register_turn(turn_id) - def acquire_turn_consumer(self, turn_id: str) -> None: - with self._turn_consumer_lock: - if self._active_turn_consumer is not None: - raise RuntimeError( - "Concurrent turn consumers are not yet supported in the experimental SDK. " - f"Client is already streaming turn {self._active_turn_consumer!r}; " - f"cannot start turn {turn_id!r} until the active consumer finishes." - ) - self._active_turn_consumer = turn_id + def unregister_turn_notifications(self, turn_id: str) -> None: + self._router.unregister_turn(turn_id) - def release_turn_consumer(self, turn_id: str) -> None: - with self._turn_consumer_lock: - if self._active_turn_consumer == turn_id: - self._active_turn_consumer = None + def next_turn_notification(self, turn_id: str) -> Notification: + return self._router.next_turn_notification(turn_id) - def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse: - return self.request("thread/start", _params_dict(params), response_model=ThreadStartResponse) + def thread_start( + self, params: V2ThreadStartParams | JsonObject | None = None + ) -> ThreadStartResponse: + return self.request( + "thread/start", _params_dict(params), response_model=ThreadStartResponse + ) def thread_resume( self, @@ -311,12 +287,20 @@ class AppServerClient: params: V2ThreadResumeParams | JsonObject | None = None, ) -> ThreadResumeResponse: payload = {"threadId": thread_id, **_params_dict(params)} - return self.request("thread/resume", payload, response_model=ThreadResumeResponse) + return self.request( + "thread/resume", payload, response_model=ThreadResumeResponse + ) - def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse: - return self.request("thread/list", _params_dict(params), response_model=ThreadListResponse) + def thread_list( + self, params: V2ThreadListParams | JsonObject | None = None + ) -> ThreadListResponse: + return self.request( + "thread/list", _params_dict(params), response_model=ThreadListResponse + ) - def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse: + def thread_read( + self, thread_id: str, include_turns: bool = False + ) -> ThreadReadResponse: return self.request( "thread/read", {"threadId": thread_id, "includeTurns": include_turns}, @@ -332,10 +316,18 @@ class AppServerClient: return self.request("thread/fork", payload, response_model=ThreadForkResponse) def thread_archive(self, thread_id: str) -> ThreadArchiveResponse: - return self.request("thread/archive", {"threadId": thread_id}, response_model=ThreadArchiveResponse) + return self.request( + "thread/archive", + {"threadId": thread_id}, + response_model=ThreadArchiveResponse, + ) def thread_unarchive(self, thread_id: str) -> ThreadUnarchiveResponse: - return self.request("thread/unarchive", {"threadId": thread_id}, response_model=ThreadUnarchiveResponse) + return self.request( + "thread/unarchive", + {"threadId": thread_id}, + response_model=ThreadUnarchiveResponse, + ) def thread_set_name(self, thread_id: str, name: str) -> ThreadSetNameResponse: return self.request( @@ -362,7 +354,9 @@ class AppServerClient: "threadId": thread_id, "input": self._normalize_input_items(input_items), } - return self.request("turn/start", payload, response_model=TurnStartResponse) + started = self.request("turn/start", payload, response_model=TurnStartResponse) + self.register_turn_notifications(started.turn.id) + return started def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse: return self.request( @@ -412,23 +406,18 @@ class AppServerClient: ) def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification: - while True: - notification = self.next_notification() - if ( - notification.method == "turn/completed" - and isinstance(notification.payload, TurnCompletedNotification) - and notification.payload.turn.id == turn_id - ): - return notification.payload - - def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]: - target_methods = {methods} if isinstance(methods, str) else set(methods) - out: list[Notification] = [] - while True: - notification = self.next_notification() - out.append(notification) - if notification.method in target_methods: - return out + self.register_turn_notifications(turn_id) + try: + while True: + notification = self.next_turn_notification(turn_id) + if ( + notification.method == "turn/completed" + and isinstance(notification.payload, TurnCompletedNotification) + and notification.payload.turn.id == turn_id + ): + return notification.payload + finally: + self.unregister_turn_notifications(turn_id) def stream_text( self, @@ -438,33 +427,41 @@ class AppServerClient: ) -> Iterator[AgentMessageDeltaNotification]: started = self.turn_start(thread_id, text, params=params) turn_id = started.turn.id - while True: - notification = self.next_notification() - if ( - notification.method == "item/agentMessage/delta" - and isinstance(notification.payload, AgentMessageDeltaNotification) - and notification.payload.turn_id == turn_id - ): - yield notification.payload - continue - if ( - notification.method == "turn/completed" - and isinstance(notification.payload, TurnCompletedNotification) - and notification.payload.turn.id == turn_id - ): - break + self.register_turn_notifications(turn_id) + try: + while True: + notification = self.next_turn_notification(turn_id) + if ( + notification.method == "item/agentMessage/delta" + and isinstance(notification.payload, AgentMessageDeltaNotification) + and notification.payload.turn_id == turn_id + ): + yield notification.payload + continue + if ( + notification.method == "turn/completed" + and isinstance(notification.payload, TurnCompletedNotification) + and notification.payload.turn.id == turn_id + ): + break + finally: + self.unregister_turn_notifications(turn_id) def _coerce_notification(self, method: str, params: object) -> Notification: params_dict = params if isinstance(params, dict) else {} model = NOTIFICATION_MODELS.get(method) if model is None: - return Notification(method=method, payload=UnknownNotification(params=params_dict)) + return Notification( + method=method, payload=UnknownNotification(params=params_dict) + ) try: payload = model.model_validate(params_dict) except Exception: # noqa: BLE001 - return Notification(method=method, payload=UnknownNotification(params=params_dict)) + return Notification( + method=method, payload=UnknownNotification(params=params_dict) + ) return Notification(method=method, payload=payload) def _normalize_input_items( @@ -477,7 +474,9 @@ class AppServerClient: return [input_items] return input_items - def _default_approval_handler(self, method: str, params: JsonObject | None) -> JsonObject: + def _default_approval_handler( + self, method: str, params: JsonObject | None + ) -> JsonObject: if method == "item/commandExecution/requestApproval": return {"decision": "accept"} if method == "item/fileChange/requestApproval": @@ -498,6 +497,32 @@ class AppServerClient: self._stderr_thread = threading.Thread(target=_drain, daemon=True) self._stderr_thread.start() + def _start_reader_thread(self) -> None: + if self._proc is None or self._proc.stdout is None: + return + + self._reader_thread = threading.Thread(target=self._reader_loop, daemon=True) + self._reader_thread.start() + + def _reader_loop(self) -> None: + try: + while True: + msg = self._read_message() + if "method" in msg and "id" in msg: + response = self._handle_server_request(msg) + self._write_message({"id": msg["id"], "result": response}) + continue + if "method" in msg and "id" not in msg: + method = msg["method"] + if isinstance(method, str): + self._router.route_notification( + self._coerce_notification(method, msg.get("params")) + ) + continue + self._router.route_response(msg) + except BaseException as exc: + self._router.fail_all(exc) + def _stderr_tail(self, limit: int = 40) -> str: return "\n".join(list(self._stderr_lines)[-limit:]) diff --git a/sdk/python/src/codex_app_server/generated/notification_registry.py b/sdk/python/src/codex_app_server/generated/notification_registry.py index a97dc98f34..b44ca2a436 100644 --- a/sdk/python/src/codex_app_server/generated/notification_registry.py +++ b/sdk/python/src/codex_app_server/generated/notification_registry.py @@ -130,3 +130,43 @@ NOTIFICATION_MODELS: dict[str, type[BaseModel]] = { "windows/worldWritableWarning": WindowsWorldWritableWarningNotification, "windowsSandbox/setupCompleted": WindowsSandboxSetupCompletedNotification, } + +DIRECT_TURN_ID_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = ( + AgentMessageDeltaNotification, + CommandExecutionOutputDeltaNotification, + ContextCompactedNotification, + ErrorNotification, + FileChangeOutputDeltaNotification, + FileChangePatchUpdatedNotification, + HookCompletedNotification, + HookStartedNotification, + ItemCompletedNotification, + ItemGuardianApprovalReviewCompletedNotification, + ItemGuardianApprovalReviewStartedNotification, + ItemStartedNotification, + McpToolCallProgressNotification, + ModelReroutedNotification, + ModelVerificationNotification, + PlanDeltaNotification, + ReasoningSummaryPartAddedNotification, + ReasoningSummaryTextDeltaNotification, + ReasoningTextDeltaNotification, + TerminalInteractionNotification, + ThreadGoalUpdatedNotification, + ThreadTokenUsageUpdatedNotification, + TurnDiffUpdatedNotification, + TurnPlanUpdatedNotification, +) + +NESTED_TURN_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = ( + TurnCompletedNotification, + TurnStartedNotification, +) + + +def notification_turn_id(payload: BaseModel) -> str | None: + if isinstance(payload, DIRECT_TURN_ID_NOTIFICATION_TYPES): + return payload.turn_id if isinstance(payload.turn_id, str) else None + if isinstance(payload, NESTED_TURN_NOTIFICATION_TYPES): + return payload.turn.id + return None diff --git a/sdk/python/tests/test_async_client_behavior.py b/sdk/python/tests/test_async_client_behavior.py index 580ff2a93b..0c4e8096fb 100644 --- a/sdk/python/tests/test_async_client_behavior.py +++ b/sdk/python/tests/test_async_client_behavior.py @@ -2,11 +2,17 @@ from __future__ import annotations import asyncio import time +from types import SimpleNamespace from codex_app_server.async_client import AsyncAppServerClient +from codex_app_server.generated.v2_all import ( + AgentMessageDeltaNotification, + TurnCompletedNotification, +) +from codex_app_server.models import Notification, UnknownNotification -def test_async_client_serializes_transport_calls() -> None: +def test_async_client_allows_concurrent_transport_calls() -> None: async def scenario() -> int: client = AsyncAppServerClient() active = 0 @@ -24,10 +30,10 @@ def test_async_client_serializes_transport_calls() -> None: await asyncio.gather(client.model_list(), client.model_list()) return max_active - assert asyncio.run(scenario()) == 1 + assert asyncio.run(scenario()) == 2 -def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None: +def test_async_stream_text_is_incremental_without_blocking_parallel_calls() -> None: async def scenario() -> tuple[str, list[str], bool]: client = AsyncAppServerClient() @@ -46,19 +52,155 @@ def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None: stream = client.stream_text("thread-1", "hello") first = await anext(stream) - blocked_before_stream_done = False competing_call = asyncio.create_task(client.model_list()) await asyncio.sleep(0.01) - blocked_before_stream_done = not competing_call.done() + competing_call_done_before_stream_done = competing_call.done() remaining: list[str] = [] async for item in stream: remaining.append(item) await competing_call - return first, remaining, blocked_before_stream_done + return first, remaining, competing_call_done_before_stream_done - first, remaining, blocked = asyncio.run(scenario()) + first, remaining, was_unblocked = asyncio.run(scenario()) assert first == "first" assert remaining == ["second", "third"] - assert blocked + assert was_unblocked + + +def test_async_client_turn_notification_methods_delegate_to_sync_client() -> None: + async def scenario() -> tuple[list[tuple[str, str]], Notification, str]: + client = AsyncAppServerClient() + event = Notification( + method="unknown/direct", + payload=UnknownNotification(params={"turnId": "turn-1"}), + ) + completed = TurnCompletedNotification.model_validate( + { + "threadId": "thread-1", + "turn": {"id": "turn-1", "items": [], "status": "completed"}, + } + ) + calls: list[tuple[str, str]] = [] + + def fake_register(turn_id: str) -> None: + calls.append(("register", turn_id)) + + def fake_unregister(turn_id: str) -> None: + calls.append(("unregister", turn_id)) + + def fake_next(turn_id: str) -> Notification: + calls.append(("next", turn_id)) + return event + + def fake_wait(turn_id: str) -> TurnCompletedNotification: + calls.append(("wait", turn_id)) + return completed + + client._sync.register_turn_notifications = fake_register # type: ignore[method-assign] + client._sync.unregister_turn_notifications = fake_unregister # type: ignore[method-assign] + client._sync.next_turn_notification = fake_next # type: ignore[method-assign] + client._sync.wait_for_turn_completed = fake_wait # type: ignore[method-assign] + + client.register_turn_notifications("turn-1") + next_event = await client.next_turn_notification("turn-1") + completed_event = await client.wait_for_turn_completed("turn-1") + client.unregister_turn_notifications("turn-1") + + return calls, next_event, completed_event.turn.id + + calls, next_event, completed_turn_id = asyncio.run(scenario()) + + assert ( + calls, + next_event, + completed_turn_id, + ) == ( + [ + ("register", "turn-1"), + ("next", "turn-1"), + ("wait", "turn-1"), + ("unregister", "turn-1"), + ], + Notification( + method="unknown/direct", + payload=UnknownNotification(params={"turnId": "turn-1"}), + ), + "turn-1", + ) + + +def test_async_stream_text_uses_sync_turn_routing() -> None: + async def scenario() -> tuple[list[tuple[str, str]], list[str]]: + client = AsyncAppServerClient() + notifications = [ + Notification( + method="item/agentMessage/delta", + payload=AgentMessageDeltaNotification.model_validate( + { + "delta": "first", + "itemId": "item-1", + "threadId": "thread-1", + "turnId": "turn-1", + } + ), + ), + Notification( + method="item/agentMessage/delta", + payload=AgentMessageDeltaNotification.model_validate( + { + "delta": "second", + "itemId": "item-2", + "threadId": "thread-1", + "turnId": "turn-1", + } + ), + ), + Notification( + method="turn/completed", + payload=TurnCompletedNotification.model_validate( + { + "threadId": "thread-1", + "turn": {"id": "turn-1", "items": [], "status": "completed"}, + } + ), + ), + ] + calls: list[tuple[str, str]] = [] + + def fake_turn_start(thread_id: str, text: str, *, params=None): # type: ignore[no-untyped-def] + calls.append(("turn_start", thread_id)) + return SimpleNamespace(turn=SimpleNamespace(id="turn-1")) + + def fake_register(turn_id: str) -> None: + calls.append(("register", turn_id)) + + def fake_next(turn_id: str) -> Notification: + calls.append(("next", turn_id)) + return notifications.pop(0) + + def fake_unregister(turn_id: str) -> None: + calls.append(("unregister", turn_id)) + + client._sync.turn_start = fake_turn_start # type: ignore[method-assign] + client._sync.register_turn_notifications = fake_register # type: ignore[method-assign] + client._sync.next_turn_notification = fake_next # type: ignore[method-assign] + client._sync.unregister_turn_notifications = fake_unregister # type: ignore[method-assign] + + chunks = [chunk async for chunk in client.stream_text("thread-1", "hello")] + return calls, [chunk.delta for chunk in chunks] + + calls, deltas = asyncio.run(scenario()) + + assert (calls, deltas) == ( + [ + ("turn_start", "thread-1"), + ("register", "turn-1"), + ("next", "turn-1"), + ("next", "turn-1"), + ("next", "turn-1"), + ("unregister", "turn-1"), + ], + ["first", "second"], + ) diff --git a/sdk/python/tests/test_client_rpc_methods.py b/sdk/python/tests/test_client_rpc_methods.py index f1dd353b1f..07b88215a8 100644 --- a/sdk/python/tests/test_client_rpc_methods.py +++ b/sdk/python/tests/test_client_rpc_methods.py @@ -4,13 +4,17 @@ from pathlib import Path from typing import Any from codex_app_server.client import AppServerClient, _params_dict +from codex_app_server.generated.notification_registry import notification_turn_id from codex_app_server.generated.v2_all import ( + AgentMessageDeltaNotification, ApprovalsReviewer, ThreadListParams, ThreadResumeResponse, ThreadTokenUsageUpdatedNotification, + TurnCompletedNotification, + WarningNotification, ) -from codex_app_server.models import UnknownNotification +from codex_app_server.models import Notification, UnknownNotification ROOT = Path(__file__).resolve().parents[1] @@ -128,3 +132,220 @@ def test_invalid_notification_payload_falls_back_to_unknown() -> None: assert event.method == "thread/tokenUsage/updated" assert isinstance(event.payload, UnknownNotification) + + +def test_generated_notification_turn_id_handles_known_payload_shapes() -> None: + direct = AgentMessageDeltaNotification.model_validate( + { + "delta": "hello", + "itemId": "item-1", + "threadId": "thread-1", + "turnId": "turn-1", + } + ) + nested = TurnCompletedNotification.model_validate( + { + "threadId": "thread-1", + "turn": {"id": "turn-2", "items": [], "status": "completed"}, + } + ) + unscoped = WarningNotification(message="heads up") + + assert [ + notification_turn_id(direct), + notification_turn_id(nested), + notification_turn_id(unscoped), + ] == ["turn-1", "turn-2", None] + + +def test_turn_notification_router_demuxes_registered_turns() -> None: + client = AppServerClient() + client.register_turn_notifications("turn-1") + client.register_turn_notifications("turn-2") + + client._router.route_notification( + client._coerce_notification( + "item/agentMessage/delta", + { + "delta": "two", + "itemId": "item-2", + "threadId": "thread-1", + "turnId": "turn-2", + }, + ) + ) + client._router.route_notification( + client._coerce_notification( + "item/agentMessage/delta", + { + "delta": "one", + "itemId": "item-1", + "threadId": "thread-1", + "turnId": "turn-1", + }, + ) + ) + + first = client.next_turn_notification("turn-1") + second = client.next_turn_notification("turn-2") + + assert isinstance(first.payload, AgentMessageDeltaNotification) + assert isinstance(second.payload, AgentMessageDeltaNotification) + assert [ + (first.method, first.payload.delta), + (second.method, second.payload.delta), + ] == [ + ("item/agentMessage/delta", "one"), + ("item/agentMessage/delta", "two"), + ] + + +def test_client_reader_routes_interleaved_turn_notifications_by_turn_id() -> None: + client = AppServerClient() + client.register_turn_notifications("turn-1") + client.register_turn_notifications("turn-2") + + messages: list[dict[str, object]] = [ + { + "method": "item/agentMessage/delta", + "params": { + "delta": "one-a", + "itemId": "item-1", + "threadId": "thread-1", + "turnId": "turn-1", + }, + }, + { + "method": "item/agentMessage/delta", + "params": { + "delta": "two-a", + "itemId": "item-2", + "threadId": "thread-1", + "turnId": "turn-2", + }, + }, + { + "method": "item/agentMessage/delta", + "params": { + "delta": "one-b", + "itemId": "item-3", + "threadId": "thread-1", + "turnId": "turn-1", + }, + }, + { + "method": "item/agentMessage/delta", + "params": { + "delta": "two-b", + "itemId": "item-4", + "threadId": "thread-1", + "turnId": "turn-2", + }, + }, + ] + + def fake_read_message() -> dict[str, object]: + if messages: + return messages.pop(0) + raise EOFError + + client._read_message = fake_read_message # type: ignore[method-assign] + client._reader_loop() + + first_turn_events = [ + client.next_turn_notification("turn-1"), + client.next_turn_notification("turn-1"), + ] + second_turn_events = [ + client.next_turn_notification("turn-2"), + client.next_turn_notification("turn-2"), + ] + + first_turn_deltas = [ + event.payload.delta + for event in first_turn_events + if isinstance(event.payload, AgentMessageDeltaNotification) + ] + second_turn_deltas = [ + event.payload.delta + for event in second_turn_events + if isinstance(event.payload, AgentMessageDeltaNotification) + ] + assert (first_turn_deltas, second_turn_deltas) == ( + ["one-a", "one-b"], + ["two-a", "two-b"], + ) + + +def test_turn_notification_router_buffers_events_before_registration() -> None: + client = AppServerClient() + client._router.route_notification( + client._coerce_notification( + "item/agentMessage/delta", + { + "delta": "early", + "itemId": "item-1", + "threadId": "thread-1", + "turnId": "turn-1", + }, + ) + ) + + client.register_turn_notifications("turn-1") + event = client.next_turn_notification("turn-1") + + assert isinstance(event.payload, AgentMessageDeltaNotification) + assert (event.method, event.payload.delta) == ( + "item/agentMessage/delta", + "early", + ) + + +def test_turn_notification_router_clears_unregistered_turn_when_completed() -> None: + client = AppServerClient() + client._router.route_notification( + client._coerce_notification( + "item/agentMessage/delta", + { + "delta": "early", + "itemId": "item-1", + "threadId": "thread-1", + "turnId": "turn-1", + }, + ) + ) + client._router.route_notification( + client._coerce_notification( + "turn/completed", + { + "threadId": "thread-1", + "turn": {"id": "turn-1", "items": [], "status": "completed"}, + }, + ) + ) + + assert client._router._pending_turn_notifications == {} + + +def test_turn_notification_router_routes_unknown_turn_notifications() -> None: + client = AppServerClient() + client.register_turn_notifications("turn-1") + client.register_turn_notifications("turn-2") + + client._router.route_notification( + Notification( + method="unknown/direct", + payload=UnknownNotification(params={"turnId": "turn-1"}), + ) + ) + client._router.route_notification( + Notification( + method="unknown/nested", + payload=UnknownNotification(params={"turn": {"id": "turn-2"}}), + ) + ) + + first = client.next_turn_notification("turn-1") + second = client.next_turn_notification("turn-2") + + assert [first.method, second.method] == ["unknown/direct", "unknown/nested"] diff --git a/sdk/python/tests/test_public_api_runtime_behavior.py b/sdk/python/tests/test_public_api_runtime_behavior.py index 10865cf879..4f0fc45a7c 100644 --- a/sdk/python/tests/test_public_api_runtime_behavior.py +++ b/sdk/python/tests/test_public_api_runtime_behavior.py @@ -226,54 +226,74 @@ def test_async_codex_initializes_only_once_under_concurrency() -> None: asyncio.run(scenario()) -def test_turn_stream_rejects_second_active_consumer() -> None: +def test_turn_streams_can_consume_multiple_turns_on_one_client() -> None: client = AppServerClient() - notifications: deque[Notification] = deque( - [ - _delta_notification(turn_id="turn-1"), - _completed_notification(turn_id="turn-1"), - ] - ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + notifications: dict[str, deque[Notification]] = { + "turn-1": deque( + [ + _delta_notification(turn_id="turn-1", text="one"), + _completed_notification(turn_id="turn-1"), + ] + ), + "turn-2": deque( + [ + _delta_notification(turn_id="turn-2", text="two"), + _completed_notification(turn_id="turn-2"), + ] + ), + } + client.next_turn_notification = lambda turn_id: notifications[turn_id].popleft() # type: ignore[method-assign] first_stream = TurnHandle(client, "thread-1", "turn-1").stream() assert next(first_stream).method == "item/agentMessage/delta" second_stream = TurnHandle(client, "thread-1", "turn-2").stream() - with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"): - next(second_stream) + assert next(second_stream).method == "item/agentMessage/delta" + assert next(first_stream).method == "turn/completed" + assert next(second_stream).method == "turn/completed" first_stream.close() + second_stream.close() -def test_async_turn_stream_rejects_second_active_consumer() -> None: +def test_async_turn_streams_can_consume_multiple_turns_on_one_client() -> None: async def scenario() -> None: codex = AsyncCodex() async def fake_ensure_initialized() -> None: return None - notifications: deque[Notification] = deque( - [ - _delta_notification(turn_id="turn-1"), - _completed_notification(turn_id="turn-1"), - ] - ) + notifications: dict[str, deque[Notification]] = { + "turn-1": deque( + [ + _delta_notification(turn_id="turn-1", text="one"), + _completed_notification(turn_id="turn-1"), + ] + ), + "turn-2": deque( + [ + _delta_notification(turn_id="turn-2", text="two"), + _completed_notification(turn_id="turn-2"), + ] + ), + } - async def fake_next_notification() -> Notification: - return notifications.popleft() + async def fake_next_notification(turn_id: str) -> Notification: + return notifications[turn_id].popleft() codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign] - codex._client.next_notification = fake_next_notification # type: ignore[method-assign] + codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign] first_stream = AsyncTurnHandle(codex, "thread-1", "turn-1").stream() assert (await anext(first_stream)).method == "item/agentMessage/delta" second_stream = AsyncTurnHandle(codex, "thread-1", "turn-2").stream() - with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"): - await anext(second_stream) + assert (await anext(second_stream)).method == "item/agentMessage/delta" + assert (await anext(first_stream)).method == "turn/completed" + assert (await anext(second_stream)).method == "turn/completed" await first_stream.aclose() + await second_stream.aclose() asyncio.run(scenario()) @@ -285,7 +305,7 @@ def test_turn_run_returns_completed_turn_payload() -> None: _completed_notification(), ] ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign] result = TurnHandle(client, "thread-1", "turn-1").run() @@ -305,7 +325,7 @@ def test_thread_run_accepts_string_input_and_returns_run_result() -> None: _completed_notification(), ] ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign] seen: dict[str, object] = {} def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202 @@ -338,7 +358,7 @@ def test_thread_run_uses_last_completed_assistant_message_as_final_response() -> _completed_notification(), ] ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign] client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731 turn=SimpleNamespace(id="turn-1") ) @@ -363,7 +383,7 @@ def test_thread_run_preserves_empty_last_assistant_message() -> None: _completed_notification(), ] ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign] client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731 turn=SimpleNamespace(id="turn-1") ) @@ -394,7 +414,7 @@ def test_thread_run_prefers_explicit_final_answer_over_later_commentary() -> Non _completed_notification(), ] ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign] client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731 turn=SimpleNamespace(id="turn-1") ) @@ -420,7 +440,7 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non _completed_notification(), ] ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign] client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731 turn=SimpleNamespace(id="turn-1") ) @@ -438,7 +458,7 @@ def test_thread_run_raises_on_failed_turn() -> None: _completed_notification(status="failed", error_message="boom"), ] ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign] client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731 turn=SimpleNamespace(id="turn-1") ) @@ -447,6 +467,48 @@ def test_thread_run_raises_on_failed_turn() -> None: Thread(client, "thread-1").run("hello") +def test_stream_text_registers_and_consumes_turn_notifications() -> None: + client = AppServerClient() + notifications: deque[Notification] = deque( + [ + _delta_notification(text="first"), + _delta_notification(text="second"), + _completed_notification(), + ] + ) + calls: list[tuple[str, str]] = [] + client.turn_start = lambda thread_id, input_items, *, params=None: SimpleNamespace( # noqa: ARG005,E731 + turn=SimpleNamespace(id="turn-1") + ) + + def fake_register(turn_id: str) -> None: + calls.append(("register", turn_id)) + + def fake_next(turn_id: str) -> Notification: + calls.append(("next", turn_id)) + return notifications.popleft() + + def fake_unregister(turn_id: str) -> None: + calls.append(("unregister", turn_id)) + + client.register_turn_notifications = fake_register # type: ignore[method-assign] + client.next_turn_notification = fake_next # type: ignore[method-assign] + client.unregister_turn_notifications = fake_unregister # type: ignore[method-assign] + + chunks = list(client.stream_text("thread-1", "hello")) + + assert ([chunk.delta for chunk in chunks], calls) == ( + ["first", "second"], + [ + ("register", "turn-1"), + ("next", "turn-1"), + ("next", "turn-1"), + ("next", "turn-1"), + ("unregister", "turn-1"), + ], + ) + + def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None: async def scenario() -> None: codex = AsyncCodex() @@ -471,12 +533,12 @@ def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None: seen["params"] = params return SimpleNamespace(turn=SimpleNamespace(id="turn-1")) - async def fake_next_notification() -> Notification: + async def fake_next_notification(_turn_id: str) -> Notification: return notifications.popleft() codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign] codex._client.turn_start = fake_turn_start # type: ignore[method-assign] - codex._client.next_notification = fake_next_notification # type: ignore[method-assign] + codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign] result = await AsyncThread(codex, "thread-1").run("hello") @@ -491,15 +553,21 @@ def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None: asyncio.run(scenario()) -def test_async_thread_run_uses_last_completed_assistant_message_as_final_response() -> None: +def test_async_thread_run_uses_last_completed_assistant_message_as_final_response() -> ( + None +): async def scenario() -> None: codex = AsyncCodex() async def fake_ensure_initialized() -> None: return None - first_item_notification = _item_completed_notification(text="First async message") - second_item_notification = _item_completed_notification(text="Second async message") + first_item_notification = _item_completed_notification( + text="First async message" + ) + second_item_notification = _item_completed_notification( + text="Second async message" + ) notifications: deque[Notification] = deque( [ first_item_notification, @@ -511,12 +579,12 @@ def test_async_thread_run_uses_last_completed_assistant_message_as_final_respons async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001 return SimpleNamespace(turn=SimpleNamespace(id="turn-1")) - async def fake_next_notification() -> Notification: + async def fake_next_notification(_turn_id: str) -> Notification: return notifications.popleft() codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign] codex._client.turn_start = fake_turn_start # type: ignore[method-assign] - codex._client.next_notification = fake_next_notification # type: ignore[method-assign] + codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign] result = await AsyncThread(codex, "thread-1").run("hello") @@ -550,12 +618,12 @@ def test_async_thread_run_returns_none_when_only_commentary_messages_complete() async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001 return SimpleNamespace(turn=SimpleNamespace(id="turn-1")) - async def fake_next_notification() -> Notification: + async def fake_next_notification(_turn_id: str) -> Notification: return notifications.popleft() codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign] codex._client.turn_start = fake_turn_start # type: ignore[method-assign] - codex._client.next_notification = fake_next_notification # type: ignore[method-assign] + codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign] result = await AsyncThread(codex, "thread-1").run("hello")