diff --git a/sdk/python/docs/api-reference.md b/sdk/python/docs/api-reference.md index ddeaf39cd0..c7a763498d 100644 --- a/sdk/python/docs/api-reference.md +++ b/sdk/python/docs/api-reference.md @@ -2,7 +2,7 @@ Public surface of `codex_app_server` for app-server v2. -This SDK surface is experimental. The current implementation intentionally allows only one active turn consumer (`Thread.run()`, `TurnHandle.stream()`, or `TurnHandle.run()`) per client instance at a time. +This SDK surface is experimental. Turn streams are routed by turn ID so one client can consume multiple active turns concurrently. ## Package Entry @@ -137,8 +137,8 @@ Use `turn(...)` when you need low-level turn control (`stream()`, `steer()`, Behavior notes: -- `stream()` and `run()` are exclusive per client instance in the current experimental build -- starting a second turn consumer on the same `Codex` instance raises `RuntimeError` +- `stream()` and `run()` consume only notifications for their own turn ID +- one `Codex` instance can stream multiple active turns concurrently ### AsyncTurnHandle @@ -149,8 +149,8 @@ Behavior notes: Behavior notes: -- `stream()` and `run()` are exclusive per client instance in the current experimental build -- starting a second turn consumer on the same `AsyncCodex` instance raises `RuntimeError` +- `stream()` and `run()` consume only notifications for their own turn ID +- one `AsyncCodex` instance can stream multiple active turns concurrently ## Inputs diff --git a/sdk/python/docs/getting-started.md b/sdk/python/docs/getting-started.md index 45ad1eb51f..1794d39f70 100644 --- a/sdk/python/docs/getting-started.md +++ b/sdk/python/docs/getting-started.md @@ -45,7 +45,7 @@ What happened: - `thread.run("...")` started a turn, consumed events until completion, and returned the final assistant response plus collected items and usage. - `result.final_response` is `None` when no final-answer or phase-less assistant message item completes for the turn. - use `thread.turn(...)` when you need a `TurnHandle` for streaming, steering, interrupting, or turn IDs/status -- one client can have only one active turn consumer (`thread.run(...)`, `TurnHandle.stream()`, or `TurnHandle.run()`) at a time in the current experimental build +- one client can consume multiple active turns concurrently; turn streams are routed by turn ID ## 3) Continue the same thread (multi-turn) diff --git a/sdk/python/src/codex_app_server/api.py b/sdk/python/src/codex_app_server/api.py index 2c71859cc8..31ec0685a7 100644 --- a/sdk/python/src/codex_app_server/api.py +++ b/sdk/python/src/codex_app_server/api.py @@ -38,14 +38,14 @@ from .generated.v2_all import ( ) from .models import InitializeResponse, JsonObject, Notification, ServerInfo from ._inputs import ( - ImageInput, + ImageInput, # noqa: F401 Input, - InputItem, - LocalImageInput, - MentionInput, + InputItem, # noqa: F401 + LocalImageInput, # noqa: F401 + MentionInput, # noqa: F401 RunInput, - SkillInput, - TextInput, + SkillInput, # noqa: F401 + TextInput, # noqa: F401 _normalize_run_input, _to_wire_input, ) @@ -274,6 +274,7 @@ class Codex: def thread_unarchive(self, thread_id: str) -> Thread: unarchived = self._client.thread_unarchive(thread_id) return Thread(self._client, unarchived.thread.id) + # END GENERATED: Codex.flat_methods def models(self, *, include_hidden: bool = False) -> ModelListResponse: @@ -476,6 +477,7 @@ class AsyncCodex: await self._ensure_initialized() unarchived = await self._client.thread_unarchive(thread_id) return AsyncThread(self, unarchived.thread.id) + # END GENERATED: AsyncCodex.flat_methods async def models(self, *, include_hidden: bool = False) -> ModelListResponse: @@ -555,6 +557,7 @@ class Thread: ) turn = self._client.turn_start(self.id, wire_input, params=params) return TurnHandle(self._client, self.id, turn.turn.id) + # END GENERATED: Thread.flat_methods def read(self, *, include_turns: bool = False) -> ThreadReadResponse: @@ -644,6 +647,7 @@ class AsyncThread: params=params, ) return AsyncTurnHandle(self._codex, self.id, turn.turn.id) + # END GENERATED: AsyncThread.flat_methods async def read(self, *, include_turns: bool = False) -> ThreadReadResponse: @@ -674,11 +678,10 @@ class TurnHandle: return self._client.turn_interrupt(self.thread_id, self.id) def stream(self) -> Iterator[Notification]: - # TODO: replace this client-wide experimental guard with per-turn event demux. - self._client.acquire_turn_consumer(self.id) + self._client.register_turn_notifications(self.id) try: while True: - event = self._client.next_notification() + event = self._client.next_turn_notification(self.id) yield event if ( event.method == "turn/completed" @@ -687,7 +690,7 @@ class TurnHandle: ): break finally: - self._client.release_turn_consumer(self.id) + self._client.unregister_turn_notifications(self.id) def run(self) -> AppServerTurn: completed: TurnCompletedNotification | None = None @@ -728,11 +731,10 @@ class AsyncTurnHandle: async def stream(self) -> AsyncIterator[Notification]: await self._codex._ensure_initialized() - # TODO: replace this client-wide experimental guard with per-turn event demux. - self._codex._client.acquire_turn_consumer(self.id) + self._codex._client.register_turn_notifications(self.id) try: while True: - event = await self._codex._client.next_notification() + event = await self._codex._client.next_turn_notification(self.id) yield event if ( event.method == "turn/completed" @@ -741,7 +743,7 @@ class AsyncTurnHandle: ): break finally: - self._codex._client.release_turn_consumer(self.id) + self._codex._client.unregister_turn_notifications(self.id) async def run(self) -> AppServerTurn: completed: TurnCompletedNotification | None = None diff --git a/sdk/python/src/codex_app_server/async_client.py b/sdk/python/src/codex_app_server/async_client.py index 6ca0c42a78..0fd289c95d 100644 --- a/sdk/python/src/codex_app_server/async_client.py +++ b/sdk/python/src/codex_app_server/async_client.py @@ -79,11 +79,11 @@ class AsyncAppServerClient: async def initialize(self) -> InitializeResponse: return await self._call_sync(self._sync.initialize) - def acquire_turn_consumer(self, turn_id: str) -> None: - self._sync.acquire_turn_consumer(turn_id) + def register_turn_notifications(self, turn_id: str) -> None: + self._sync.register_turn_notifications(turn_id) - def release_turn_consumer(self, turn_id: str) -> None: - self._sync.release_turn_consumer(turn_id) + def unregister_turn_notifications(self, turn_id: str) -> None: + self._sync.unregister_turn_notifications(turn_id) async def request( self, @@ -99,7 +99,9 @@ class AsyncAppServerClient: response_model=response_model, ) - async def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse: + async def thread_start( + self, params: V2ThreadStartParams | JsonObject | None = None + ) -> ThreadStartResponse: return await self._call_sync(self._sync.thread_start, params) async def thread_resume( @@ -109,10 +111,14 @@ class AsyncAppServerClient: ) -> ThreadResumeResponse: return await self._call_sync(self._sync.thread_resume, thread_id, params) - async def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse: + async def thread_list( + self, params: V2ThreadListParams | JsonObject | None = None + ) -> ThreadListResponse: return await self._call_sync(self._sync.thread_list, params) - async def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse: + async def thread_read( + self, thread_id: str, include_turns: bool = False + ) -> ThreadReadResponse: return await self._call_sync(self._sync.thread_read, thread_id, include_turns) async def thread_fork( @@ -140,9 +146,13 @@ class AsyncAppServerClient: input_items: list[JsonObject] | JsonObject | str, params: V2TurnStartParams | JsonObject | None = None, ) -> TurnStartResponse: - return await self._call_sync(self._sync.turn_start, thread_id, input_items, params) + return await self._call_sync( + self._sync.turn_start, thread_id, input_items, params + ) - async def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse: + async def turn_interrupt( + self, thread_id: str, turn_id: str + ) -> TurnInterruptResponse: return await self._call_sync(self._sync.turn_interrupt, thread_id, turn_id) async def turn_steer( @@ -184,10 +194,15 @@ class AsyncAppServerClient: async def next_notification(self) -> Notification: return await self._call_sync(self._sync.next_notification) + async def next_turn_notification(self, turn_id: str) -> Notification: + return await self._call_sync(self._sync.next_turn_notification, turn_id) + async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification: return await self._call_sync(self._sync.wait_for_turn_completed, turn_id) - async def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]: + async def stream_until_methods( + self, methods: Iterable[str] | str + ) -> list[Notification]: return await self._call_sync(self._sync.stream_until_methods, methods) async def stream_text( diff --git a/sdk/python/src/codex_app_server/client.py b/sdk/python/src/codex_app_server/client.py index 665e1c6725..9b46e6cd84 100644 --- a/sdk/python/src/codex_app_server/client.py +++ b/sdk/python/src/codex_app_server/client.py @@ -2,6 +2,7 @@ from __future__ import annotations import json import os +import queue import subprocess import threading import uuid @@ -49,6 +50,8 @@ from ._version import __version__ as SDK_VERSION ModelT = TypeVar("ModelT", bound=BaseModel) ApprovalHandler = Callable[[str, JsonObject | None], JsonObject] RUNTIME_PKG_NAME = "openai-codex-cli-bin" +ResponseQueueItem = JsonValue | BaseException +NotificationQueueItem = Notification | BaseException def _params_dict( @@ -75,7 +78,9 @@ def _params_dict( return dumped if isinstance(params, dict): return params - raise TypeError(f"Expected generated params model or dict, got {type(params).__name__}") + raise TypeError( + f"Expected generated params model or dict, got {type(params).__name__}" + ) def _installed_codex_path() -> Path: @@ -146,11 +151,14 @@ class AppServerClient: self._approval_handler = approval_handler or self._default_approval_handler self._proc: subprocess.Popen[str] | None = None self._lock = threading.Lock() - self._turn_consumer_lock = threading.Lock() - self._active_turn_consumer: str | None = None - self._pending_notifications: deque[Notification] = deque() + self._router_lock = threading.Lock() + self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {} + self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {} + self._pending_turn_notifications: dict[str, deque[Notification]] = {} + self._global_notifications: queue.Queue[NotificationQueueItem] = queue.Queue() self._stderr_lines: deque[str] = deque(maxlen=400) self._stderr_thread: threading.Thread | None = None + self._reader_thread: threading.Thread | None = None def __enter__(self) -> "AppServerClient": self.start() @@ -189,13 +197,13 @@ class AppServerClient: ) self._start_stderr_drain_thread() + self._start_reader_thread() def close(self) -> None: if self._proc is None: return proc = self._proc self._proc = None - self._active_turn_consumer = None if proc.stdin: proc.stdin.close() @@ -207,6 +215,8 @@ class AppServerClient: if self._stderr_thread and self._stderr_thread.is_alive(): self._stderr_thread.join(timeout=0.5) + if self._reader_thread and self._reader_thread.is_alive(): + self._reader_thread.join(timeout=0.5) def initialize(self) -> InitializeResponse: result = self.request( @@ -240,70 +250,63 @@ class AppServerClient: def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue: request_id = str(uuid.uuid4()) - self._write_message({"id": request_id, "method": method, "params": params or {}}) + waiter: queue.Queue[ResponseQueueItem] = queue.Queue(maxsize=1) + with self._router_lock: + self._response_waiters[request_id] = waiter - while True: - msg = self._read_message() + try: + self._write_message( + {"id": request_id, "method": method, "params": params or {}} + ) + except BaseException: + with self._router_lock: + self._response_waiters.pop(request_id, None) + raise - if "method" in msg and "id" in msg: - response = self._handle_server_request(msg) - self._write_message({"id": msg["id"], "result": response}) - continue - - if "method" in msg and "id" not in msg: - self._pending_notifications.append( - self._coerce_notification(msg["method"], msg.get("params")) - ) - continue - - if msg.get("id") != request_id: - continue - - if "error" in msg: - err = msg["error"] - if isinstance(err, dict): - raise map_jsonrpc_error( - int(err.get("code", -32000)), - str(err.get("message", "unknown")), - err.get("data"), - ) - raise AppServerError("Malformed JSON-RPC error response") - - return msg.get("result") + item = waiter.get() + if isinstance(item, BaseException): + raise item + return item def notify(self, method: str, params: JsonObject | None = None) -> None: self._write_message({"method": method, "params": params or {}}) def next_notification(self) -> Notification: - if self._pending_notifications: - return self._pending_notifications.popleft() + item = self._global_notifications.get() + if isinstance(item, BaseException): + raise item + return item - while True: - msg = self._read_message() - if "method" in msg and "id" in msg: - response = self._handle_server_request(msg) - self._write_message({"id": msg["id"], "result": response}) - continue - if "method" in msg and "id" not in msg: - return self._coerce_notification(msg["method"], msg.get("params")) + def register_turn_notifications(self, turn_id: str) -> None: + turn_queue: queue.Queue[NotificationQueueItem] = queue.Queue() + with self._router_lock: + if turn_id in self._turn_notifications: + return + pending = self._pending_turn_notifications.pop(turn_id, deque()) + self._turn_notifications[turn_id] = turn_queue + for notification in pending: + turn_queue.put(notification) - def acquire_turn_consumer(self, turn_id: str) -> None: - with self._turn_consumer_lock: - if self._active_turn_consumer is not None: - raise RuntimeError( - "Concurrent turn consumers are not yet supported in the experimental SDK. " - f"Client is already streaming turn {self._active_turn_consumer!r}; " - f"cannot start turn {turn_id!r} until the active consumer finishes." - ) - self._active_turn_consumer = turn_id + def unregister_turn_notifications(self, turn_id: str) -> None: + with self._router_lock: + self._turn_notifications.pop(turn_id, None) - def release_turn_consumer(self, turn_id: str) -> None: - with self._turn_consumer_lock: - if self._active_turn_consumer == turn_id: - self._active_turn_consumer = None + def next_turn_notification(self, turn_id: str) -> Notification: + with self._router_lock: + turn_queue = self._turn_notifications.get(turn_id) + if turn_queue is None: + raise RuntimeError(f"turn {turn_id!r} is not registered for streaming") + item = turn_queue.get() + if isinstance(item, BaseException): + raise item + return item - def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse: - return self.request("thread/start", _params_dict(params), response_model=ThreadStartResponse) + def thread_start( + self, params: V2ThreadStartParams | JsonObject | None = None + ) -> ThreadStartResponse: + return self.request( + "thread/start", _params_dict(params), response_model=ThreadStartResponse + ) def thread_resume( self, @@ -311,12 +314,20 @@ class AppServerClient: params: V2ThreadResumeParams | JsonObject | None = None, ) -> ThreadResumeResponse: payload = {"threadId": thread_id, **_params_dict(params)} - return self.request("thread/resume", payload, response_model=ThreadResumeResponse) + return self.request( + "thread/resume", payload, response_model=ThreadResumeResponse + ) - def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse: - return self.request("thread/list", _params_dict(params), response_model=ThreadListResponse) + def thread_list( + self, params: V2ThreadListParams | JsonObject | None = None + ) -> ThreadListResponse: + return self.request( + "thread/list", _params_dict(params), response_model=ThreadListResponse + ) - def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse: + def thread_read( + self, thread_id: str, include_turns: bool = False + ) -> ThreadReadResponse: return self.request( "thread/read", {"threadId": thread_id, "includeTurns": include_turns}, @@ -332,10 +343,18 @@ class AppServerClient: return self.request("thread/fork", payload, response_model=ThreadForkResponse) def thread_archive(self, thread_id: str) -> ThreadArchiveResponse: - return self.request("thread/archive", {"threadId": thread_id}, response_model=ThreadArchiveResponse) + return self.request( + "thread/archive", + {"threadId": thread_id}, + response_model=ThreadArchiveResponse, + ) def thread_unarchive(self, thread_id: str) -> ThreadUnarchiveResponse: - return self.request("thread/unarchive", {"threadId": thread_id}, response_model=ThreadUnarchiveResponse) + return self.request( + "thread/unarchive", + {"threadId": thread_id}, + response_model=ThreadUnarchiveResponse, + ) def thread_set_name(self, thread_id: str, name: str) -> ThreadSetNameResponse: return self.request( @@ -362,7 +381,9 @@ class AppServerClient: "threadId": thread_id, "input": self._normalize_input_items(input_items), } - return self.request("turn/start", payload, response_model=TurnStartResponse) + started = self.request("turn/start", payload, response_model=TurnStartResponse) + self.register_turn_notifications(started.turn.id) + return started def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse: return self.request( @@ -412,14 +433,18 @@ class AppServerClient: ) def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification: - while True: - notification = self.next_notification() - if ( - notification.method == "turn/completed" - and isinstance(notification.payload, TurnCompletedNotification) - and notification.payload.turn.id == turn_id - ): - return notification.payload + self.register_turn_notifications(turn_id) + try: + while True: + notification = self.next_turn_notification(turn_id) + if ( + notification.method == "turn/completed" + and isinstance(notification.payload, TurnCompletedNotification) + and notification.payload.turn.id == turn_id + ): + return notification.payload + finally: + self.unregister_turn_notifications(turn_id) def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]: target_methods = {methods} if isinstance(methods, str) else set(methods) @@ -438,33 +463,40 @@ class AppServerClient: ) -> Iterator[AgentMessageDeltaNotification]: started = self.turn_start(thread_id, text, params=params) turn_id = started.turn.id - while True: - notification = self.next_notification() - if ( - notification.method == "item/agentMessage/delta" - and isinstance(notification.payload, AgentMessageDeltaNotification) - and notification.payload.turn_id == turn_id - ): - yield notification.payload - continue - if ( - notification.method == "turn/completed" - and isinstance(notification.payload, TurnCompletedNotification) - and notification.payload.turn.id == turn_id - ): - break + try: + while True: + notification = self.next_turn_notification(turn_id) + if ( + notification.method == "item/agentMessage/delta" + and isinstance(notification.payload, AgentMessageDeltaNotification) + and notification.payload.turn_id == turn_id + ): + yield notification.payload + continue + if ( + notification.method == "turn/completed" + and isinstance(notification.payload, TurnCompletedNotification) + and notification.payload.turn.id == turn_id + ): + break + finally: + self.unregister_turn_notifications(turn_id) def _coerce_notification(self, method: str, params: object) -> Notification: params_dict = params if isinstance(params, dict) else {} model = NOTIFICATION_MODELS.get(method) if model is None: - return Notification(method=method, payload=UnknownNotification(params=params_dict)) + return Notification( + method=method, payload=UnknownNotification(params=params_dict) + ) try: payload = model.model_validate(params_dict) except Exception: # noqa: BLE001 - return Notification(method=method, payload=UnknownNotification(params=params_dict)) + return Notification( + method=method, payload=UnknownNotification(params=params_dict) + ) return Notification(method=method, payload=payload) def _normalize_input_items( @@ -477,7 +509,9 @@ class AppServerClient: return [input_items] return input_items - def _default_approval_handler(self, method: str, params: JsonObject | None) -> JsonObject: + def _default_approval_handler( + self, method: str, params: JsonObject | None + ) -> JsonObject: if method == "item/commandExecution/requestApproval": return {"decision": "accept"} if method == "item/fileChange/requestApproval": @@ -498,9 +532,95 @@ class AppServerClient: self._stderr_thread = threading.Thread(target=_drain, daemon=True) self._stderr_thread.start() + def _start_reader_thread(self) -> None: + if self._proc is None or self._proc.stdout is None: + return + + self._reader_thread = threading.Thread(target=self._reader_loop, daemon=True) + self._reader_thread.start() + + def _reader_loop(self) -> None: + try: + while True: + msg = self._read_message() + if "method" in msg and "id" in msg: + response = self._handle_server_request(msg) + self._write_message({"id": msg["id"], "result": response}) + continue + if "method" in msg and "id" not in msg: + method = msg["method"] + if isinstance(method, str): + self._route_notification( + self._coerce_notification(method, msg.get("params")) + ) + continue + self._route_response(msg) + except BaseException as exc: + self._fail_pending_queues(exc) + def _stderr_tail(self, limit: int = 40) -> str: return "\n".join(list(self._stderr_lines)[-limit:]) + def _route_response(self, msg: dict[str, JsonValue]) -> None: + request_id = msg.get("id") + with self._router_lock: + waiter = self._response_waiters.pop(str(request_id), None) + if waiter is None: + return + + if "error" in msg: + err = msg["error"] + if isinstance(err, dict): + waiter.put( + map_jsonrpc_error( + int(err.get("code", -32000)), + str(err.get("message", "unknown")), + err.get("data"), + ) + ) + else: + waiter.put(AppServerError("Malformed JSON-RPC error response")) + return + + waiter.put(msg.get("result")) + + def _route_notification(self, notification: Notification) -> None: + turn_id = self._notification_turn_id(notification) + if turn_id is None: + self._global_notifications.put(notification) + return + + with self._router_lock: + turn_queue = self._turn_notifications.get(turn_id) + if turn_queue is None: + self._pending_turn_notifications.setdefault(turn_id, deque()).append( + notification + ) + return + turn_queue.put(notification) + + def _notification_turn_id(self, notification: Notification) -> str | None: + payload = notification.payload + turn_id = getattr(payload, "turn_id", None) + if isinstance(turn_id, str): + return turn_id + turn = getattr(payload, "turn", None) + nested_turn_id = getattr(turn, "id", None) + if isinstance(nested_turn_id, str): + return nested_turn_id + return None + + def _fail_pending_queues(self, exc: BaseException) -> None: + with self._router_lock: + response_waiters = list(self._response_waiters.values()) + self._response_waiters.clear() + turn_queues = list(self._turn_notifications.values()) + for waiter in response_waiters: + waiter.put(exc) + for turn_queue in turn_queues: + turn_queue.put(exc) + self._global_notifications.put(exc) + def _handle_server_request(self, msg: dict[str, JsonValue]) -> JsonObject: method = msg["method"] params = msg.get("params") diff --git a/sdk/python/tests/test_client_rpc_methods.py b/sdk/python/tests/test_client_rpc_methods.py index f1dd353b1f..693cd4077a 100644 --- a/sdk/python/tests/test_client_rpc_methods.py +++ b/sdk/python/tests/test_client_rpc_methods.py @@ -128,3 +128,67 @@ def test_invalid_notification_payload_falls_back_to_unknown() -> None: assert event.method == "thread/tokenUsage/updated" assert isinstance(event.payload, UnknownNotification) + + +def test_turn_notification_router_demuxes_registered_turns() -> None: + client = AppServerClient() + client.register_turn_notifications("turn-1") + client.register_turn_notifications("turn-2") + + client._route_notification( + client._coerce_notification( + "item/agentMessage/delta", + { + "delta": "two", + "itemId": "item-2", + "threadId": "thread-1", + "turnId": "turn-2", + }, + ) + ) + client._route_notification( + client._coerce_notification( + "item/agentMessage/delta", + { + "delta": "one", + "itemId": "item-1", + "threadId": "thread-1", + "turnId": "turn-1", + }, + ) + ) + + first = client.next_turn_notification("turn-1") + second = client.next_turn_notification("turn-2") + + assert [ + (first.method, getattr(first.payload, "delta", None)), + (second.method, getattr(second.payload, "delta", None)), + ] == [ + ("item/agentMessage/delta", "one"), + ("item/agentMessage/delta", "two"), + ] + + +def test_turn_notification_router_buffers_events_before_registration() -> None: + client = AppServerClient() + client._route_notification( + client._coerce_notification( + "turn/completed", + { + "threadId": "thread-1", + "turn": {"id": "turn-1", "items": [], "status": "completed"}, + }, + ) + ) + + client.register_turn_notifications("turn-1") + event = client.next_turn_notification("turn-1") + + assert ( + event.method, + getattr(getattr(event.payload, "turn", None), "id", None), + ) == ( + "turn/completed", + "turn-1", + ) diff --git a/sdk/python/tests/test_public_api_runtime_behavior.py b/sdk/python/tests/test_public_api_runtime_behavior.py index 10865cf879..a73b906aac 100644 --- a/sdk/python/tests/test_public_api_runtime_behavior.py +++ b/sdk/python/tests/test_public_api_runtime_behavior.py @@ -226,54 +226,74 @@ def test_async_codex_initializes_only_once_under_concurrency() -> None: asyncio.run(scenario()) -def test_turn_stream_rejects_second_active_consumer() -> None: +def test_turn_streams_can_consume_multiple_turns_on_one_client() -> None: client = AppServerClient() - notifications: deque[Notification] = deque( - [ - _delta_notification(turn_id="turn-1"), - _completed_notification(turn_id="turn-1"), - ] - ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + notifications: dict[str, deque[Notification]] = { + "turn-1": deque( + [ + _delta_notification(turn_id="turn-1", text="one"), + _completed_notification(turn_id="turn-1"), + ] + ), + "turn-2": deque( + [ + _delta_notification(turn_id="turn-2", text="two"), + _completed_notification(turn_id="turn-2"), + ] + ), + } + client.next_turn_notification = lambda turn_id: notifications[turn_id].popleft() # type: ignore[method-assign] first_stream = TurnHandle(client, "thread-1", "turn-1").stream() assert next(first_stream).method == "item/agentMessage/delta" second_stream = TurnHandle(client, "thread-1", "turn-2").stream() - with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"): - next(second_stream) + assert next(second_stream).method == "item/agentMessage/delta" + assert next(first_stream).method == "turn/completed" + assert next(second_stream).method == "turn/completed" first_stream.close() + second_stream.close() -def test_async_turn_stream_rejects_second_active_consumer() -> None: +def test_async_turn_streams_can_consume_multiple_turns_on_one_client() -> None: async def scenario() -> None: codex = AsyncCodex() async def fake_ensure_initialized() -> None: return None - notifications: deque[Notification] = deque( - [ - _delta_notification(turn_id="turn-1"), - _completed_notification(turn_id="turn-1"), - ] - ) + notifications: dict[str, deque[Notification]] = { + "turn-1": deque( + [ + _delta_notification(turn_id="turn-1", text="one"), + _completed_notification(turn_id="turn-1"), + ] + ), + "turn-2": deque( + [ + _delta_notification(turn_id="turn-2", text="two"), + _completed_notification(turn_id="turn-2"), + ] + ), + } - async def fake_next_notification() -> Notification: - return notifications.popleft() + async def fake_next_notification(turn_id: str) -> Notification: + return notifications[turn_id].popleft() codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign] - codex._client.next_notification = fake_next_notification # type: ignore[method-assign] + codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign] first_stream = AsyncTurnHandle(codex, "thread-1", "turn-1").stream() assert (await anext(first_stream)).method == "item/agentMessage/delta" second_stream = AsyncTurnHandle(codex, "thread-1", "turn-2").stream() - with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"): - await anext(second_stream) + assert (await anext(second_stream)).method == "item/agentMessage/delta" + assert (await anext(first_stream)).method == "turn/completed" + assert (await anext(second_stream)).method == "turn/completed" await first_stream.aclose() + await second_stream.aclose() asyncio.run(scenario()) @@ -285,7 +305,7 @@ def test_turn_run_returns_completed_turn_payload() -> None: _completed_notification(), ] ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign] result = TurnHandle(client, "thread-1", "turn-1").run() @@ -305,7 +325,7 @@ def test_thread_run_accepts_string_input_and_returns_run_result() -> None: _completed_notification(), ] ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign] seen: dict[str, object] = {} def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202 @@ -338,7 +358,7 @@ def test_thread_run_uses_last_completed_assistant_message_as_final_response() -> _completed_notification(), ] ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign] client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731 turn=SimpleNamespace(id="turn-1") ) @@ -363,7 +383,7 @@ def test_thread_run_preserves_empty_last_assistant_message() -> None: _completed_notification(), ] ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign] client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731 turn=SimpleNamespace(id="turn-1") ) @@ -394,7 +414,7 @@ def test_thread_run_prefers_explicit_final_answer_over_later_commentary() -> Non _completed_notification(), ] ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign] client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731 turn=SimpleNamespace(id="turn-1") ) @@ -420,7 +440,7 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non _completed_notification(), ] ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign] client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731 turn=SimpleNamespace(id="turn-1") ) @@ -438,7 +458,7 @@ def test_thread_run_raises_on_failed_turn() -> None: _completed_notification(status="failed", error_message="boom"), ] ) - client.next_notification = notifications.popleft # type: ignore[method-assign] + client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign] client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731 turn=SimpleNamespace(id="turn-1") ) @@ -471,12 +491,12 @@ def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None: seen["params"] = params return SimpleNamespace(turn=SimpleNamespace(id="turn-1")) - async def fake_next_notification() -> Notification: + async def fake_next_notification(_turn_id: str) -> Notification: return notifications.popleft() codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign] codex._client.turn_start = fake_turn_start # type: ignore[method-assign] - codex._client.next_notification = fake_next_notification # type: ignore[method-assign] + codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign] result = await AsyncThread(codex, "thread-1").run("hello") @@ -491,15 +511,21 @@ def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None: asyncio.run(scenario()) -def test_async_thread_run_uses_last_completed_assistant_message_as_final_response() -> None: +def test_async_thread_run_uses_last_completed_assistant_message_as_final_response() -> ( + None +): async def scenario() -> None: codex = AsyncCodex() async def fake_ensure_initialized() -> None: return None - first_item_notification = _item_completed_notification(text="First async message") - second_item_notification = _item_completed_notification(text="Second async message") + first_item_notification = _item_completed_notification( + text="First async message" + ) + second_item_notification = _item_completed_notification( + text="Second async message" + ) notifications: deque[Notification] = deque( [ first_item_notification, @@ -511,12 +537,12 @@ def test_async_thread_run_uses_last_completed_assistant_message_as_final_respons async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001 return SimpleNamespace(turn=SimpleNamespace(id="turn-1")) - async def fake_next_notification() -> Notification: + async def fake_next_notification(_turn_id: str) -> Notification: return notifications.popleft() codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign] codex._client.turn_start = fake_turn_start # type: ignore[method-assign] - codex._client.next_notification = fake_next_notification # type: ignore[method-assign] + codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign] result = await AsyncThread(codex, "thread-1").run("hello") @@ -550,12 +576,12 @@ def test_async_thread_run_returns_none_when_only_commentary_messages_complete() async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001 return SimpleNamespace(turn=SimpleNamespace(id="turn-1")) - async def fake_next_notification() -> Notification: + async def fake_next_notification(_turn_id: str) -> Notification: return notifications.popleft() codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign] codex._client.turn_start = fake_turn_start # type: ignore[method-assign] - codex._client.next_notification = fake_next_notification # type: ignore[method-assign] + codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign] result = await AsyncThread(codex, "thread-1").run("hello")