mirror of
https://github.com/openai/codex.git
synced 2026-05-22 20:14:17 +00:00
Compare commits
1 Commits
pr16267
...
dev/shaqay
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
083243dca1 |
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
Public surface of `codex_app_server` for app-server v2.
|
Public surface of `codex_app_server` for app-server v2.
|
||||||
|
|
||||||
This SDK surface is experimental. The current implementation intentionally allows only one active turn consumer (`Thread.run()`, `TurnHandle.stream()`, or `TurnHandle.run()`) per client instance at a time.
|
This SDK surface is experimental. The current implementation allows concurrent turn consumers on one client only when they belong to different thread IDs. Each client still supports only one active turn per thread ID at a time.
|
||||||
|
|
||||||
## Package Entry
|
## Package Entry
|
||||||
|
|
||||||
@@ -137,8 +137,9 @@ Use `turn(...)` when you need low-level turn control (`stream()`, `steer()`,
|
|||||||
|
|
||||||
Behavior notes:
|
Behavior notes:
|
||||||
|
|
||||||
- `stream()` and `run()` are exclusive per client instance in the current experimental build
|
- `stream()` and `run()` may run concurrently on one client when the turns belong to different thread IDs
|
||||||
- starting a second turn consumer on the same `Codex` instance raises `RuntimeError`
|
- starting a second turn on the same thread raises `RuntimeError`; use `steer()` or `interrupt()` on the existing handle instead
|
||||||
|
- low-level global notification APIs such as `next_notification()` are incompatible with active turn streaming on the same client
|
||||||
|
|
||||||
### AsyncTurnHandle
|
### AsyncTurnHandle
|
||||||
|
|
||||||
@@ -149,8 +150,9 @@ Behavior notes:
|
|||||||
|
|
||||||
Behavior notes:
|
Behavior notes:
|
||||||
|
|
||||||
- `stream()` and `run()` are exclusive per client instance in the current experimental build
|
- `stream()` and `run()` may run concurrently on one client when the turns belong to different thread IDs
|
||||||
- starting a second turn consumer on the same `AsyncCodex` instance raises `RuntimeError`
|
- starting a second turn on the same thread raises `RuntimeError`; use `steer()` or `interrupt()` on the existing handle instead
|
||||||
|
- low-level global notification APIs such as `next_notification()` are incompatible with active turn streaming on the same client
|
||||||
|
|
||||||
## Inputs
|
## Inputs
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,8 @@ What happened:
|
|||||||
- `thread.run("...")` started a turn, consumed events until completion, and returned the final assistant response plus collected items and usage.
|
- `thread.run("...")` started a turn, consumed events until completion, and returned the final assistant response plus collected items and usage.
|
||||||
- `result.final_response` is `None` when no final-answer or phase-less assistant message item completes for the turn.
|
- `result.final_response` is `None` when no final-answer or phase-less assistant message item completes for the turn.
|
||||||
- use `thread.turn(...)` when you need a `TurnHandle` for streaming, steering, interrupting, or turn IDs/status
|
- use `thread.turn(...)` when you need a `TurnHandle` for streaming, steering, interrupting, or turn IDs/status
|
||||||
- one client can have only one active turn consumer (`thread.run(...)`, `TurnHandle.stream()`, or `TurnHandle.run()`) at a time in the current experimental build
|
- one client can run turns concurrently across different thread IDs in the current experimental build
|
||||||
|
- one thread can have only one active turn at a time on a given client; start a second same-thread turn only after the first completes, or use `steer()` on the existing `TurnHandle`
|
||||||
|
|
||||||
## 3) Continue the same thread (multi-turn)
|
## 3) Continue the same thread (multi-turn)
|
||||||
|
|
||||||
|
|||||||
@@ -653,11 +653,10 @@ class TurnHandle:
|
|||||||
return self._client.turn_interrupt(self.thread_id, self.id)
|
return self._client.turn_interrupt(self.thread_id, self.id)
|
||||||
|
|
||||||
def stream(self) -> Iterator[Notification]:
|
def stream(self) -> Iterator[Notification]:
|
||||||
# TODO: replace this client-wide experimental guard with per-turn event demux.
|
self._client.acquire_turn_consumer(self.thread_id, self.id)
|
||||||
self._client.acquire_turn_consumer(self.id)
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
event = self._client.next_notification()
|
event = self._client.next_turn_notification(self.thread_id, self.id)
|
||||||
yield event
|
yield event
|
||||||
if (
|
if (
|
||||||
event.method == "turn/completed"
|
event.method == "turn/completed"
|
||||||
@@ -666,7 +665,7 @@ class TurnHandle:
|
|||||||
):
|
):
|
||||||
break
|
break
|
||||||
finally:
|
finally:
|
||||||
self._client.release_turn_consumer(self.id)
|
self._client.release_turn_consumer(self.thread_id, self.id)
|
||||||
|
|
||||||
def run(self) -> AppServerTurn:
|
def run(self) -> AppServerTurn:
|
||||||
completed: TurnCompletedNotification | None = None
|
completed: TurnCompletedNotification | None = None
|
||||||
@@ -704,11 +703,10 @@ class AsyncTurnHandle:
|
|||||||
|
|
||||||
async def stream(self) -> AsyncIterator[Notification]:
|
async def stream(self) -> AsyncIterator[Notification]:
|
||||||
await self._codex._ensure_initialized()
|
await self._codex._ensure_initialized()
|
||||||
# TODO: replace this client-wide experimental guard with per-turn event demux.
|
self._codex._client.acquire_turn_consumer(self.thread_id, self.id)
|
||||||
self._codex._client.acquire_turn_consumer(self.id)
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
event = await self._codex._client.next_notification()
|
event = await self._codex._client.next_turn_notification(self.thread_id, self.id)
|
||||||
yield event
|
yield event
|
||||||
if (
|
if (
|
||||||
event.method == "turn/completed"
|
event.method == "turn/completed"
|
||||||
@@ -717,7 +715,7 @@ class AsyncTurnHandle:
|
|||||||
):
|
):
|
||||||
break
|
break
|
||||||
finally:
|
finally:
|
||||||
self._codex._client.release_turn_consumer(self.id)
|
self._codex._client.release_turn_consumer(self.thread_id, self.id)
|
||||||
|
|
||||||
async def run(self) -> AppServerTurn:
|
async def run(self) -> AppServerTurn:
|
||||||
completed: TurnCompletedNotification | None = None
|
completed: TurnCompletedNotification | None = None
|
||||||
|
|||||||
@@ -41,8 +41,6 @@ class AsyncAppServerClient:
|
|||||||
|
|
||||||
def __init__(self, config: AppServerConfig | None = None) -> None:
|
def __init__(self, config: AppServerConfig | None = None) -> None:
|
||||||
self._sync = AppServerClient(config=config)
|
self._sync = AppServerClient(config=config)
|
||||||
# Single stdio transport cannot be read safely from multiple threads.
|
|
||||||
self._transport_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
async def __aenter__(self) -> "AsyncAppServerClient":
|
async def __aenter__(self) -> "AsyncAppServerClient":
|
||||||
await self.start()
|
await self.start()
|
||||||
@@ -58,8 +56,7 @@ class AsyncAppServerClient:
|
|||||||
*args: ParamsT.args,
|
*args: ParamsT.args,
|
||||||
**kwargs: ParamsT.kwargs,
|
**kwargs: ParamsT.kwargs,
|
||||||
) -> ReturnT:
|
) -> ReturnT:
|
||||||
async with self._transport_lock:
|
return await asyncio.to_thread(fn, *args, **kwargs)
|
||||||
return await asyncio.to_thread(fn, *args, **kwargs)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _next_from_iterator(
|
def _next_from_iterator(
|
||||||
@@ -79,11 +76,11 @@ class AsyncAppServerClient:
|
|||||||
async def initialize(self) -> InitializeResponse:
|
async def initialize(self) -> InitializeResponse:
|
||||||
return await self._call_sync(self._sync.initialize)
|
return await self._call_sync(self._sync.initialize)
|
||||||
|
|
||||||
def acquire_turn_consumer(self, turn_id: str) -> None:
|
def acquire_turn_consumer(self, thread_id: str, turn_id: str) -> None:
|
||||||
self._sync.acquire_turn_consumer(turn_id)
|
self._sync.acquire_turn_consumer(thread_id, turn_id)
|
||||||
|
|
||||||
def release_turn_consumer(self, turn_id: str) -> None:
|
def release_turn_consumer(self, thread_id: str, turn_id: str) -> None:
|
||||||
self._sync.release_turn_consumer(turn_id)
|
self._sync.release_turn_consumer(thread_id, turn_id)
|
||||||
|
|
||||||
async def request(
|
async def request(
|
||||||
self,
|
self,
|
||||||
@@ -184,6 +181,9 @@ class AsyncAppServerClient:
|
|||||||
async def next_notification(self) -> Notification:
|
async def next_notification(self) -> Notification:
|
||||||
return await self._call_sync(self._sync.next_notification)
|
return await self._call_sync(self._sync.next_notification)
|
||||||
|
|
||||||
|
async def next_turn_notification(self, thread_id: str, turn_id: str) -> Notification:
|
||||||
|
return await self._call_sync(self._sync.next_turn_notification, thread_id, turn_id)
|
||||||
|
|
||||||
async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
|
async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
|
||||||
return await self._call_sync(self._sync.wait_for_turn_completed, turn_id)
|
return await self._call_sync(self._sync.wait_for_turn_completed, turn_id)
|
||||||
|
|
||||||
@@ -196,13 +196,12 @@ class AsyncAppServerClient:
|
|||||||
text: str,
|
text: str,
|
||||||
params: V2TurnStartParams | JsonObject | None = None,
|
params: V2TurnStartParams | JsonObject | None = None,
|
||||||
) -> AsyncIterator[AgentMessageDeltaNotification]:
|
) -> AsyncIterator[AgentMessageDeltaNotification]:
|
||||||
async with self._transport_lock:
|
iterator = self._sync.stream_text(thread_id, text, params)
|
||||||
iterator = self._sync.stream_text(thread_id, text, params)
|
while True:
|
||||||
while True:
|
has_value, chunk = await asyncio.to_thread(
|
||||||
has_value, chunk = await asyncio.to_thread(
|
self._next_from_iterator,
|
||||||
self._next_from_iterator,
|
iterator,
|
||||||
iterator,
|
)
|
||||||
)
|
if not has_value:
|
||||||
if not has_value:
|
break
|
||||||
break
|
yield chunk
|
||||||
yield chunk
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import subprocess
|
|||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Iterable, Iterator, TypeVar
|
from typing import Callable, Iterable, Iterator, TypeVar
|
||||||
|
|
||||||
@@ -48,6 +48,58 @@ from .retry import retry_on_overload
|
|||||||
ModelT = TypeVar("ModelT", bound=BaseModel)
|
ModelT = TypeVar("ModelT", bound=BaseModel)
|
||||||
ApprovalHandler = Callable[[str, JsonObject | None], JsonObject]
|
ApprovalHandler = Callable[[str, JsonObject | None], JsonObject]
|
||||||
RUNTIME_PKG_NAME = "codex-cli-bin"
|
RUNTIME_PKG_NAME = "codex-cli-bin"
|
||||||
|
GLOBAL_NOTIFICATION_BACKLOG_LIMIT = 512
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _PendingRequest:
|
||||||
|
event: threading.Event = field(default_factory=threading.Event)
|
||||||
|
result: JsonValue | None = None
|
||||||
|
error: BaseException | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class _BufferedNotificationStream:
|
||||||
|
def __init__(self, *, maxlen: int | None = None) -> None:
|
||||||
|
self._condition = threading.Condition()
|
||||||
|
self._items: deque[Notification] = (
|
||||||
|
deque(maxlen=maxlen) if maxlen is not None else deque()
|
||||||
|
)
|
||||||
|
self._closed = False
|
||||||
|
self._error: BaseException | None = None
|
||||||
|
|
||||||
|
def push(self, notification: Notification) -> None:
|
||||||
|
with self._condition:
|
||||||
|
if self._closed:
|
||||||
|
return
|
||||||
|
self._items.append(notification)
|
||||||
|
self._condition.notify_all()
|
||||||
|
|
||||||
|
def pop(self) -> Notification:
|
||||||
|
with self._condition:
|
||||||
|
while not self._items and not self._closed:
|
||||||
|
self._condition.wait()
|
||||||
|
|
||||||
|
if self._items:
|
||||||
|
return self._items.popleft()
|
||||||
|
|
||||||
|
if self._error is not None:
|
||||||
|
raise self._error
|
||||||
|
|
||||||
|
raise TransportClosedError("notification stream is closed")
|
||||||
|
|
||||||
|
def close(self, error: BaseException | None = None) -> None:
|
||||||
|
with self._condition:
|
||||||
|
self._closed = True
|
||||||
|
self._error = error
|
||||||
|
self._condition.notify_all()
|
||||||
|
|
||||||
|
def is_closed(self) -> bool:
|
||||||
|
with self._condition:
|
||||||
|
return self._closed
|
||||||
|
|
||||||
|
def is_drained(self) -> bool:
|
||||||
|
with self._condition:
|
||||||
|
return self._closed and not self._items
|
||||||
|
|
||||||
|
|
||||||
def _params_dict(
|
def _params_dict(
|
||||||
@@ -144,12 +196,21 @@ class AppServerClient:
|
|||||||
self.config = config or AppServerConfig()
|
self.config = config or AppServerConfig()
|
||||||
self._approval_handler = approval_handler or self._default_approval_handler
|
self._approval_handler = approval_handler or self._default_approval_handler
|
||||||
self._proc: subprocess.Popen[str] | None = None
|
self._proc: subprocess.Popen[str] | None = None
|
||||||
self._lock = threading.Lock()
|
self._write_lock = threading.Lock()
|
||||||
self._turn_consumer_lock = threading.Lock()
|
self._state_lock = threading.Lock()
|
||||||
self._active_turn_consumer: str | None = None
|
self._pending_notifications = _BufferedNotificationStream(
|
||||||
self._pending_notifications: deque[Notification] = deque()
|
maxlen=GLOBAL_NOTIFICATION_BACKLOG_LIMIT
|
||||||
|
)
|
||||||
|
self._pending_requests: dict[str, _PendingRequest] = {}
|
||||||
|
self._turn_streams: dict[tuple[str, str], _BufferedNotificationStream] = {}
|
||||||
|
self._turn_starting_by_thread_id: set[str] = set()
|
||||||
|
self._active_turn_by_thread_id: dict[str, str] = {}
|
||||||
|
self._active_turn_consumers: set[tuple[str, str]] = set()
|
||||||
|
self._active_turn_stream_count = 0
|
||||||
|
self._transport_error: BaseException | None = None
|
||||||
self._stderr_lines: deque[str] = deque(maxlen=400)
|
self._stderr_lines: deque[str] = deque(maxlen=400)
|
||||||
self._stderr_thread: threading.Thread | None = None
|
self._stderr_thread: threading.Thread | None = None
|
||||||
|
self._reader_thread: threading.Thread | None = None
|
||||||
|
|
||||||
def __enter__(self) -> "AppServerClient":
|
def __enter__(self) -> "AppServerClient":
|
||||||
self.start()
|
self.start()
|
||||||
@@ -161,6 +222,7 @@ class AppServerClient:
|
|||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
if self._proc is not None:
|
if self._proc is not None:
|
||||||
return
|
return
|
||||||
|
self._reset_transport_state()
|
||||||
|
|
||||||
if self.config.launch_args_override is not None:
|
if self.config.launch_args_override is not None:
|
||||||
args = list(self.config.launch_args_override)
|
args = list(self.config.launch_args_override)
|
||||||
@@ -187,13 +249,14 @@ class AppServerClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._start_stderr_drain_thread()
|
self._start_stderr_drain_thread()
|
||||||
|
self._start_reader_thread()
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
if self._proc is None:
|
if self._proc is None:
|
||||||
return
|
return
|
||||||
proc = self._proc
|
proc = self._proc
|
||||||
self._proc = None
|
self._proc = None
|
||||||
self._active_turn_consumer = None
|
self._finish_transport(TransportClosedError("app-server closed"))
|
||||||
|
|
||||||
if proc.stdin:
|
if proc.stdin:
|
||||||
proc.stdin.close()
|
proc.stdin.close()
|
||||||
@@ -205,6 +268,8 @@ class AppServerClient:
|
|||||||
|
|
||||||
if self._stderr_thread and self._stderr_thread.is_alive():
|
if self._stderr_thread and self._stderr_thread.is_alive():
|
||||||
self._stderr_thread.join(timeout=0.5)
|
self._stderr_thread.join(timeout=0.5)
|
||||||
|
if self._reader_thread and self._reader_thread.is_alive():
|
||||||
|
self._reader_thread.join(timeout=0.5)
|
||||||
|
|
||||||
def initialize(self) -> InitializeResponse:
|
def initialize(self) -> InitializeResponse:
|
||||||
result = self.request(
|
result = self.request(
|
||||||
@@ -238,67 +303,76 @@ class AppServerClient:
|
|||||||
|
|
||||||
def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue:
|
def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue:
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
self._write_message({"id": request_id, "method": method, "params": params or {}})
|
waiter = _PendingRequest()
|
||||||
|
with self._state_lock:
|
||||||
|
if self._transport_error is not None:
|
||||||
|
raise self._transport_error
|
||||||
|
self._pending_requests[request_id] = waiter
|
||||||
|
|
||||||
while True:
|
try:
|
||||||
msg = self._read_message()
|
self._write_message({"id": request_id, "method": method, "params": params or {}})
|
||||||
|
except BaseException:
|
||||||
|
with self._state_lock:
|
||||||
|
self._pending_requests.pop(request_id, None)
|
||||||
|
raise
|
||||||
|
|
||||||
if "method" in msg and "id" in msg:
|
waiter.event.wait()
|
||||||
response = self._handle_server_request(msg)
|
if waiter.error is not None:
|
||||||
self._write_message({"id": msg["id"], "result": response})
|
raise waiter.error
|
||||||
continue
|
return waiter.result
|
||||||
|
|
||||||
if "method" in msg and "id" not in msg:
|
|
||||||
self._pending_notifications.append(
|
|
||||||
self._coerce_notification(msg["method"], msg.get("params"))
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if msg.get("id") != request_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "error" in msg:
|
|
||||||
err = msg["error"]
|
|
||||||
if isinstance(err, dict):
|
|
||||||
raise map_jsonrpc_error(
|
|
||||||
int(err.get("code", -32000)),
|
|
||||||
str(err.get("message", "unknown")),
|
|
||||||
err.get("data"),
|
|
||||||
)
|
|
||||||
raise AppServerError("Malformed JSON-RPC error response")
|
|
||||||
|
|
||||||
return msg.get("result")
|
|
||||||
|
|
||||||
def notify(self, method: str, params: JsonObject | None = None) -> None:
|
def notify(self, method: str, params: JsonObject | None = None) -> None:
|
||||||
self._write_message({"method": method, "params": params or {}})
|
self._write_message({"method": method, "params": params or {}})
|
||||||
|
|
||||||
def next_notification(self) -> Notification:
|
def next_notification(self) -> Notification:
|
||||||
if self._pending_notifications:
|
with self._state_lock:
|
||||||
return self._pending_notifications.popleft()
|
if self._active_turn_stream_count > 0:
|
||||||
|
|
||||||
while True:
|
|
||||||
msg = self._read_message()
|
|
||||||
if "method" in msg and "id" in msg:
|
|
||||||
response = self._handle_server_request(msg)
|
|
||||||
self._write_message({"id": msg["id"], "result": response})
|
|
||||||
continue
|
|
||||||
if "method" in msg and "id" not in msg:
|
|
||||||
return self._coerce_notification(msg["method"], msg.get("params"))
|
|
||||||
|
|
||||||
def acquire_turn_consumer(self, turn_id: str) -> None:
|
|
||||||
with self._turn_consumer_lock:
|
|
||||||
if self._active_turn_consumer is not None:
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Concurrent turn consumers are not yet supported in the experimental SDK. "
|
"next_notification() is incompatible with active turn streaming on the same "
|
||||||
f"Client is already streaming turn {self._active_turn_consumer!r}; "
|
"client. Consume notifications from TurnHandle.stream()/run() instead."
|
||||||
f"cannot start turn {turn_id!r} until the active consumer finishes."
|
|
||||||
)
|
)
|
||||||
self._active_turn_consumer = turn_id
|
return self._pending_notifications.pop()
|
||||||
|
|
||||||
def release_turn_consumer(self, turn_id: str) -> None:
|
def acquire_turn_consumer(self, thread_id: str, turn_id: str) -> None:
|
||||||
with self._turn_consumer_lock:
|
turn_key = (thread_id, turn_id)
|
||||||
if self._active_turn_consumer == turn_id:
|
with self._state_lock:
|
||||||
self._active_turn_consumer = None
|
if turn_key in self._active_turn_consumers:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Turn {turn_id!r} is already being streamed on thread {thread_id!r}."
|
||||||
|
)
|
||||||
|
self._active_turn_consumers.add(turn_key)
|
||||||
|
self._active_turn_stream_count += 1
|
||||||
|
self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
|
||||||
|
|
||||||
|
def release_turn_consumer(self, thread_id: str, turn_id: str) -> None:
|
||||||
|
turn_key = (thread_id, turn_id)
|
||||||
|
with self._state_lock:
|
||||||
|
if turn_key in self._active_turn_consumers:
|
||||||
|
self._active_turn_consumers.remove(turn_key)
|
||||||
|
self._active_turn_stream_count -= 1
|
||||||
|
stream = self._turn_streams.get(turn_key)
|
||||||
|
if stream is not None and stream.is_drained():
|
||||||
|
self._turn_streams.pop(turn_key, None)
|
||||||
|
|
||||||
|
def next_turn_notification(self, thread_id: str, turn_id: str) -> Notification:
|
||||||
|
turn_key = (thread_id, turn_id)
|
||||||
|
with self._state_lock:
|
||||||
|
stream = self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
|
||||||
|
return stream.pop()
|
||||||
|
|
||||||
|
def assert_can_start_turn(self, thread_id: str) -> None:
|
||||||
|
with self._state_lock:
|
||||||
|
if thread_id in self._turn_starting_by_thread_id:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Thread {thread_id!r} is already starting a turn on this client."
|
||||||
|
)
|
||||||
|
active_turn_id = self._active_turn_by_thread_id.get(thread_id)
|
||||||
|
if active_turn_id is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Thread {thread_id!r} already has active turn {active_turn_id!r}. "
|
||||||
|
"Use TurnHandle.steer() or TurnHandle.interrupt() instead of starting "
|
||||||
|
"another turn on the same thread."
|
||||||
|
)
|
||||||
|
|
||||||
def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
|
def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
|
||||||
return self.request("thread/start", _params_dict(params), response_model=ThreadStartResponse)
|
return self.request("thread/start", _params_dict(params), response_model=ThreadStartResponse)
|
||||||
@@ -355,12 +429,19 @@ class AppServerClient:
|
|||||||
input_items: list[JsonObject] | JsonObject | str,
|
input_items: list[JsonObject] | JsonObject | str,
|
||||||
params: V2TurnStartParams | JsonObject | None = None,
|
params: V2TurnStartParams | JsonObject | None = None,
|
||||||
) -> TurnStartResponse:
|
) -> TurnStartResponse:
|
||||||
|
self._begin_turn_start(thread_id)
|
||||||
payload = {
|
payload = {
|
||||||
**_params_dict(params),
|
**_params_dict(params),
|
||||||
"threadId": thread_id,
|
"threadId": thread_id,
|
||||||
"input": self._normalize_input_items(input_items),
|
"input": self._normalize_input_items(input_items),
|
||||||
}
|
}
|
||||||
return self.request("turn/start", payload, response_model=TurnStartResponse)
|
try:
|
||||||
|
started = self.request("turn/start", payload, response_model=TurnStartResponse)
|
||||||
|
except BaseException:
|
||||||
|
self._cancel_turn_start(thread_id)
|
||||||
|
raise
|
||||||
|
self._finish_turn_start(thread_id, started.turn.id)
|
||||||
|
return started
|
||||||
|
|
||||||
def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
|
def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
|
||||||
return self.request(
|
return self.request(
|
||||||
@@ -436,21 +517,25 @@ class AppServerClient:
|
|||||||
) -> Iterator[AgentMessageDeltaNotification]:
|
) -> Iterator[AgentMessageDeltaNotification]:
|
||||||
started = self.turn_start(thread_id, text, params=params)
|
started = self.turn_start(thread_id, text, params=params)
|
||||||
turn_id = started.turn.id
|
turn_id = started.turn.id
|
||||||
while True:
|
self.acquire_turn_consumer(thread_id, turn_id)
|
||||||
notification = self.next_notification()
|
try:
|
||||||
if (
|
while True:
|
||||||
notification.method == "item/agentMessage/delta"
|
notification = self.next_turn_notification(thread_id, turn_id)
|
||||||
and isinstance(notification.payload, AgentMessageDeltaNotification)
|
if (
|
||||||
and notification.payload.turn_id == turn_id
|
notification.method == "item/agentMessage/delta"
|
||||||
):
|
and isinstance(notification.payload, AgentMessageDeltaNotification)
|
||||||
yield notification.payload
|
and notification.payload.turn_id == turn_id
|
||||||
continue
|
):
|
||||||
if (
|
yield notification.payload
|
||||||
notification.method == "turn/completed"
|
continue
|
||||||
and isinstance(notification.payload, TurnCompletedNotification)
|
if (
|
||||||
and notification.payload.turn.id == turn_id
|
notification.method == "turn/completed"
|
||||||
):
|
and isinstance(notification.payload, TurnCompletedNotification)
|
||||||
break
|
and notification.payload.turn.id == turn_id
|
||||||
|
):
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
self.release_turn_consumer(thread_id, turn_id)
|
||||||
|
|
||||||
def _coerce_notification(self, method: str, params: object) -> Notification:
|
def _coerce_notification(self, method: str, params: object) -> Notification:
|
||||||
params_dict = params if isinstance(params, dict) else {}
|
params_dict = params if isinstance(params, dict) else {}
|
||||||
@@ -512,7 +597,7 @@ class AppServerClient:
|
|||||||
def _write_message(self, payload: JsonObject) -> None:
|
def _write_message(self, payload: JsonObject) -> None:
|
||||||
if self._proc is None or self._proc.stdin is None:
|
if self._proc is None or self._proc.stdin is None:
|
||||||
raise TransportClosedError("app-server is not running")
|
raise TransportClosedError("app-server is not running")
|
||||||
with self._lock:
|
with self._write_lock:
|
||||||
self._proc.stdin.write(json.dumps(payload) + "\n")
|
self._proc.stdin.write(json.dumps(payload) + "\n")
|
||||||
self._proc.stdin.flush()
|
self._proc.stdin.flush()
|
||||||
|
|
||||||
@@ -535,6 +620,162 @@ class AppServerClient:
|
|||||||
raise AppServerError(f"Invalid JSON-RPC payload: {message!r}")
|
raise AppServerError(f"Invalid JSON-RPC payload: {message!r}")
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
def _reset_transport_state(self) -> None:
|
||||||
|
self._pending_notifications = _BufferedNotificationStream(
|
||||||
|
maxlen=GLOBAL_NOTIFICATION_BACKLOG_LIMIT
|
||||||
|
)
|
||||||
|
self._pending_requests = {}
|
||||||
|
self._turn_streams = {}
|
||||||
|
self._turn_starting_by_thread_id = set()
|
||||||
|
self._active_turn_by_thread_id = {}
|
||||||
|
self._active_turn_consumers = set()
|
||||||
|
self._active_turn_stream_count = 0
|
||||||
|
self._transport_error = None
|
||||||
|
|
||||||
|
def _start_reader_thread(self) -> None:
|
||||||
|
def _reader() -> None:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
msg = self._read_message()
|
||||||
|
if "method" in msg and "id" in msg:
|
||||||
|
self._start_server_request_worker(msg)
|
||||||
|
continue
|
||||||
|
if "method" in msg and "id" not in msg:
|
||||||
|
method = msg["method"]
|
||||||
|
if isinstance(method, str):
|
||||||
|
self._dispatch_notification(
|
||||||
|
self._coerce_notification(method, msg.get("params"))
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
self._handle_response_message(msg)
|
||||||
|
except BaseException as exc: # noqa: BLE001
|
||||||
|
self._finish_transport(exc)
|
||||||
|
|
||||||
|
self._reader_thread = threading.Thread(target=_reader, daemon=True)
|
||||||
|
self._reader_thread.start()
|
||||||
|
|
||||||
|
def _start_server_request_worker(self, msg: dict[str, JsonValue]) -> None:
|
||||||
|
def _resolve() -> None:
|
||||||
|
try:
|
||||||
|
response = self._handle_server_request(msg)
|
||||||
|
self._write_message({"id": msg["id"], "result": response})
|
||||||
|
except BaseException:
|
||||||
|
return
|
||||||
|
|
||||||
|
threading.Thread(target=_resolve, daemon=True).start()
|
||||||
|
|
||||||
|
def _handle_response_message(self, msg: dict[str, JsonValue]) -> None:
|
||||||
|
request_id = msg.get("id")
|
||||||
|
if not isinstance(request_id, str):
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._state_lock:
|
||||||
|
waiter = self._pending_requests.pop(request_id, None)
|
||||||
|
|
||||||
|
if waiter is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if "error" in msg:
|
||||||
|
err = msg["error"]
|
||||||
|
if isinstance(err, dict):
|
||||||
|
waiter.error = map_jsonrpc_error(
|
||||||
|
int(err.get("code", -32000)),
|
||||||
|
str(err.get("message", "unknown")),
|
||||||
|
err.get("data"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
waiter.error = AppServerError("Malformed JSON-RPC error response")
|
||||||
|
else:
|
||||||
|
waiter.result = msg.get("result")
|
||||||
|
waiter.event.set()
|
||||||
|
|
||||||
|
def _dispatch_notification(self, notification: Notification) -> None:
|
||||||
|
self._pending_notifications.push(notification)
|
||||||
|
|
||||||
|
turn_key = self._turn_key_for_notification(notification)
|
||||||
|
if turn_key is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
thread_id, turn_id = turn_key
|
||||||
|
close_stream = False
|
||||||
|
with self._state_lock:
|
||||||
|
stream = self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
|
||||||
|
if notification.method == "turn/started":
|
||||||
|
self._turn_starting_by_thread_id.discard(thread_id)
|
||||||
|
self._active_turn_by_thread_id[thread_id] = turn_id
|
||||||
|
elif notification.method == "turn/completed":
|
||||||
|
self._turn_starting_by_thread_id.discard(thread_id)
|
||||||
|
if self._active_turn_by_thread_id.get(thread_id) == turn_id:
|
||||||
|
self._active_turn_by_thread_id.pop(thread_id, None)
|
||||||
|
close_stream = True
|
||||||
|
|
||||||
|
stream.push(notification)
|
||||||
|
if close_stream:
|
||||||
|
stream.close()
|
||||||
|
|
||||||
|
def _turn_key_for_notification(self, notification: Notification) -> tuple[str, str] | None:
|
||||||
|
payload = notification.payload
|
||||||
|
thread_id = getattr(payload, "thread_id", None)
|
||||||
|
turn_id = getattr(payload, "turn_id", None)
|
||||||
|
if isinstance(thread_id, str) and isinstance(turn_id, str):
|
||||||
|
return thread_id, turn_id
|
||||||
|
|
||||||
|
turn = getattr(payload, "turn", None)
|
||||||
|
nested_turn_id = getattr(turn, "id", None)
|
||||||
|
if isinstance(thread_id, str) and isinstance(nested_turn_id, str):
|
||||||
|
return thread_id, nested_turn_id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _begin_turn_start(self, thread_id: str) -> None:
|
||||||
|
with self._state_lock:
|
||||||
|
active_turn_id = self._active_turn_by_thread_id.get(thread_id)
|
||||||
|
if active_turn_id is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Thread {thread_id!r} already has active turn {active_turn_id!r}. "
|
||||||
|
"Use TurnHandle.steer() or TurnHandle.interrupt() instead of starting "
|
||||||
|
"another turn on the same thread."
|
||||||
|
)
|
||||||
|
if thread_id in self._turn_starting_by_thread_id:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Thread {thread_id!r} is already starting a turn on this client."
|
||||||
|
)
|
||||||
|
self._turn_starting_by_thread_id.add(thread_id)
|
||||||
|
|
||||||
|
def _cancel_turn_start(self, thread_id: str) -> None:
|
||||||
|
with self._state_lock:
|
||||||
|
self._turn_starting_by_thread_id.discard(thread_id)
|
||||||
|
|
||||||
|
def _finish_turn_start(self, thread_id: str, turn_id: str) -> None:
|
||||||
|
turn_key = (thread_id, turn_id)
|
||||||
|
with self._state_lock:
|
||||||
|
self._turn_starting_by_thread_id.discard(thread_id)
|
||||||
|
stream = self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
|
||||||
|
if not stream.is_closed():
|
||||||
|
self._active_turn_by_thread_id[thread_id] = turn_id
|
||||||
|
|
||||||
|
def _finish_transport(self, error: BaseException) -> None:
|
||||||
|
with self._state_lock:
|
||||||
|
if self._transport_error is not None:
|
||||||
|
return
|
||||||
|
self._transport_error = error
|
||||||
|
pending_requests = list(self._pending_requests.values())
|
||||||
|
self._pending_requests.clear()
|
||||||
|
turn_streams = list(self._turn_streams.values())
|
||||||
|
self._turn_streams.clear()
|
||||||
|
self._turn_starting_by_thread_id.clear()
|
||||||
|
self._active_turn_by_thread_id.clear()
|
||||||
|
self._active_turn_consumers.clear()
|
||||||
|
self._active_turn_stream_count = 0
|
||||||
|
|
||||||
|
for waiter in pending_requests:
|
||||||
|
waiter.error = error
|
||||||
|
waiter.event.set()
|
||||||
|
|
||||||
|
self._pending_notifications.close(error)
|
||||||
|
for stream in turn_streams:
|
||||||
|
stream.close(error)
|
||||||
|
|
||||||
|
|
||||||
def default_codex_home() -> str:
|
def default_codex_home() -> str:
|
||||||
return str(Path.home() / ".codex")
|
return str(Path.home() / ".codex")
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import time
|
|||||||
from codex_app_server.async_client import AsyncAppServerClient
|
from codex_app_server.async_client import AsyncAppServerClient
|
||||||
|
|
||||||
|
|
||||||
def test_async_client_serializes_transport_calls() -> None:
|
def test_async_client_allows_parallel_transport_calls() -> None:
|
||||||
async def scenario() -> int:
|
async def scenario() -> int:
|
||||||
client = AsyncAppServerClient()
|
client = AsyncAppServerClient()
|
||||||
active = 0
|
active = 0
|
||||||
@@ -24,10 +24,10 @@ def test_async_client_serializes_transport_calls() -> None:
|
|||||||
await asyncio.gather(client.model_list(), client.model_list())
|
await asyncio.gather(client.model_list(), client.model_list())
|
||||||
return max_active
|
return max_active
|
||||||
|
|
||||||
assert asyncio.run(scenario()) == 1
|
assert asyncio.run(scenario()) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None:
|
def test_async_stream_text_is_incremental_and_allows_parallel_calls() -> None:
|
||||||
async def scenario() -> tuple[str, list[str], bool]:
|
async def scenario() -> tuple[str, list[str], bool]:
|
||||||
client = AsyncAppServerClient()
|
client = AsyncAppServerClient()
|
||||||
|
|
||||||
@@ -46,19 +46,19 @@ def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None:
|
|||||||
stream = client.stream_text("thread-1", "hello")
|
stream = client.stream_text("thread-1", "hello")
|
||||||
first = await anext(stream)
|
first = await anext(stream)
|
||||||
|
|
||||||
blocked_before_stream_done = False
|
completed_before_stream_done = False
|
||||||
competing_call = asyncio.create_task(client.model_list())
|
competing_call = asyncio.create_task(client.model_list())
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
blocked_before_stream_done = not competing_call.done()
|
completed_before_stream_done = competing_call.done()
|
||||||
|
|
||||||
remaining: list[str] = []
|
remaining: list[str] = []
|
||||||
async for item in stream:
|
async for item in stream:
|
||||||
remaining.append(item)
|
remaining.append(item)
|
||||||
|
|
||||||
await competing_call
|
await competing_call
|
||||||
return first, remaining, blocked_before_stream_done
|
return first, remaining, completed_before_stream_done
|
||||||
|
|
||||||
first, remaining, blocked = asyncio.run(scenario())
|
first, remaining, completed = asyncio.run(scenario())
|
||||||
assert first == "first"
|
assert first == "first"
|
||||||
assert remaining == ["second", "third"]
|
assert remaining == ["second", "third"]
|
||||||
assert blocked
|
assert completed
|
||||||
|
|||||||
@@ -4,7 +4,12 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from codex_app_server.client import AppServerClient, _params_dict
|
from codex_app_server.client import AppServerClient, _params_dict
|
||||||
from codex_app_server.generated.v2_all import ThreadListParams, ThreadTokenUsageUpdatedNotification
|
from codex_app_server.generated.v2_all import (
|
||||||
|
AgentMessageDeltaNotification,
|
||||||
|
ThreadListParams,
|
||||||
|
ThreadTokenUsageUpdatedNotification,
|
||||||
|
TurnCompletedNotification,
|
||||||
|
)
|
||||||
from codex_app_server.models import UnknownNotification
|
from codex_app_server.models import UnknownNotification
|
||||||
|
|
||||||
ROOT = Path(__file__).resolve().parents[1]
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
@@ -93,3 +98,58 @@ def test_invalid_notification_payload_falls_back_to_unknown() -> None:
|
|||||||
|
|
||||||
assert event.method == "thread/tokenUsage/updated"
|
assert event.method == "thread/tokenUsage/updated"
|
||||||
assert isinstance(event.payload, UnknownNotification)
|
assert isinstance(event.payload, UnknownNotification)
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_routes_interleaved_turn_notifications_to_matching_turn_queues() -> None:
|
||||||
|
client = AppServerClient()
|
||||||
|
first = client._coerce_notification(
|
||||||
|
"item/agentMessage/delta",
|
||||||
|
{
|
||||||
|
"delta": "first",
|
||||||
|
"itemId": "item-1",
|
||||||
|
"threadId": "thread-1",
|
||||||
|
"turnId": "turn-1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
second = client._coerce_notification(
|
||||||
|
"item/agentMessage/delta",
|
||||||
|
{
|
||||||
|
"delta": "second",
|
||||||
|
"itemId": "item-2",
|
||||||
|
"threadId": "thread-2",
|
||||||
|
"turnId": "turn-2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
client._dispatch_notification(first) # type: ignore[attr-defined]
|
||||||
|
client._dispatch_notification(second) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
first_turn = client.next_turn_notification("thread-1", "turn-1")
|
||||||
|
second_turn = client.next_turn_notification("thread-2", "turn-2")
|
||||||
|
|
||||||
|
assert isinstance(first_turn.payload, AgentMessageDeltaNotification)
|
||||||
|
assert first_turn.payload.delta == "first"
|
||||||
|
assert isinstance(second_turn.payload, AgentMessageDeltaNotification)
|
||||||
|
assert second_turn.payload.delta == "second"
|
||||||
|
|
||||||
|
|
||||||
|
def test_next_notification_still_returns_turn_notifications_without_active_streams() -> None:
|
||||||
|
client = AppServerClient()
|
||||||
|
completed = client._coerce_notification(
|
||||||
|
"turn/completed",
|
||||||
|
{
|
||||||
|
"threadId": "thread-1",
|
||||||
|
"turn": {
|
||||||
|
"id": "turn-1",
|
||||||
|
"items": [],
|
||||||
|
"status": "completed",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
client._dispatch_notification(completed) # type: ignore[attr-defined]
|
||||||
|
event = client.next_notification()
|
||||||
|
|
||||||
|
assert event.method == "turn/completed"
|
||||||
|
assert isinstance(event.payload, TurnCompletedNotification)
|
||||||
|
assert event.payload.turn.id == "turn-1"
|
||||||
|
|||||||
@@ -133,6 +133,15 @@ def _token_usage_notification(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _turn_notification_source(
|
||||||
|
notifications_by_turn: dict[tuple[str, str], deque[Notification]],
|
||||||
|
):
|
||||||
|
def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
|
||||||
|
return notifications_by_turn[(thread_id, turn_id)].popleft()
|
||||||
|
|
||||||
|
return fake_next_turn_notification
|
||||||
|
|
||||||
|
|
||||||
def test_codex_init_failure_closes_client(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_codex_init_failure_closes_client(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
closed: list[bool] = []
|
closed: list[bool] = []
|
||||||
|
|
||||||
@@ -226,66 +235,132 @@ def test_async_codex_initializes_only_once_under_concurrency() -> None:
|
|||||||
asyncio.run(scenario())
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
|
||||||
def test_turn_stream_rejects_second_active_consumer() -> None:
|
def test_turn_stream_allows_different_active_threads() -> None:
|
||||||
client = AppServerClient()
|
client = AppServerClient()
|
||||||
notifications: deque[Notification] = deque(
|
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||||
[
|
{
|
||||||
_delta_notification(turn_id="turn-1"),
|
("thread-1", "turn-1"): deque(
|
||||||
_completed_notification(turn_id="turn-1"),
|
[
|
||||||
]
|
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||||
|
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
("thread-2", "turn-2"): deque(
|
||||||
|
[
|
||||||
|
_delta_notification(thread_id="thread-2", turn_id="turn-2"),
|
||||||
|
_completed_notification(thread_id="thread-2", turn_id="turn-2"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
first_stream = TurnHandle(client, "thread-1", "turn-1").stream()
|
||||||
|
second_stream = TurnHandle(client, "thread-2", "turn-2").stream()
|
||||||
|
assert next(first_stream).method == "item/agentMessage/delta"
|
||||||
|
assert next(second_stream).method == "item/agentMessage/delta"
|
||||||
|
|
||||||
|
first_stream.close()
|
||||||
|
second_stream.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_turn_stream_blocks_next_notification_while_active() -> None:
|
||||||
|
client = AppServerClient()
|
||||||
|
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||||
|
{
|
||||||
|
("thread-1", "turn-1"): deque(
|
||||||
|
[
|
||||||
|
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||||
|
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
|
||||||
|
|
||||||
first_stream = TurnHandle(client, "thread-1", "turn-1").stream()
|
first_stream = TurnHandle(client, "thread-1", "turn-1").stream()
|
||||||
assert next(first_stream).method == "item/agentMessage/delta"
|
assert next(first_stream).method == "item/agentMessage/delta"
|
||||||
|
|
||||||
second_stream = TurnHandle(client, "thread-1", "turn-2").stream()
|
with pytest.raises(RuntimeError, match="next_notification\\(\\) is incompatible"):
|
||||||
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"):
|
client.next_notification()
|
||||||
next(second_stream)
|
|
||||||
|
|
||||||
first_stream.close()
|
first_stream.close()
|
||||||
|
|
||||||
|
|
||||||
def test_async_turn_stream_rejects_second_active_consumer() -> None:
|
def test_turn_start_rejects_same_thread_overlap_and_allows_after_completion() -> None:
|
||||||
|
client = AppServerClient()
|
||||||
|
turn_ids = iter(["turn-1", "turn-2"])
|
||||||
|
|
||||||
|
def fake_request(method: str, params, *, response_model): # type: ignore[no-untyped-def]
|
||||||
|
assert method == "turn/start"
|
||||||
|
return response_model.model_validate(
|
||||||
|
{
|
||||||
|
"turn": {
|
||||||
|
"id": next(turn_ids),
|
||||||
|
"items": [],
|
||||||
|
"status": TurnStatus.in_progress.value,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
client.request = fake_request # type: ignore[method-assign]
|
||||||
|
|
||||||
|
first = client.turn_start("thread-1", "first turn")
|
||||||
|
assert first.turn.id == "turn-1"
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="already has active turn"):
|
||||||
|
client.turn_start("thread-1", "second turn")
|
||||||
|
|
||||||
|
client._dispatch_notification( # type: ignore[attr-defined]
|
||||||
|
_completed_notification(thread_id="thread-1", turn_id="turn-1")
|
||||||
|
)
|
||||||
|
|
||||||
|
second = client.turn_start("thread-1", "second turn")
|
||||||
|
assert second.turn.id == "turn-2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_turn_stream_allows_different_active_threads() -> None:
|
||||||
async def scenario() -> None:
|
async def scenario() -> None:
|
||||||
codex = AsyncCodex()
|
codex = AsyncCodex()
|
||||||
|
|
||||||
async def fake_ensure_initialized() -> None:
|
async def fake_ensure_initialized() -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
notifications: deque[Notification] = deque(
|
notifications_by_turn = {
|
||||||
[
|
("thread-1", "turn-1"): deque(
|
||||||
_delta_notification(turn_id="turn-1"),
|
[
|
||||||
_completed_notification(turn_id="turn-1"),
|
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||||
]
|
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||||
)
|
]
|
||||||
|
),
|
||||||
|
("thread-2", "turn-2"): deque(
|
||||||
|
[
|
||||||
|
_delta_notification(thread_id="thread-2", turn_id="turn-2"),
|
||||||
|
_completed_notification(thread_id="thread-2", turn_id="turn-2"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
async def fake_next_notification() -> Notification:
|
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
|
||||||
return notifications.popleft()
|
return notifications_by_turn[(thread_id, turn_id)].popleft()
|
||||||
|
|
||||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
|
||||||
|
|
||||||
first_stream = AsyncTurnHandle(codex, "thread-1", "turn-1").stream()
|
first_stream = AsyncTurnHandle(codex, "thread-1", "turn-1").stream()
|
||||||
|
second_stream = AsyncTurnHandle(codex, "thread-2", "turn-2").stream()
|
||||||
assert (await anext(first_stream)).method == "item/agentMessage/delta"
|
assert (await anext(first_stream)).method == "item/agentMessage/delta"
|
||||||
|
assert (await anext(second_stream)).method == "item/agentMessage/delta"
|
||||||
second_stream = AsyncTurnHandle(codex, "thread-1", "turn-2").stream()
|
|
||||||
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"):
|
|
||||||
await anext(second_stream)
|
|
||||||
|
|
||||||
await first_stream.aclose()
|
await first_stream.aclose()
|
||||||
|
await second_stream.aclose()
|
||||||
|
|
||||||
asyncio.run(scenario())
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
|
||||||
def test_turn_run_returns_completed_turn_payload() -> None:
|
def test_turn_run_returns_completed_turn_payload() -> None:
|
||||||
client = AppServerClient()
|
client = AppServerClient()
|
||||||
notifications: deque[Notification] = deque(
|
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||||
[
|
{("thread-1", "turn-1"): deque([_completed_notification()])}
|
||||||
_completed_notification(),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
|
||||||
|
|
||||||
result = TurnHandle(client, "thread-1", "turn-1").run()
|
result = TurnHandle(client, "thread-1", "turn-1").run()
|
||||||
|
|
||||||
@@ -298,14 +373,17 @@ def test_thread_run_accepts_string_input_and_returns_run_result() -> None:
|
|||||||
client = AppServerClient()
|
client = AppServerClient()
|
||||||
item_notification = _item_completed_notification(text="Hello.")
|
item_notification = _item_completed_notification(text="Hello.")
|
||||||
usage_notification = _token_usage_notification()
|
usage_notification = _token_usage_notification()
|
||||||
notifications: deque[Notification] = deque(
|
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||||
[
|
{
|
||||||
item_notification,
|
("thread-1", "turn-1"): deque(
|
||||||
usage_notification,
|
[
|
||||||
_completed_notification(),
|
item_notification,
|
||||||
]
|
usage_notification,
|
||||||
|
_completed_notification(),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
|
||||||
seen: dict[str, object] = {}
|
seen: dict[str, object] = {}
|
||||||
|
|
||||||
def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202
|
def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202
|
||||||
@@ -331,14 +409,17 @@ def test_thread_run_uses_last_completed_assistant_message_as_final_response() ->
|
|||||||
client = AppServerClient()
|
client = AppServerClient()
|
||||||
first_item_notification = _item_completed_notification(text="First message")
|
first_item_notification = _item_completed_notification(text="First message")
|
||||||
second_item_notification = _item_completed_notification(text="Second message")
|
second_item_notification = _item_completed_notification(text="Second message")
|
||||||
notifications: deque[Notification] = deque(
|
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||||
[
|
{
|
||||||
first_item_notification,
|
("thread-1", "turn-1"): deque(
|
||||||
second_item_notification,
|
[
|
||||||
_completed_notification(),
|
first_item_notification,
|
||||||
]
|
second_item_notification,
|
||||||
|
_completed_notification(),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
|
||||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||||
turn=SimpleNamespace(id="turn-1")
|
turn=SimpleNamespace(id="turn-1")
|
||||||
)
|
)
|
||||||
@@ -356,14 +437,17 @@ def test_thread_run_preserves_empty_last_assistant_message() -> None:
|
|||||||
client = AppServerClient()
|
client = AppServerClient()
|
||||||
first_item_notification = _item_completed_notification(text="First message")
|
first_item_notification = _item_completed_notification(text="First message")
|
||||||
second_item_notification = _item_completed_notification(text="")
|
second_item_notification = _item_completed_notification(text="")
|
||||||
notifications: deque[Notification] = deque(
|
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||||
[
|
{
|
||||||
first_item_notification,
|
("thread-1", "turn-1"): deque(
|
||||||
second_item_notification,
|
[
|
||||||
_completed_notification(),
|
first_item_notification,
|
||||||
]
|
second_item_notification,
|
||||||
|
_completed_notification(),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
|
||||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||||
turn=SimpleNamespace(id="turn-1")
|
turn=SimpleNamespace(id="turn-1")
|
||||||
)
|
)
|
||||||
@@ -387,14 +471,17 @@ def test_thread_run_prefers_explicit_final_answer_over_later_commentary() -> Non
|
|||||||
text="Commentary",
|
text="Commentary",
|
||||||
phase=MessagePhase.commentary,
|
phase=MessagePhase.commentary,
|
||||||
)
|
)
|
||||||
notifications: deque[Notification] = deque(
|
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||||
[
|
{
|
||||||
final_answer_notification,
|
("thread-1", "turn-1"): deque(
|
||||||
commentary_notification,
|
[
|
||||||
_completed_notification(),
|
final_answer_notification,
|
||||||
]
|
commentary_notification,
|
||||||
|
_completed_notification(),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
|
||||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||||
turn=SimpleNamespace(id="turn-1")
|
turn=SimpleNamespace(id="turn-1")
|
||||||
)
|
)
|
||||||
@@ -414,13 +501,16 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non
|
|||||||
text="Commentary",
|
text="Commentary",
|
||||||
phase=MessagePhase.commentary,
|
phase=MessagePhase.commentary,
|
||||||
)
|
)
|
||||||
notifications: deque[Notification] = deque(
|
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||||
[
|
{
|
||||||
commentary_notification,
|
("thread-1", "turn-1"): deque(
|
||||||
_completed_notification(),
|
[
|
||||||
]
|
commentary_notification,
|
||||||
|
_completed_notification(),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
|
||||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||||
turn=SimpleNamespace(id="turn-1")
|
turn=SimpleNamespace(id="turn-1")
|
||||||
)
|
)
|
||||||
@@ -433,12 +523,13 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non
|
|||||||
|
|
||||||
def test_thread_run_raises_on_failed_turn() -> None:
|
def test_thread_run_raises_on_failed_turn() -> None:
|
||||||
client = AppServerClient()
|
client = AppServerClient()
|
||||||
notifications: deque[Notification] = deque(
|
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||||
[
|
{
|
||||||
_completed_notification(status="failed", error_message="boom"),
|
("thread-1", "turn-1"): deque(
|
||||||
]
|
[_completed_notification(status="failed", error_message="boom")]
|
||||||
|
),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
|
||||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||||
turn=SimpleNamespace(id="turn-1")
|
turn=SimpleNamespace(id="turn-1")
|
||||||
)
|
)
|
||||||
@@ -471,12 +562,13 @@ def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None:
|
|||||||
seen["params"] = params
|
seen["params"] = params
|
||||||
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
||||||
|
|
||||||
async def fake_next_notification() -> Notification:
|
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
|
||||||
|
assert (thread_id, turn_id) == ("thread-1", "turn-1")
|
||||||
return notifications.popleft()
|
return notifications.popleft()
|
||||||
|
|
||||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||||
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
|
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
|
||||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
|
||||||
|
|
||||||
result = await AsyncThread(codex, "thread-1").run("hello")
|
result = await AsyncThread(codex, "thread-1").run("hello")
|
||||||
|
|
||||||
@@ -511,12 +603,13 @@ def test_async_thread_run_uses_last_completed_assistant_message_as_final_respons
|
|||||||
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
|
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
|
||||||
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
||||||
|
|
||||||
async def fake_next_notification() -> Notification:
|
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
|
||||||
|
assert (thread_id, turn_id) == ("thread-1", "turn-1")
|
||||||
return notifications.popleft()
|
return notifications.popleft()
|
||||||
|
|
||||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||||
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
|
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
|
||||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
|
||||||
|
|
||||||
result = await AsyncThread(codex, "thread-1").run("hello")
|
result = await AsyncThread(codex, "thread-1").run("hello")
|
||||||
|
|
||||||
@@ -550,12 +643,13 @@ def test_async_thread_run_returns_none_when_only_commentary_messages_complete()
|
|||||||
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
|
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
|
||||||
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
||||||
|
|
||||||
async def fake_next_notification() -> Notification:
|
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
|
||||||
|
assert (thread_id, turn_id) == ("thread-1", "turn-1")
|
||||||
return notifications.popleft()
|
return notifications.popleft()
|
||||||
|
|
||||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||||
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
|
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
|
||||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
|
||||||
|
|
||||||
result = await AsyncThread(codex, "thread-1").run("hello")
|
result = await AsyncThread(codex, "thread-1").run("hello")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user