mirror of
https://github.com/openai/codex.git
synced 2026-05-18 18:22:39 +00:00
Route Python SDK turn notifications by ID (#21778)
## Why The Python SDK previously protected the stdio transport with a single active turn-consumer guard. That avoided competing reads from stdout, but it also meant one `Codex`/`AsyncCodex` client could not stream multiple active turns at the same time. Notifications could also arrive before the caller received a `TurnHandle` and registered for streaming, so the SDK needed an explicit routing layer instead of letting individual API calls read directly from the shared transport. ## What Changed - Added a private `MessageRouter` that owns per-request response queues, per-turn notification queues, pending turn-notification replay, and global notification delivery behind a single stdout reader thread. - Generated typed notification routing metadata so turn IDs come from known payload shapes instead of router-side attribute guessing, with explicit fallback handling for unknown notification payloads. - Updated sync and async turn streaming so `TurnHandle.stream()`/`run()` and `stream_text()` consume only notifications for their own turn ID, while `AsyncAppServerClient` no longer serializes all transport calls behind one async lock. - Cleared pending turn-notification buffers when unregistered turns complete so never-consumed turn handles do not leave stale queues behind. - Removed the internal stream-until helper now that turn completion waiting can register directly with routed turn notifications. - Updated Python SDK docs and focused tests for concurrent transport calls, interleaved turn routing, buffered early notifications, unknown notification routing, async delegation, and routed turn completion behavior. ## Validation - `uv run --extra dev ruff format scripts/update_sdk_artifacts.py src/codex_app_server/_message_router.py src/codex_app_server/client.py src/codex_app_server/generated/notification_registry.py tests/test_client_rpc_methods.py tests/test_public_api_runtime_behavior.py tests/test_async_client_behavior.py` - `uv run --extra dev ruff check scripts/update_sdk_artifacts.py src/codex_app_server/_message_router.py src/codex_app_server/client.py src/codex_app_server/generated/notification_registry.py tests/test_client_rpc_methods.py tests/test_public_api_runtime_behavior.py tests/test_async_client_behavior.py` - `uv run --extra dev pytest tests/test_client_rpc_methods.py tests/test_public_api_runtime_behavior.py tests/test_async_client_behavior.py` - `git diff --check` --------- Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -8,11 +8,11 @@ import uuid
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Iterator, TypeVar
|
||||
from typing import Callable, Iterator, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .errors import AppServerError, TransportClosedError, map_jsonrpc_error
|
||||
from .errors import AppServerError, TransportClosedError
|
||||
from .generated.notification_registry import NOTIFICATION_MODELS
|
||||
from .generated.v2_all import (
|
||||
AgentMessageDeltaNotification,
|
||||
@@ -43,6 +43,7 @@ from .models import (
|
||||
Notification,
|
||||
UnknownNotification,
|
||||
)
|
||||
from ._message_router import MessageRouter
|
||||
from .retry import retry_on_overload
|
||||
from ._version import __version__ as SDK_VERSION
|
||||
|
||||
@@ -75,7 +76,9 @@ def _params_dict(
|
||||
return dumped
|
||||
if isinstance(params, dict):
|
||||
return params
|
||||
raise TypeError(f"Expected generated params model or dict, got {type(params).__name__}")
|
||||
raise TypeError(
|
||||
f"Expected generated params model or dict, got {type(params).__name__}"
|
||||
)
|
||||
|
||||
|
||||
def _installed_codex_path() -> Path:
|
||||
@@ -146,11 +149,10 @@ class AppServerClient:
|
||||
self._approval_handler = approval_handler or self._default_approval_handler
|
||||
self._proc: subprocess.Popen[str] | None = None
|
||||
self._lock = threading.Lock()
|
||||
self._turn_consumer_lock = threading.Lock()
|
||||
self._active_turn_consumer: str | None = None
|
||||
self._pending_notifications: deque[Notification] = deque()
|
||||
self._router = MessageRouter()
|
||||
self._stderr_lines: deque[str] = deque(maxlen=400)
|
||||
self._stderr_thread: threading.Thread | None = None
|
||||
self._reader_thread: threading.Thread | None = None
|
||||
|
||||
def __enter__(self) -> "AppServerClient":
|
||||
self.start()
|
||||
@@ -189,13 +191,13 @@ class AppServerClient:
|
||||
)
|
||||
|
||||
self._start_stderr_drain_thread()
|
||||
self._start_reader_thread()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._proc is None:
|
||||
return
|
||||
proc = self._proc
|
||||
self._proc = None
|
||||
self._active_turn_consumer = None
|
||||
|
||||
if proc.stdin:
|
||||
proc.stdin.close()
|
||||
@@ -207,6 +209,8 @@ class AppServerClient:
|
||||
|
||||
if self._stderr_thread and self._stderr_thread.is_alive():
|
||||
self._stderr_thread.join(timeout=0.5)
|
||||
if self._reader_thread and self._reader_thread.is_alive():
|
||||
self._reader_thread.join(timeout=0.5)
|
||||
|
||||
def initialize(self) -> InitializeResponse:
|
||||
result = self.request(
|
||||
@@ -240,70 +244,42 @@ class AppServerClient:
|
||||
|
||||
def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue:
|
||||
request_id = str(uuid.uuid4())
|
||||
self._write_message({"id": request_id, "method": method, "params": params or {}})
|
||||
waiter = self._router.create_response_waiter(request_id)
|
||||
|
||||
while True:
|
||||
msg = self._read_message()
|
||||
try:
|
||||
self._write_message(
|
||||
{"id": request_id, "method": method, "params": params or {}}
|
||||
)
|
||||
except BaseException:
|
||||
self._router.discard_response_waiter(request_id)
|
||||
raise
|
||||
|
||||
if "method" in msg and "id" in msg:
|
||||
response = self._handle_server_request(msg)
|
||||
self._write_message({"id": msg["id"], "result": response})
|
||||
continue
|
||||
|
||||
if "method" in msg and "id" not in msg:
|
||||
self._pending_notifications.append(
|
||||
self._coerce_notification(msg["method"], msg.get("params"))
|
||||
)
|
||||
continue
|
||||
|
||||
if msg.get("id") != request_id:
|
||||
continue
|
||||
|
||||
if "error" in msg:
|
||||
err = msg["error"]
|
||||
if isinstance(err, dict):
|
||||
raise map_jsonrpc_error(
|
||||
int(err.get("code", -32000)),
|
||||
str(err.get("message", "unknown")),
|
||||
err.get("data"),
|
||||
)
|
||||
raise AppServerError("Malformed JSON-RPC error response")
|
||||
|
||||
return msg.get("result")
|
||||
item = waiter.get()
|
||||
if isinstance(item, BaseException):
|
||||
raise item
|
||||
return item
|
||||
|
||||
def notify(self, method: str, params: JsonObject | None = None) -> None:
|
||||
self._write_message({"method": method, "params": params or {}})
|
||||
|
||||
def next_notification(self) -> Notification:
|
||||
if self._pending_notifications:
|
||||
return self._pending_notifications.popleft()
|
||||
return self._router.next_global_notification()
|
||||
|
||||
while True:
|
||||
msg = self._read_message()
|
||||
if "method" in msg and "id" in msg:
|
||||
response = self._handle_server_request(msg)
|
||||
self._write_message({"id": msg["id"], "result": response})
|
||||
continue
|
||||
if "method" in msg and "id" not in msg:
|
||||
return self._coerce_notification(msg["method"], msg.get("params"))
|
||||
def register_turn_notifications(self, turn_id: str) -> None:
|
||||
self._router.register_turn(turn_id)
|
||||
|
||||
def acquire_turn_consumer(self, turn_id: str) -> None:
|
||||
with self._turn_consumer_lock:
|
||||
if self._active_turn_consumer is not None:
|
||||
raise RuntimeError(
|
||||
"Concurrent turn consumers are not yet supported in the experimental SDK. "
|
||||
f"Client is already streaming turn {self._active_turn_consumer!r}; "
|
||||
f"cannot start turn {turn_id!r} until the active consumer finishes."
|
||||
)
|
||||
self._active_turn_consumer = turn_id
|
||||
def unregister_turn_notifications(self, turn_id: str) -> None:
|
||||
self._router.unregister_turn(turn_id)
|
||||
|
||||
def release_turn_consumer(self, turn_id: str) -> None:
|
||||
with self._turn_consumer_lock:
|
||||
if self._active_turn_consumer == turn_id:
|
||||
self._active_turn_consumer = None
|
||||
def next_turn_notification(self, turn_id: str) -> Notification:
|
||||
return self._router.next_turn_notification(turn_id)
|
||||
|
||||
def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
|
||||
return self.request("thread/start", _params_dict(params), response_model=ThreadStartResponse)
|
||||
def thread_start(
|
||||
self, params: V2ThreadStartParams | JsonObject | None = None
|
||||
) -> ThreadStartResponse:
|
||||
return self.request(
|
||||
"thread/start", _params_dict(params), response_model=ThreadStartResponse
|
||||
)
|
||||
|
||||
def thread_resume(
|
||||
self,
|
||||
@@ -311,12 +287,20 @@ class AppServerClient:
|
||||
params: V2ThreadResumeParams | JsonObject | None = None,
|
||||
) -> ThreadResumeResponse:
|
||||
payload = {"threadId": thread_id, **_params_dict(params)}
|
||||
return self.request("thread/resume", payload, response_model=ThreadResumeResponse)
|
||||
return self.request(
|
||||
"thread/resume", payload, response_model=ThreadResumeResponse
|
||||
)
|
||||
|
||||
def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse:
|
||||
return self.request("thread/list", _params_dict(params), response_model=ThreadListResponse)
|
||||
def thread_list(
|
||||
self, params: V2ThreadListParams | JsonObject | None = None
|
||||
) -> ThreadListResponse:
|
||||
return self.request(
|
||||
"thread/list", _params_dict(params), response_model=ThreadListResponse
|
||||
)
|
||||
|
||||
def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse:
|
||||
def thread_read(
|
||||
self, thread_id: str, include_turns: bool = False
|
||||
) -> ThreadReadResponse:
|
||||
return self.request(
|
||||
"thread/read",
|
||||
{"threadId": thread_id, "includeTurns": include_turns},
|
||||
@@ -332,10 +316,18 @@ class AppServerClient:
|
||||
return self.request("thread/fork", payload, response_model=ThreadForkResponse)
|
||||
|
||||
def thread_archive(self, thread_id: str) -> ThreadArchiveResponse:
|
||||
return self.request("thread/archive", {"threadId": thread_id}, response_model=ThreadArchiveResponse)
|
||||
return self.request(
|
||||
"thread/archive",
|
||||
{"threadId": thread_id},
|
||||
response_model=ThreadArchiveResponse,
|
||||
)
|
||||
|
||||
def thread_unarchive(self, thread_id: str) -> ThreadUnarchiveResponse:
|
||||
return self.request("thread/unarchive", {"threadId": thread_id}, response_model=ThreadUnarchiveResponse)
|
||||
return self.request(
|
||||
"thread/unarchive",
|
||||
{"threadId": thread_id},
|
||||
response_model=ThreadUnarchiveResponse,
|
||||
)
|
||||
|
||||
def thread_set_name(self, thread_id: str, name: str) -> ThreadSetNameResponse:
|
||||
return self.request(
|
||||
@@ -362,7 +354,9 @@ class AppServerClient:
|
||||
"threadId": thread_id,
|
||||
"input": self._normalize_input_items(input_items),
|
||||
}
|
||||
return self.request("turn/start", payload, response_model=TurnStartResponse)
|
||||
started = self.request("turn/start", payload, response_model=TurnStartResponse)
|
||||
self.register_turn_notifications(started.turn.id)
|
||||
return started
|
||||
|
||||
def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
|
||||
return self.request(
|
||||
@@ -412,23 +406,18 @@ class AppServerClient:
|
||||
)
|
||||
|
||||
def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
|
||||
while True:
|
||||
notification = self.next_notification()
|
||||
if (
|
||||
notification.method == "turn/completed"
|
||||
and isinstance(notification.payload, TurnCompletedNotification)
|
||||
and notification.payload.turn.id == turn_id
|
||||
):
|
||||
return notification.payload
|
||||
|
||||
def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]:
|
||||
target_methods = {methods} if isinstance(methods, str) else set(methods)
|
||||
out: list[Notification] = []
|
||||
while True:
|
||||
notification = self.next_notification()
|
||||
out.append(notification)
|
||||
if notification.method in target_methods:
|
||||
return out
|
||||
self.register_turn_notifications(turn_id)
|
||||
try:
|
||||
while True:
|
||||
notification = self.next_turn_notification(turn_id)
|
||||
if (
|
||||
notification.method == "turn/completed"
|
||||
and isinstance(notification.payload, TurnCompletedNotification)
|
||||
and notification.payload.turn.id == turn_id
|
||||
):
|
||||
return notification.payload
|
||||
finally:
|
||||
self.unregister_turn_notifications(turn_id)
|
||||
|
||||
def stream_text(
|
||||
self,
|
||||
@@ -438,33 +427,41 @@ class AppServerClient:
|
||||
) -> Iterator[AgentMessageDeltaNotification]:
|
||||
started = self.turn_start(thread_id, text, params=params)
|
||||
turn_id = started.turn.id
|
||||
while True:
|
||||
notification = self.next_notification()
|
||||
if (
|
||||
notification.method == "item/agentMessage/delta"
|
||||
and isinstance(notification.payload, AgentMessageDeltaNotification)
|
||||
and notification.payload.turn_id == turn_id
|
||||
):
|
||||
yield notification.payload
|
||||
continue
|
||||
if (
|
||||
notification.method == "turn/completed"
|
||||
and isinstance(notification.payload, TurnCompletedNotification)
|
||||
and notification.payload.turn.id == turn_id
|
||||
):
|
||||
break
|
||||
self.register_turn_notifications(turn_id)
|
||||
try:
|
||||
while True:
|
||||
notification = self.next_turn_notification(turn_id)
|
||||
if (
|
||||
notification.method == "item/agentMessage/delta"
|
||||
and isinstance(notification.payload, AgentMessageDeltaNotification)
|
||||
and notification.payload.turn_id == turn_id
|
||||
):
|
||||
yield notification.payload
|
||||
continue
|
||||
if (
|
||||
notification.method == "turn/completed"
|
||||
and isinstance(notification.payload, TurnCompletedNotification)
|
||||
and notification.payload.turn.id == turn_id
|
||||
):
|
||||
break
|
||||
finally:
|
||||
self.unregister_turn_notifications(turn_id)
|
||||
|
||||
def _coerce_notification(self, method: str, params: object) -> Notification:
|
||||
params_dict = params if isinstance(params, dict) else {}
|
||||
|
||||
model = NOTIFICATION_MODELS.get(method)
|
||||
if model is None:
|
||||
return Notification(method=method, payload=UnknownNotification(params=params_dict))
|
||||
return Notification(
|
||||
method=method, payload=UnknownNotification(params=params_dict)
|
||||
)
|
||||
|
||||
try:
|
||||
payload = model.model_validate(params_dict)
|
||||
except Exception: # noqa: BLE001
|
||||
return Notification(method=method, payload=UnknownNotification(params=params_dict))
|
||||
return Notification(
|
||||
method=method, payload=UnknownNotification(params=params_dict)
|
||||
)
|
||||
return Notification(method=method, payload=payload)
|
||||
|
||||
def _normalize_input_items(
|
||||
@@ -477,7 +474,9 @@ class AppServerClient:
|
||||
return [input_items]
|
||||
return input_items
|
||||
|
||||
def _default_approval_handler(self, method: str, params: JsonObject | None) -> JsonObject:
|
||||
def _default_approval_handler(
|
||||
self, method: str, params: JsonObject | None
|
||||
) -> JsonObject:
|
||||
if method == "item/commandExecution/requestApproval":
|
||||
return {"decision": "accept"}
|
||||
if method == "item/fileChange/requestApproval":
|
||||
@@ -498,6 +497,32 @@ class AppServerClient:
|
||||
self._stderr_thread = threading.Thread(target=_drain, daemon=True)
|
||||
self._stderr_thread.start()
|
||||
|
||||
def _start_reader_thread(self) -> None:
|
||||
if self._proc is None or self._proc.stdout is None:
|
||||
return
|
||||
|
||||
self._reader_thread = threading.Thread(target=self._reader_loop, daemon=True)
|
||||
self._reader_thread.start()
|
||||
|
||||
def _reader_loop(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
msg = self._read_message()
|
||||
if "method" in msg and "id" in msg:
|
||||
response = self._handle_server_request(msg)
|
||||
self._write_message({"id": msg["id"], "result": response})
|
||||
continue
|
||||
if "method" in msg and "id" not in msg:
|
||||
method = msg["method"]
|
||||
if isinstance(method, str):
|
||||
self._router.route_notification(
|
||||
self._coerce_notification(method, msg.get("params"))
|
||||
)
|
||||
continue
|
||||
self._router.route_response(msg)
|
||||
except BaseException as exc:
|
||||
self._router.fail_all(exc)
|
||||
|
||||
def _stderr_tail(self, limit: int = 40) -> str:
|
||||
return "\n".join(list(self._stderr_lines)[-limit:])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user