Compare commits

...

1 Commits

Author SHA1 Message Date
Shaqayeq
083243dca1 Support concurrent Python SDK turns across threads 2026-03-19 16:16:02 -07:00
8 changed files with 584 additions and 189 deletions

View File

@@ -2,7 +2,7 @@
Public surface of `codex_app_server` for app-server v2. Public surface of `codex_app_server` for app-server v2.
This SDK surface is experimental. The current implementation intentionally allows only one active turn consumer (`Thread.run()`, `TurnHandle.stream()`, or `TurnHandle.run()`) per client instance at a time. This SDK surface is experimental. The current implementation allows concurrent turn consumers on one client only when they belong to different thread IDs. Each client still supports only one active turn per thread ID at a time.
## Package Entry ## Package Entry
@@ -137,8 +137,9 @@ Use `turn(...)` when you need low-level turn control (`stream()`, `steer()`,
Behavior notes: Behavior notes:
- `stream()` and `run()` are exclusive per client instance in the current experimental build - `stream()` and `run()` may run concurrently on one client when the turns belong to different thread IDs
- starting a second turn consumer on the same `Codex` instance raises `RuntimeError` - starting a second turn on the same thread raises `RuntimeError`; use `steer()` or `interrupt()` on the existing handle instead
- low-level global notification APIs such as `next_notification()` are incompatible with active turn streaming on the same client
### AsyncTurnHandle ### AsyncTurnHandle
@@ -149,8 +150,9 @@ Behavior notes:
Behavior notes: Behavior notes:
- `stream()` and `run()` are exclusive per client instance in the current experimental build - `stream()` and `run()` may run concurrently on one client when the turns belong to different thread IDs
- starting a second turn consumer on the same `AsyncCodex` instance raises `RuntimeError` - starting a second turn on the same thread raises `RuntimeError`; use `steer()` or `interrupt()` on the existing handle instead
- low-level global notification APIs such as `next_notification()` are incompatible with active turn streaming on the same client
## Inputs ## Inputs

View File

@@ -43,7 +43,8 @@ What happened:
- `thread.run("...")` started a turn, consumed events until completion, and returned the final assistant response plus collected items and usage. - `thread.run("...")` started a turn, consumed events until completion, and returned the final assistant response plus collected items and usage.
- `result.final_response` is `None` when no final-answer or phase-less assistant message item completes for the turn. - `result.final_response` is `None` when no final-answer or phase-less assistant message item completes for the turn.
- use `thread.turn(...)` when you need a `TurnHandle` for streaming, steering, interrupting, or turn IDs/status - use `thread.turn(...)` when you need a `TurnHandle` for streaming, steering, interrupting, or turn IDs/status
- one client can have only one active turn consumer (`thread.run(...)`, `TurnHandle.stream()`, or `TurnHandle.run()`) at a time in the current experimental build - one client can run turns concurrently across different thread IDs in the current experimental build
- one thread can have only one active turn at a time on a given client; start a second same-thread turn only after the first completes, or use `steer()` on the existing `TurnHandle`
## 3) Continue the same thread (multi-turn) ## 3) Continue the same thread (multi-turn)

View File

@@ -653,11 +653,10 @@ class TurnHandle:
return self._client.turn_interrupt(self.thread_id, self.id) return self._client.turn_interrupt(self.thread_id, self.id)
def stream(self) -> Iterator[Notification]: def stream(self) -> Iterator[Notification]:
# TODO: replace this client-wide experimental guard with per-turn event demux. self._client.acquire_turn_consumer(self.thread_id, self.id)
self._client.acquire_turn_consumer(self.id)
try: try:
while True: while True:
event = self._client.next_notification() event = self._client.next_turn_notification(self.thread_id, self.id)
yield event yield event
if ( if (
event.method == "turn/completed" event.method == "turn/completed"
@@ -666,7 +665,7 @@ class TurnHandle:
): ):
break break
finally: finally:
self._client.release_turn_consumer(self.id) self._client.release_turn_consumer(self.thread_id, self.id)
def run(self) -> AppServerTurn: def run(self) -> AppServerTurn:
completed: TurnCompletedNotification | None = None completed: TurnCompletedNotification | None = None
@@ -704,11 +703,10 @@ class AsyncTurnHandle:
async def stream(self) -> AsyncIterator[Notification]: async def stream(self) -> AsyncIterator[Notification]:
await self._codex._ensure_initialized() await self._codex._ensure_initialized()
# TODO: replace this client-wide experimental guard with per-turn event demux. self._codex._client.acquire_turn_consumer(self.thread_id, self.id)
self._codex._client.acquire_turn_consumer(self.id)
try: try:
while True: while True:
event = await self._codex._client.next_notification() event = await self._codex._client.next_turn_notification(self.thread_id, self.id)
yield event yield event
if ( if (
event.method == "turn/completed" event.method == "turn/completed"
@@ -717,7 +715,7 @@ class AsyncTurnHandle:
): ):
break break
finally: finally:
self._codex._client.release_turn_consumer(self.id) self._codex._client.release_turn_consumer(self.thread_id, self.id)
async def run(self) -> AppServerTurn: async def run(self) -> AppServerTurn:
completed: TurnCompletedNotification | None = None completed: TurnCompletedNotification | None = None

View File

@@ -41,8 +41,6 @@ class AsyncAppServerClient:
def __init__(self, config: AppServerConfig | None = None) -> None: def __init__(self, config: AppServerConfig | None = None) -> None:
self._sync = AppServerClient(config=config) self._sync = AppServerClient(config=config)
# Single stdio transport cannot be read safely from multiple threads.
self._transport_lock = asyncio.Lock()
async def __aenter__(self) -> "AsyncAppServerClient": async def __aenter__(self) -> "AsyncAppServerClient":
await self.start() await self.start()
@@ -58,8 +56,7 @@ class AsyncAppServerClient:
*args: ParamsT.args, *args: ParamsT.args,
**kwargs: ParamsT.kwargs, **kwargs: ParamsT.kwargs,
) -> ReturnT: ) -> ReturnT:
async with self._transport_lock: return await asyncio.to_thread(fn, *args, **kwargs)
return await asyncio.to_thread(fn, *args, **kwargs)
@staticmethod @staticmethod
def _next_from_iterator( def _next_from_iterator(
@@ -79,11 +76,11 @@ class AsyncAppServerClient:
async def initialize(self) -> InitializeResponse: async def initialize(self) -> InitializeResponse:
return await self._call_sync(self._sync.initialize) return await self._call_sync(self._sync.initialize)
def acquire_turn_consumer(self, turn_id: str) -> None: def acquire_turn_consumer(self, thread_id: str, turn_id: str) -> None:
self._sync.acquire_turn_consumer(turn_id) self._sync.acquire_turn_consumer(thread_id, turn_id)
def release_turn_consumer(self, turn_id: str) -> None: def release_turn_consumer(self, thread_id: str, turn_id: str) -> None:
self._sync.release_turn_consumer(turn_id) self._sync.release_turn_consumer(thread_id, turn_id)
async def request( async def request(
self, self,
@@ -184,6 +181,9 @@ class AsyncAppServerClient:
async def next_notification(self) -> Notification: async def next_notification(self) -> Notification:
return await self._call_sync(self._sync.next_notification) return await self._call_sync(self._sync.next_notification)
async def next_turn_notification(self, thread_id: str, turn_id: str) -> Notification:
return await self._call_sync(self._sync.next_turn_notification, thread_id, turn_id)
async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification: async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
return await self._call_sync(self._sync.wait_for_turn_completed, turn_id) return await self._call_sync(self._sync.wait_for_turn_completed, turn_id)
@@ -196,13 +196,12 @@ class AsyncAppServerClient:
text: str, text: str,
params: V2TurnStartParams | JsonObject | None = None, params: V2TurnStartParams | JsonObject | None = None,
) -> AsyncIterator[AgentMessageDeltaNotification]: ) -> AsyncIterator[AgentMessageDeltaNotification]:
async with self._transport_lock: iterator = self._sync.stream_text(thread_id, text, params)
iterator = self._sync.stream_text(thread_id, text, params) while True:
while True: has_value, chunk = await asyncio.to_thread(
has_value, chunk = await asyncio.to_thread( self._next_from_iterator,
self._next_from_iterator, iterator,
iterator, )
) if not has_value:
if not has_value: break
break yield chunk
yield chunk

View File

@@ -6,7 +6,7 @@ import subprocess
import threading import threading
import uuid import uuid
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Callable, Iterable, Iterator, TypeVar from typing import Callable, Iterable, Iterator, TypeVar
@@ -48,6 +48,58 @@ from .retry import retry_on_overload
ModelT = TypeVar("ModelT", bound=BaseModel) ModelT = TypeVar("ModelT", bound=BaseModel)
ApprovalHandler = Callable[[str, JsonObject | None], JsonObject] ApprovalHandler = Callable[[str, JsonObject | None], JsonObject]
RUNTIME_PKG_NAME = "codex-cli-bin" RUNTIME_PKG_NAME = "codex-cli-bin"
GLOBAL_NOTIFICATION_BACKLOG_LIMIT = 512
@dataclass(slots=True)
class _PendingRequest:
event: threading.Event = field(default_factory=threading.Event)
result: JsonValue | None = None
error: BaseException | None = None
class _BufferedNotificationStream:
def __init__(self, *, maxlen: int | None = None) -> None:
self._condition = threading.Condition()
self._items: deque[Notification] = (
deque(maxlen=maxlen) if maxlen is not None else deque()
)
self._closed = False
self._error: BaseException | None = None
def push(self, notification: Notification) -> None:
with self._condition:
if self._closed:
return
self._items.append(notification)
self._condition.notify_all()
def pop(self) -> Notification:
with self._condition:
while not self._items and not self._closed:
self._condition.wait()
if self._items:
return self._items.popleft()
if self._error is not None:
raise self._error
raise TransportClosedError("notification stream is closed")
def close(self, error: BaseException | None = None) -> None:
with self._condition:
self._closed = True
self._error = error
self._condition.notify_all()
def is_closed(self) -> bool:
with self._condition:
return self._closed
def is_drained(self) -> bool:
with self._condition:
return self._closed and not self._items
def _params_dict( def _params_dict(
@@ -144,12 +196,21 @@ class AppServerClient:
self.config = config or AppServerConfig() self.config = config or AppServerConfig()
self._approval_handler = approval_handler or self._default_approval_handler self._approval_handler = approval_handler or self._default_approval_handler
self._proc: subprocess.Popen[str] | None = None self._proc: subprocess.Popen[str] | None = None
self._lock = threading.Lock() self._write_lock = threading.Lock()
self._turn_consumer_lock = threading.Lock() self._state_lock = threading.Lock()
self._active_turn_consumer: str | None = None self._pending_notifications = _BufferedNotificationStream(
self._pending_notifications: deque[Notification] = deque() maxlen=GLOBAL_NOTIFICATION_BACKLOG_LIMIT
)
self._pending_requests: dict[str, _PendingRequest] = {}
self._turn_streams: dict[tuple[str, str], _BufferedNotificationStream] = {}
self._turn_starting_by_thread_id: set[str] = set()
self._active_turn_by_thread_id: dict[str, str] = {}
self._active_turn_consumers: set[tuple[str, str]] = set()
self._active_turn_stream_count = 0
self._transport_error: BaseException | None = None
self._stderr_lines: deque[str] = deque(maxlen=400) self._stderr_lines: deque[str] = deque(maxlen=400)
self._stderr_thread: threading.Thread | None = None self._stderr_thread: threading.Thread | None = None
self._reader_thread: threading.Thread | None = None
def __enter__(self) -> "AppServerClient": def __enter__(self) -> "AppServerClient":
self.start() self.start()
@@ -161,6 +222,7 @@ class AppServerClient:
def start(self) -> None: def start(self) -> None:
if self._proc is not None: if self._proc is not None:
return return
self._reset_transport_state()
if self.config.launch_args_override is not None: if self.config.launch_args_override is not None:
args = list(self.config.launch_args_override) args = list(self.config.launch_args_override)
@@ -187,13 +249,14 @@ class AppServerClient:
) )
self._start_stderr_drain_thread() self._start_stderr_drain_thread()
self._start_reader_thread()
def close(self) -> None: def close(self) -> None:
if self._proc is None: if self._proc is None:
return return
proc = self._proc proc = self._proc
self._proc = None self._proc = None
self._active_turn_consumer = None self._finish_transport(TransportClosedError("app-server closed"))
if proc.stdin: if proc.stdin:
proc.stdin.close() proc.stdin.close()
@@ -205,6 +268,8 @@ class AppServerClient:
if self._stderr_thread and self._stderr_thread.is_alive(): if self._stderr_thread and self._stderr_thread.is_alive():
self._stderr_thread.join(timeout=0.5) self._stderr_thread.join(timeout=0.5)
if self._reader_thread and self._reader_thread.is_alive():
self._reader_thread.join(timeout=0.5)
def initialize(self) -> InitializeResponse: def initialize(self) -> InitializeResponse:
result = self.request( result = self.request(
@@ -238,67 +303,76 @@ class AppServerClient:
def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue: def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue:
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
self._write_message({"id": request_id, "method": method, "params": params or {}}) waiter = _PendingRequest()
with self._state_lock:
if self._transport_error is not None:
raise self._transport_error
self._pending_requests[request_id] = waiter
while True: try:
msg = self._read_message() self._write_message({"id": request_id, "method": method, "params": params or {}})
except BaseException:
with self._state_lock:
self._pending_requests.pop(request_id, None)
raise
if "method" in msg and "id" in msg: waiter.event.wait()
response = self._handle_server_request(msg) if waiter.error is not None:
self._write_message({"id": msg["id"], "result": response}) raise waiter.error
continue return waiter.result
if "method" in msg and "id" not in msg:
self._pending_notifications.append(
self._coerce_notification(msg["method"], msg.get("params"))
)
continue
if msg.get("id") != request_id:
continue
if "error" in msg:
err = msg["error"]
if isinstance(err, dict):
raise map_jsonrpc_error(
int(err.get("code", -32000)),
str(err.get("message", "unknown")),
err.get("data"),
)
raise AppServerError("Malformed JSON-RPC error response")
return msg.get("result")
def notify(self, method: str, params: JsonObject | None = None) -> None: def notify(self, method: str, params: JsonObject | None = None) -> None:
self._write_message({"method": method, "params": params or {}}) self._write_message({"method": method, "params": params or {}})
def next_notification(self) -> Notification: def next_notification(self) -> Notification:
if self._pending_notifications: with self._state_lock:
return self._pending_notifications.popleft() if self._active_turn_stream_count > 0:
while True:
msg = self._read_message()
if "method" in msg and "id" in msg:
response = self._handle_server_request(msg)
self._write_message({"id": msg["id"], "result": response})
continue
if "method" in msg and "id" not in msg:
return self._coerce_notification(msg["method"], msg.get("params"))
def acquire_turn_consumer(self, turn_id: str) -> None:
with self._turn_consumer_lock:
if self._active_turn_consumer is not None:
raise RuntimeError( raise RuntimeError(
"Concurrent turn consumers are not yet supported in the experimental SDK. " "next_notification() is incompatible with active turn streaming on the same "
f"Client is already streaming turn {self._active_turn_consumer!r}; " "client. Consume notifications from TurnHandle.stream()/run() instead."
f"cannot start turn {turn_id!r} until the active consumer finishes."
) )
self._active_turn_consumer = turn_id return self._pending_notifications.pop()
def release_turn_consumer(self, turn_id: str) -> None: def acquire_turn_consumer(self, thread_id: str, turn_id: str) -> None:
with self._turn_consumer_lock: turn_key = (thread_id, turn_id)
if self._active_turn_consumer == turn_id: with self._state_lock:
self._active_turn_consumer = None if turn_key in self._active_turn_consumers:
raise RuntimeError(
f"Turn {turn_id!r} is already being streamed on thread {thread_id!r}."
)
self._active_turn_consumers.add(turn_key)
self._active_turn_stream_count += 1
self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
def release_turn_consumer(self, thread_id: str, turn_id: str) -> None:
turn_key = (thread_id, turn_id)
with self._state_lock:
if turn_key in self._active_turn_consumers:
self._active_turn_consumers.remove(turn_key)
self._active_turn_stream_count -= 1
stream = self._turn_streams.get(turn_key)
if stream is not None and stream.is_drained():
self._turn_streams.pop(turn_key, None)
def next_turn_notification(self, thread_id: str, turn_id: str) -> Notification:
turn_key = (thread_id, turn_id)
with self._state_lock:
stream = self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
return stream.pop()
def assert_can_start_turn(self, thread_id: str) -> None:
with self._state_lock:
if thread_id in self._turn_starting_by_thread_id:
raise RuntimeError(
f"Thread {thread_id!r} is already starting a turn on this client."
)
active_turn_id = self._active_turn_by_thread_id.get(thread_id)
if active_turn_id is not None:
raise RuntimeError(
f"Thread {thread_id!r} already has active turn {active_turn_id!r}. "
"Use TurnHandle.steer() or TurnHandle.interrupt() instead of starting "
"another turn on the same thread."
)
def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse: def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
return self.request("thread/start", _params_dict(params), response_model=ThreadStartResponse) return self.request("thread/start", _params_dict(params), response_model=ThreadStartResponse)
@@ -355,12 +429,19 @@ class AppServerClient:
input_items: list[JsonObject] | JsonObject | str, input_items: list[JsonObject] | JsonObject | str,
params: V2TurnStartParams | JsonObject | None = None, params: V2TurnStartParams | JsonObject | None = None,
) -> TurnStartResponse: ) -> TurnStartResponse:
self._begin_turn_start(thread_id)
payload = { payload = {
**_params_dict(params), **_params_dict(params),
"threadId": thread_id, "threadId": thread_id,
"input": self._normalize_input_items(input_items), "input": self._normalize_input_items(input_items),
} }
return self.request("turn/start", payload, response_model=TurnStartResponse) try:
started = self.request("turn/start", payload, response_model=TurnStartResponse)
except BaseException:
self._cancel_turn_start(thread_id)
raise
self._finish_turn_start(thread_id, started.turn.id)
return started
def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse: def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
return self.request( return self.request(
@@ -436,21 +517,25 @@ class AppServerClient:
) -> Iterator[AgentMessageDeltaNotification]: ) -> Iterator[AgentMessageDeltaNotification]:
started = self.turn_start(thread_id, text, params=params) started = self.turn_start(thread_id, text, params=params)
turn_id = started.turn.id turn_id = started.turn.id
while True: self.acquire_turn_consumer(thread_id, turn_id)
notification = self.next_notification() try:
if ( while True:
notification.method == "item/agentMessage/delta" notification = self.next_turn_notification(thread_id, turn_id)
and isinstance(notification.payload, AgentMessageDeltaNotification) if (
and notification.payload.turn_id == turn_id notification.method == "item/agentMessage/delta"
): and isinstance(notification.payload, AgentMessageDeltaNotification)
yield notification.payload and notification.payload.turn_id == turn_id
continue ):
if ( yield notification.payload
notification.method == "turn/completed" continue
and isinstance(notification.payload, TurnCompletedNotification) if (
and notification.payload.turn.id == turn_id notification.method == "turn/completed"
): and isinstance(notification.payload, TurnCompletedNotification)
break and notification.payload.turn.id == turn_id
):
break
finally:
self.release_turn_consumer(thread_id, turn_id)
def _coerce_notification(self, method: str, params: object) -> Notification: def _coerce_notification(self, method: str, params: object) -> Notification:
params_dict = params if isinstance(params, dict) else {} params_dict = params if isinstance(params, dict) else {}
@@ -512,7 +597,7 @@ class AppServerClient:
def _write_message(self, payload: JsonObject) -> None: def _write_message(self, payload: JsonObject) -> None:
if self._proc is None or self._proc.stdin is None: if self._proc is None or self._proc.stdin is None:
raise TransportClosedError("app-server is not running") raise TransportClosedError("app-server is not running")
with self._lock: with self._write_lock:
self._proc.stdin.write(json.dumps(payload) + "\n") self._proc.stdin.write(json.dumps(payload) + "\n")
self._proc.stdin.flush() self._proc.stdin.flush()
@@ -535,6 +620,162 @@ class AppServerClient:
raise AppServerError(f"Invalid JSON-RPC payload: {message!r}") raise AppServerError(f"Invalid JSON-RPC payload: {message!r}")
return message return message
def _reset_transport_state(self) -> None:
self._pending_notifications = _BufferedNotificationStream(
maxlen=GLOBAL_NOTIFICATION_BACKLOG_LIMIT
)
self._pending_requests = {}
self._turn_streams = {}
self._turn_starting_by_thread_id = set()
self._active_turn_by_thread_id = {}
self._active_turn_consumers = set()
self._active_turn_stream_count = 0
self._transport_error = None
def _start_reader_thread(self) -> None:
def _reader() -> None:
try:
while True:
msg = self._read_message()
if "method" in msg and "id" in msg:
self._start_server_request_worker(msg)
continue
if "method" in msg and "id" not in msg:
method = msg["method"]
if isinstance(method, str):
self._dispatch_notification(
self._coerce_notification(method, msg.get("params"))
)
continue
self._handle_response_message(msg)
except BaseException as exc: # noqa: BLE001
self._finish_transport(exc)
self._reader_thread = threading.Thread(target=_reader, daemon=True)
self._reader_thread.start()
def _start_server_request_worker(self, msg: dict[str, JsonValue]) -> None:
def _resolve() -> None:
try:
response = self._handle_server_request(msg)
self._write_message({"id": msg["id"], "result": response})
except BaseException:
return
threading.Thread(target=_resolve, daemon=True).start()
def _handle_response_message(self, msg: dict[str, JsonValue]) -> None:
request_id = msg.get("id")
if not isinstance(request_id, str):
return
with self._state_lock:
waiter = self._pending_requests.pop(request_id, None)
if waiter is None:
return
if "error" in msg:
err = msg["error"]
if isinstance(err, dict):
waiter.error = map_jsonrpc_error(
int(err.get("code", -32000)),
str(err.get("message", "unknown")),
err.get("data"),
)
else:
waiter.error = AppServerError("Malformed JSON-RPC error response")
else:
waiter.result = msg.get("result")
waiter.event.set()
def _dispatch_notification(self, notification: Notification) -> None:
self._pending_notifications.push(notification)
turn_key = self._turn_key_for_notification(notification)
if turn_key is None:
return
thread_id, turn_id = turn_key
close_stream = False
with self._state_lock:
stream = self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
if notification.method == "turn/started":
self._turn_starting_by_thread_id.discard(thread_id)
self._active_turn_by_thread_id[thread_id] = turn_id
elif notification.method == "turn/completed":
self._turn_starting_by_thread_id.discard(thread_id)
if self._active_turn_by_thread_id.get(thread_id) == turn_id:
self._active_turn_by_thread_id.pop(thread_id, None)
close_stream = True
stream.push(notification)
if close_stream:
stream.close()
def _turn_key_for_notification(self, notification: Notification) -> tuple[str, str] | None:
payload = notification.payload
thread_id = getattr(payload, "thread_id", None)
turn_id = getattr(payload, "turn_id", None)
if isinstance(thread_id, str) and isinstance(turn_id, str):
return thread_id, turn_id
turn = getattr(payload, "turn", None)
nested_turn_id = getattr(turn, "id", None)
if isinstance(thread_id, str) and isinstance(nested_turn_id, str):
return thread_id, nested_turn_id
return None
def _begin_turn_start(self, thread_id: str) -> None:
with self._state_lock:
active_turn_id = self._active_turn_by_thread_id.get(thread_id)
if active_turn_id is not None:
raise RuntimeError(
f"Thread {thread_id!r} already has active turn {active_turn_id!r}. "
"Use TurnHandle.steer() or TurnHandle.interrupt() instead of starting "
"another turn on the same thread."
)
if thread_id in self._turn_starting_by_thread_id:
raise RuntimeError(
f"Thread {thread_id!r} is already starting a turn on this client."
)
self._turn_starting_by_thread_id.add(thread_id)
def _cancel_turn_start(self, thread_id: str) -> None:
with self._state_lock:
self._turn_starting_by_thread_id.discard(thread_id)
def _finish_turn_start(self, thread_id: str, turn_id: str) -> None:
turn_key = (thread_id, turn_id)
with self._state_lock:
self._turn_starting_by_thread_id.discard(thread_id)
stream = self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
if not stream.is_closed():
self._active_turn_by_thread_id[thread_id] = turn_id
def _finish_transport(self, error: BaseException) -> None:
with self._state_lock:
if self._transport_error is not None:
return
self._transport_error = error
pending_requests = list(self._pending_requests.values())
self._pending_requests.clear()
turn_streams = list(self._turn_streams.values())
self._turn_streams.clear()
self._turn_starting_by_thread_id.clear()
self._active_turn_by_thread_id.clear()
self._active_turn_consumers.clear()
self._active_turn_stream_count = 0
for waiter in pending_requests:
waiter.error = error
waiter.event.set()
self._pending_notifications.close(error)
for stream in turn_streams:
stream.close(error)
def default_codex_home() -> str: def default_codex_home() -> str:
return str(Path.home() / ".codex") return str(Path.home() / ".codex")

View File

@@ -6,7 +6,7 @@ import time
from codex_app_server.async_client import AsyncAppServerClient from codex_app_server.async_client import AsyncAppServerClient
def test_async_client_serializes_transport_calls() -> None: def test_async_client_allows_parallel_transport_calls() -> None:
async def scenario() -> int: async def scenario() -> int:
client = AsyncAppServerClient() client = AsyncAppServerClient()
active = 0 active = 0
@@ -24,10 +24,10 @@ def test_async_client_serializes_transport_calls() -> None:
await asyncio.gather(client.model_list(), client.model_list()) await asyncio.gather(client.model_list(), client.model_list())
return max_active return max_active
assert asyncio.run(scenario()) == 1 assert asyncio.run(scenario()) == 2
def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None: def test_async_stream_text_is_incremental_and_allows_parallel_calls() -> None:
async def scenario() -> tuple[str, list[str], bool]: async def scenario() -> tuple[str, list[str], bool]:
client = AsyncAppServerClient() client = AsyncAppServerClient()
@@ -46,19 +46,19 @@ def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None:
stream = client.stream_text("thread-1", "hello") stream = client.stream_text("thread-1", "hello")
first = await anext(stream) first = await anext(stream)
blocked_before_stream_done = False completed_before_stream_done = False
competing_call = asyncio.create_task(client.model_list()) competing_call = asyncio.create_task(client.model_list())
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
blocked_before_stream_done = not competing_call.done() completed_before_stream_done = competing_call.done()
remaining: list[str] = [] remaining: list[str] = []
async for item in stream: async for item in stream:
remaining.append(item) remaining.append(item)
await competing_call await competing_call
return first, remaining, blocked_before_stream_done return first, remaining, completed_before_stream_done
first, remaining, blocked = asyncio.run(scenario()) first, remaining, completed = asyncio.run(scenario())
assert first == "first" assert first == "first"
assert remaining == ["second", "third"] assert remaining == ["second", "third"]
assert blocked assert completed

View File

@@ -4,7 +4,12 @@ from pathlib import Path
from typing import Any from typing import Any
from codex_app_server.client import AppServerClient, _params_dict from codex_app_server.client import AppServerClient, _params_dict
from codex_app_server.generated.v2_all import ThreadListParams, ThreadTokenUsageUpdatedNotification from codex_app_server.generated.v2_all import (
AgentMessageDeltaNotification,
ThreadListParams,
ThreadTokenUsageUpdatedNotification,
TurnCompletedNotification,
)
from codex_app_server.models import UnknownNotification from codex_app_server.models import UnknownNotification
ROOT = Path(__file__).resolve().parents[1] ROOT = Path(__file__).resolve().parents[1]
@@ -93,3 +98,58 @@ def test_invalid_notification_payload_falls_back_to_unknown() -> None:
assert event.method == "thread/tokenUsage/updated" assert event.method == "thread/tokenUsage/updated"
assert isinstance(event.payload, UnknownNotification) assert isinstance(event.payload, UnknownNotification)
def test_client_routes_interleaved_turn_notifications_to_matching_turn_queues() -> None:
client = AppServerClient()
first = client._coerce_notification(
"item/agentMessage/delta",
{
"delta": "first",
"itemId": "item-1",
"threadId": "thread-1",
"turnId": "turn-1",
},
)
second = client._coerce_notification(
"item/agentMessage/delta",
{
"delta": "second",
"itemId": "item-2",
"threadId": "thread-2",
"turnId": "turn-2",
},
)
client._dispatch_notification(first) # type: ignore[attr-defined]
client._dispatch_notification(second) # type: ignore[attr-defined]
first_turn = client.next_turn_notification("thread-1", "turn-1")
second_turn = client.next_turn_notification("thread-2", "turn-2")
assert isinstance(first_turn.payload, AgentMessageDeltaNotification)
assert first_turn.payload.delta == "first"
assert isinstance(second_turn.payload, AgentMessageDeltaNotification)
assert second_turn.payload.delta == "second"
def test_next_notification_still_returns_turn_notifications_without_active_streams() -> None:
client = AppServerClient()
completed = client._coerce_notification(
"turn/completed",
{
"threadId": "thread-1",
"turn": {
"id": "turn-1",
"items": [],
"status": "completed",
},
},
)
client._dispatch_notification(completed) # type: ignore[attr-defined]
event = client.next_notification()
assert event.method == "turn/completed"
assert isinstance(event.payload, TurnCompletedNotification)
assert event.payload.turn.id == "turn-1"

View File

@@ -133,6 +133,15 @@ def _token_usage_notification(
) )
def _turn_notification_source(
notifications_by_turn: dict[tuple[str, str], deque[Notification]],
):
def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
return notifications_by_turn[(thread_id, turn_id)].popleft()
return fake_next_turn_notification
def test_codex_init_failure_closes_client(monkeypatch: pytest.MonkeyPatch) -> None: def test_codex_init_failure_closes_client(monkeypatch: pytest.MonkeyPatch) -> None:
closed: list[bool] = [] closed: list[bool] = []
@@ -226,66 +235,132 @@ def test_async_codex_initializes_only_once_under_concurrency() -> None:
asyncio.run(scenario()) asyncio.run(scenario())
def test_turn_stream_rejects_second_active_consumer() -> None: def test_turn_stream_allows_different_active_threads() -> None:
client = AppServerClient() client = AppServerClient()
notifications: deque[Notification] = deque( client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
[ {
_delta_notification(turn_id="turn-1"), ("thread-1", "turn-1"): deque(
_completed_notification(turn_id="turn-1"), [
] _delta_notification(thread_id="thread-1", turn_id="turn-1"),
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
]
),
("thread-2", "turn-2"): deque(
[
_delta_notification(thread_id="thread-2", turn_id="turn-2"),
_completed_notification(thread_id="thread-2", turn_id="turn-2"),
]
),
}
)
first_stream = TurnHandle(client, "thread-1", "turn-1").stream()
second_stream = TurnHandle(client, "thread-2", "turn-2").stream()
assert next(first_stream).method == "item/agentMessage/delta"
assert next(second_stream).method == "item/agentMessage/delta"
first_stream.close()
second_stream.close()
def test_turn_stream_blocks_next_notification_while_active() -> None:
client = AppServerClient()
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
]
),
}
) )
client.next_notification = notifications.popleft # type: ignore[method-assign]
first_stream = TurnHandle(client, "thread-1", "turn-1").stream() first_stream = TurnHandle(client, "thread-1", "turn-1").stream()
assert next(first_stream).method == "item/agentMessage/delta" assert next(first_stream).method == "item/agentMessage/delta"
second_stream = TurnHandle(client, "thread-1", "turn-2").stream() with pytest.raises(RuntimeError, match="next_notification\\(\\) is incompatible"):
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"): client.next_notification()
next(second_stream)
first_stream.close() first_stream.close()
def test_async_turn_stream_rejects_second_active_consumer() -> None: def test_turn_start_rejects_same_thread_overlap_and_allows_after_completion() -> None:
client = AppServerClient()
turn_ids = iter(["turn-1", "turn-2"])
def fake_request(method: str, params, *, response_model): # type: ignore[no-untyped-def]
assert method == "turn/start"
return response_model.model_validate(
{
"turn": {
"id": next(turn_ids),
"items": [],
"status": TurnStatus.in_progress.value,
}
}
)
client.request = fake_request # type: ignore[method-assign]
first = client.turn_start("thread-1", "first turn")
assert first.turn.id == "turn-1"
with pytest.raises(RuntimeError, match="already has active turn"):
client.turn_start("thread-1", "second turn")
client._dispatch_notification( # type: ignore[attr-defined]
_completed_notification(thread_id="thread-1", turn_id="turn-1")
)
second = client.turn_start("thread-1", "second turn")
assert second.turn.id == "turn-2"
def test_async_turn_stream_allows_different_active_threads() -> None:
async def scenario() -> None: async def scenario() -> None:
codex = AsyncCodex() codex = AsyncCodex()
async def fake_ensure_initialized() -> None: async def fake_ensure_initialized() -> None:
return None return None
notifications: deque[Notification] = deque( notifications_by_turn = {
[ ("thread-1", "turn-1"): deque(
_delta_notification(turn_id="turn-1"), [
_completed_notification(turn_id="turn-1"), _delta_notification(thread_id="thread-1", turn_id="turn-1"),
] _completed_notification(thread_id="thread-1", turn_id="turn-1"),
) ]
),
("thread-2", "turn-2"): deque(
[
_delta_notification(thread_id="thread-2", turn_id="turn-2"),
_completed_notification(thread_id="thread-2", turn_id="turn-2"),
]
),
}
async def fake_next_notification() -> Notification: async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
return notifications.popleft() return notifications_by_turn[(thread_id, turn_id)].popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign] codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign] codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
first_stream = AsyncTurnHandle(codex, "thread-1", "turn-1").stream() first_stream = AsyncTurnHandle(codex, "thread-1", "turn-1").stream()
second_stream = AsyncTurnHandle(codex, "thread-2", "turn-2").stream()
assert (await anext(first_stream)).method == "item/agentMessage/delta" assert (await anext(first_stream)).method == "item/agentMessage/delta"
assert (await anext(second_stream)).method == "item/agentMessage/delta"
second_stream = AsyncTurnHandle(codex, "thread-1", "turn-2").stream()
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"):
await anext(second_stream)
await first_stream.aclose() await first_stream.aclose()
await second_stream.aclose()
asyncio.run(scenario()) asyncio.run(scenario())
def test_turn_run_returns_completed_turn_payload() -> None: def test_turn_run_returns_completed_turn_payload() -> None:
client = AppServerClient() client = AppServerClient()
notifications: deque[Notification] = deque( client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
[ {("thread-1", "turn-1"): deque([_completed_notification()])}
_completed_notification(),
]
) )
client.next_notification = notifications.popleft # type: ignore[method-assign]
result = TurnHandle(client, "thread-1", "turn-1").run() result = TurnHandle(client, "thread-1", "turn-1").run()
@@ -298,14 +373,17 @@ def test_thread_run_accepts_string_input_and_returns_run_result() -> None:
client = AppServerClient() client = AppServerClient()
item_notification = _item_completed_notification(text="Hello.") item_notification = _item_completed_notification(text="Hello.")
usage_notification = _token_usage_notification() usage_notification = _token_usage_notification()
notifications: deque[Notification] = deque( client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
[ {
item_notification, ("thread-1", "turn-1"): deque(
usage_notification, [
_completed_notification(), item_notification,
] usage_notification,
_completed_notification(),
]
),
}
) )
client.next_notification = notifications.popleft # type: ignore[method-assign]
seen: dict[str, object] = {} seen: dict[str, object] = {}
def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202 def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202
@@ -331,14 +409,17 @@ def test_thread_run_uses_last_completed_assistant_message_as_final_response() ->
client = AppServerClient() client = AppServerClient()
first_item_notification = _item_completed_notification(text="First message") first_item_notification = _item_completed_notification(text="First message")
second_item_notification = _item_completed_notification(text="Second message") second_item_notification = _item_completed_notification(text="Second message")
notifications: deque[Notification] = deque( client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
[ {
first_item_notification, ("thread-1", "turn-1"): deque(
second_item_notification, [
_completed_notification(), first_item_notification,
] second_item_notification,
_completed_notification(),
]
),
}
) )
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731 client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1") turn=SimpleNamespace(id="turn-1")
) )
@@ -356,14 +437,17 @@ def test_thread_run_preserves_empty_last_assistant_message() -> None:
client = AppServerClient() client = AppServerClient()
first_item_notification = _item_completed_notification(text="First message") first_item_notification = _item_completed_notification(text="First message")
second_item_notification = _item_completed_notification(text="") second_item_notification = _item_completed_notification(text="")
notifications: deque[Notification] = deque( client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
[ {
first_item_notification, ("thread-1", "turn-1"): deque(
second_item_notification, [
_completed_notification(), first_item_notification,
] second_item_notification,
_completed_notification(),
]
),
}
) )
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731 client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1") turn=SimpleNamespace(id="turn-1")
) )
@@ -387,14 +471,17 @@ def test_thread_run_prefers_explicit_final_answer_over_later_commentary() -> Non
text="Commentary", text="Commentary",
phase=MessagePhase.commentary, phase=MessagePhase.commentary,
) )
notifications: deque[Notification] = deque( client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
[ {
final_answer_notification, ("thread-1", "turn-1"): deque(
commentary_notification, [
_completed_notification(), final_answer_notification,
] commentary_notification,
_completed_notification(),
]
),
}
) )
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731 client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1") turn=SimpleNamespace(id="turn-1")
) )
@@ -414,13 +501,16 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non
text="Commentary", text="Commentary",
phase=MessagePhase.commentary, phase=MessagePhase.commentary,
) )
notifications: deque[Notification] = deque( client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
[ {
commentary_notification, ("thread-1", "turn-1"): deque(
_completed_notification(), [
] commentary_notification,
_completed_notification(),
]
),
}
) )
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731 client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1") turn=SimpleNamespace(id="turn-1")
) )
@@ -433,12 +523,13 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non
def test_thread_run_raises_on_failed_turn() -> None: def test_thread_run_raises_on_failed_turn() -> None:
client = AppServerClient() client = AppServerClient()
notifications: deque[Notification] = deque( client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
[ {
_completed_notification(status="failed", error_message="boom"), ("thread-1", "turn-1"): deque(
] [_completed_notification(status="failed", error_message="boom")]
),
}
) )
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731 client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1") turn=SimpleNamespace(id="turn-1")
) )
@@ -471,12 +562,13 @@ def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None:
seen["params"] = params seen["params"] = params
return SimpleNamespace(turn=SimpleNamespace(id="turn-1")) return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
async def fake_next_notification() -> Notification: async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
assert (thread_id, turn_id) == ("thread-1", "turn-1")
return notifications.popleft() return notifications.popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign] codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.turn_start = fake_turn_start # type: ignore[method-assign] codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign] codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
result = await AsyncThread(codex, "thread-1").run("hello") result = await AsyncThread(codex, "thread-1").run("hello")
@@ -511,12 +603,13 @@ def test_async_thread_run_uses_last_completed_assistant_message_as_final_respons
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001 async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
return SimpleNamespace(turn=SimpleNamespace(id="turn-1")) return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
async def fake_next_notification() -> Notification: async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
assert (thread_id, turn_id) == ("thread-1", "turn-1")
return notifications.popleft() return notifications.popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign] codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.turn_start = fake_turn_start # type: ignore[method-assign] codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign] codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
result = await AsyncThread(codex, "thread-1").run("hello") result = await AsyncThread(codex, "thread-1").run("hello")
@@ -550,12 +643,13 @@ def test_async_thread_run_returns_none_when_only_commentary_messages_complete()
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001 async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
return SimpleNamespace(turn=SimpleNamespace(id="turn-1")) return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
async def fake_next_notification() -> Notification: async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
assert (thread_id, turn_id) == ("thread-1", "turn-1")
return notifications.popleft() return notifications.popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign] codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.turn_start = fake_turn_start # type: ignore[method-assign] codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign] codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
result = await AsyncThread(codex, "thread-1").run("hello") result = await AsyncThread(codex, "thread-1").run("hello")