Route Python SDK turn notifications by ID (#21778)

## Why

The Python SDK previously protected the stdio transport with a single
active turn-consumer guard. That avoided competing reads from stdout,
but it also meant one `Codex`/`AsyncCodex` client could not stream
multiple active turns at the same time. Notifications could also arrive
before the caller received a `TurnHandle` and registered for streaming,
so the SDK needed an explicit routing layer instead of letting
individual API calls read directly from the shared transport.

## What Changed

- Added a private `MessageRouter` that owns per-request response queues,
per-turn notification queues, pending turn-notification replay, and
global notification delivery behind a single stdout reader thread.
- Generated typed notification routing metadata so turn IDs come from
known payload shapes instead of router-side attribute guessing, with
explicit fallback handling for unknown notification payloads.
- Updated sync and async turn streaming so `TurnHandle.stream()`/`run()`
and `stream_text()` consume only notifications for their own turn ID,
while `AsyncAppServerClient` no longer serializes all transport calls
behind one async lock.
- Cleared pending turn-notification buffers when unregistered turns
complete so never-consumed turn handles do not leave stale queues
behind.
- Removed the internal stream-until helper now that turn completion
waiting can register directly with routed turn notifications.
- Updated Python SDK docs and focused tests for concurrent transport
calls, interleaved turn routing, buffered early notifications, unknown
notification routing, async delegation, and routed turn completion
behavior.

## Validation

- `uv run --extra dev ruff format scripts/update_sdk_artifacts.py
src/codex_app_server/_message_router.py src/codex_app_server/client.py
src/codex_app_server/generated/notification_registry.py
tests/test_client_rpc_methods.py
tests/test_public_api_runtime_behavior.py
tests/test_async_client_behavior.py`
- `uv run --extra dev ruff check scripts/update_sdk_artifacts.py
src/codex_app_server/_message_router.py src/codex_app_server/client.py
src/codex_app_server/generated/notification_registry.py
tests/test_client_rpc_methods.py
tests/test_public_api_runtime_behavior.py
tests/test_async_client_behavior.py`
- `uv run --extra dev pytest tests/test_client_rpc_methods.py
tests/test_public_api_runtime_behavior.py
tests/test_async_client_behavior.py`
- `git diff --check`

---------

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Ahmed Ibrahim
2026-05-09 07:16:23 +03:00
committed by GitHub
parent 77d9223e9f
commit ebe75bb683
11 changed files with 916 additions and 197 deletions

View File

@@ -0,0 +1,158 @@
from __future__ import annotations
import queue
import threading
from collections import deque
from .errors import AppServerError, map_jsonrpc_error
from .generated.notification_registry import notification_turn_id
from .models import JsonValue, Notification, UnknownNotification
ResponseQueueItem = JsonValue | BaseException
NotificationQueueItem = Notification | BaseException
class MessageRouter:
"""Route reader-thread messages to the SDK operation waiting for them.
The app-server stdio transport is a single ordered stream, so only the
reader thread should consume stdout. This router keeps the rest of the SDK
from competing for that stream by giving each in-flight JSON-RPC request
and active turn stream its own queue.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {}
self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {}
self._pending_turn_notifications: dict[str, deque[Notification]] = {}
self._global_notifications: queue.Queue[NotificationQueueItem] = queue.Queue()
def create_response_waiter(self, request_id: str) -> queue.Queue[ResponseQueueItem]:
"""Register a one-shot queue for a JSON-RPC response id."""
waiter: queue.Queue[ResponseQueueItem] = queue.Queue(maxsize=1)
with self._lock:
self._response_waiters[request_id] = waiter
return waiter
def discard_response_waiter(self, request_id: str) -> None:
"""Remove a response waiter when the request could not be written."""
with self._lock:
self._response_waiters.pop(request_id, None)
def next_global_notification(self) -> Notification:
"""Block until the next notification that is not scoped to a turn."""
item = self._global_notifications.get()
if isinstance(item, BaseException):
raise item
return item
def register_turn(self, turn_id: str) -> None:
"""Register a queue for a turn stream and replay early events."""
turn_queue: queue.Queue[NotificationQueueItem] = queue.Queue()
with self._lock:
if turn_id in self._turn_notifications:
return
# A turn can emit events immediately after turn/start, before the
# caller receives the TurnHandle and starts streaming.
pending = self._pending_turn_notifications.pop(turn_id, deque())
self._turn_notifications[turn_id] = turn_queue
for notification in pending:
turn_queue.put(notification)
def unregister_turn(self, turn_id: str) -> None:
"""Stop routing future turn events to the stream queue."""
with self._lock:
self._turn_notifications.pop(turn_id, None)
def next_turn_notification(self, turn_id: str) -> Notification:
"""Block until the next notification for a registered turn."""
with self._lock:
turn_queue = self._turn_notifications.get(turn_id)
if turn_queue is None:
raise RuntimeError(f"turn {turn_id!r} is not registered for streaming")
item = turn_queue.get()
if isinstance(item, BaseException):
raise item
return item
def route_response(self, msg: dict[str, JsonValue]) -> None:
"""Deliver a JSON-RPC response or error to its request waiter."""
request_id = msg.get("id")
with self._lock:
waiter = self._response_waiters.pop(str(request_id), None)
if waiter is None:
return
if "error" in msg:
err = msg["error"]
if isinstance(err, dict):
waiter.put(
map_jsonrpc_error(
int(err.get("code", -32000)),
str(err.get("message", "unknown")),
err.get("data"),
)
)
else:
waiter.put(AppServerError("Malformed JSON-RPC error response"))
return
waiter.put(msg.get("result"))
def route_notification(self, notification: Notification) -> None:
"""Deliver a notification to a turn queue or the global queue."""
turn_id = self._notification_turn_id(notification)
if turn_id is None:
self._global_notifications.put(notification)
return
with self._lock:
turn_queue = self._turn_notifications.get(turn_id)
if turn_queue is None:
if notification.method == "turn/completed":
self._pending_turn_notifications.pop(turn_id, None)
return
self._pending_turn_notifications.setdefault(turn_id, deque()).append(
notification
)
return
turn_queue.put(notification)
def fail_all(self, exc: BaseException) -> None:
"""Wake every blocked waiter when the reader thread exits."""
with self._lock:
response_waiters = list(self._response_waiters.values())
self._response_waiters.clear()
turn_queues = list(self._turn_notifications.values())
self._pending_turn_notifications.clear()
# Put the same transport failure into every queue so no SDK call blocks
# forever waiting for a response that cannot arrive.
for waiter in response_waiters:
waiter.put(exc)
for turn_queue in turn_queues:
turn_queue.put(exc)
self._global_notifications.put(exc)
def _notification_turn_id(self, notification: Notification) -> str | None:
payload = notification.payload
if isinstance(payload, UnknownNotification):
raw_turn_id = payload.params.get("turnId")
if isinstance(raw_turn_id, str):
return raw_turn_id
raw_turn = payload.params.get("turn")
if isinstance(raw_turn, dict):
raw_nested_turn_id = raw_turn.get("id")
if isinstance(raw_nested_turn_id, str):
return raw_nested_turn_id
return None
return notification_turn_id(payload)

View File

@@ -38,14 +38,14 @@ from .generated.v2_all import (
)
from .models import InitializeResponse, JsonObject, Notification, ServerInfo
from ._inputs import (
ImageInput,
ImageInput as ImageInput,
Input,
InputItem,
LocalImageInput,
MentionInput,
InputItem as InputItem,
LocalImageInput as LocalImageInput,
MentionInput as MentionInput,
RunInput,
SkillInput,
TextInput,
SkillInput as SkillInput,
TextInput as TextInput,
_normalize_run_input,
_to_wire_input,
)
@@ -274,6 +274,7 @@ class Codex:
def thread_unarchive(self, thread_id: str) -> Thread:
unarchived = self._client.thread_unarchive(thread_id)
return Thread(self._client, unarchived.thread.id)
# END GENERATED: Codex.flat_methods
def models(self, *, include_hidden: bool = False) -> ModelListResponse:
@@ -476,6 +477,7 @@ class AsyncCodex:
await self._ensure_initialized()
unarchived = await self._client.thread_unarchive(thread_id)
return AsyncThread(self, unarchived.thread.id)
# END GENERATED: AsyncCodex.flat_methods
async def models(self, *, include_hidden: bool = False) -> ModelListResponse:
@@ -555,6 +557,7 @@ class Thread:
)
turn = self._client.turn_start(self.id, wire_input, params=params)
return TurnHandle(self._client, self.id, turn.turn.id)
# END GENERATED: Thread.flat_methods
def read(self, *, include_turns: bool = False) -> ThreadReadResponse:
@@ -644,6 +647,7 @@ class AsyncThread:
params=params,
)
return AsyncTurnHandle(self._codex, self.id, turn.turn.id)
# END GENERATED: AsyncThread.flat_methods
async def read(self, *, include_turns: bool = False) -> ThreadReadResponse:
@@ -674,11 +678,10 @@ class TurnHandle:
return self._client.turn_interrupt(self.thread_id, self.id)
def stream(self) -> Iterator[Notification]:
# TODO: replace this client-wide experimental guard with per-turn event demux.
self._client.acquire_turn_consumer(self.id)
self._client.register_turn_notifications(self.id)
try:
while True:
event = self._client.next_notification()
event = self._client.next_turn_notification(self.id)
yield event
if (
event.method == "turn/completed"
@@ -687,7 +690,7 @@ class TurnHandle:
):
break
finally:
self._client.release_turn_consumer(self.id)
self._client.unregister_turn_notifications(self.id)
def run(self) -> AppServerTurn:
completed: TurnCompletedNotification | None = None
@@ -728,11 +731,10 @@ class AsyncTurnHandle:
async def stream(self) -> AsyncIterator[Notification]:
await self._codex._ensure_initialized()
# TODO: replace this client-wide experimental guard with per-turn event demux.
self._codex._client.acquire_turn_consumer(self.id)
self._codex._client.register_turn_notifications(self.id)
try:
while True:
event = await self._codex._client.next_notification()
event = await self._codex._client.next_turn_notification(self.id)
yield event
if (
event.method == "turn/completed"
@@ -741,7 +743,7 @@ class AsyncTurnHandle:
):
break
finally:
self._codex._client.release_turn_consumer(self.id)
self._codex._client.unregister_turn_notifications(self.id)
async def run(self) -> AppServerTurn:
completed: TurnCompletedNotification | None = None

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Iterator
from typing import AsyncIterator, Callable, Iterable, ParamSpec, TypeVar
from typing import AsyncIterator, Callable, ParamSpec, TypeVar
from pydantic import BaseModel
@@ -41,8 +41,6 @@ class AsyncAppServerClient:
def __init__(self, config: AppServerConfig | None = None) -> None:
self._sync = AppServerClient(config=config)
# Single stdio transport cannot be read safely from multiple threads.
self._transport_lock = asyncio.Lock()
async def __aenter__(self) -> "AsyncAppServerClient":
await self.start()
@@ -58,8 +56,7 @@ class AsyncAppServerClient:
*args: ParamsT.args,
**kwargs: ParamsT.kwargs,
) -> ReturnT:
async with self._transport_lock:
return await asyncio.to_thread(fn, *args, **kwargs)
return await asyncio.to_thread(fn, *args, **kwargs)
@staticmethod
def _next_from_iterator(
@@ -79,11 +76,11 @@ class AsyncAppServerClient:
async def initialize(self) -> InitializeResponse:
return await self._call_sync(self._sync.initialize)
def acquire_turn_consumer(self, turn_id: str) -> None:
self._sync.acquire_turn_consumer(turn_id)
def register_turn_notifications(self, turn_id: str) -> None:
self._sync.register_turn_notifications(turn_id)
def release_turn_consumer(self, turn_id: str) -> None:
self._sync.release_turn_consumer(turn_id)
def unregister_turn_notifications(self, turn_id: str) -> None:
self._sync.unregister_turn_notifications(turn_id)
async def request(
self,
@@ -99,7 +96,9 @@ class AsyncAppServerClient:
response_model=response_model,
)
async def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
async def thread_start(
self, params: V2ThreadStartParams | JsonObject | None = None
) -> ThreadStartResponse:
return await self._call_sync(self._sync.thread_start, params)
async def thread_resume(
@@ -109,10 +108,14 @@ class AsyncAppServerClient:
) -> ThreadResumeResponse:
return await self._call_sync(self._sync.thread_resume, thread_id, params)
async def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse:
async def thread_list(
self, params: V2ThreadListParams | JsonObject | None = None
) -> ThreadListResponse:
return await self._call_sync(self._sync.thread_list, params)
async def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse:
async def thread_read(
self, thread_id: str, include_turns: bool = False
) -> ThreadReadResponse:
return await self._call_sync(self._sync.thread_read, thread_id, include_turns)
async def thread_fork(
@@ -140,9 +143,13 @@ class AsyncAppServerClient:
input_items: list[JsonObject] | JsonObject | str,
params: V2TurnStartParams | JsonObject | None = None,
) -> TurnStartResponse:
return await self._call_sync(self._sync.turn_start, thread_id, input_items, params)
return await self._call_sync(
self._sync.turn_start, thread_id, input_items, params
)
async def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
async def turn_interrupt(
self, thread_id: str, turn_id: str
) -> TurnInterruptResponse:
return await self._call_sync(self._sync.turn_interrupt, thread_id, turn_id)
async def turn_steer(
@@ -184,25 +191,24 @@ class AsyncAppServerClient:
async def next_notification(self) -> Notification:
return await self._call_sync(self._sync.next_notification)
async def next_turn_notification(self, turn_id: str) -> Notification:
return await self._call_sync(self._sync.next_turn_notification, turn_id)
async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
return await self._call_sync(self._sync.wait_for_turn_completed, turn_id)
async def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]:
return await self._call_sync(self._sync.stream_until_methods, methods)
async def stream_text(
self,
thread_id: str,
text: str,
params: V2TurnStartParams | JsonObject | None = None,
) -> AsyncIterator[AgentMessageDeltaNotification]:
async with self._transport_lock:
iterator = self._sync.stream_text(thread_id, text, params)
while True:
has_value, chunk = await asyncio.to_thread(
self._next_from_iterator,
iterator,
)
if not has_value:
break
yield chunk
iterator = self._sync.stream_text(thread_id, text, params)
while True:
has_value, chunk = await asyncio.to_thread(
self._next_from_iterator,
iterator,
)
if not has_value:
break
yield chunk

View File

@@ -8,11 +8,11 @@ import uuid
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Iterable, Iterator, TypeVar
from typing import Callable, Iterator, TypeVar
from pydantic import BaseModel
from .errors import AppServerError, TransportClosedError, map_jsonrpc_error
from .errors import AppServerError, TransportClosedError
from .generated.notification_registry import NOTIFICATION_MODELS
from .generated.v2_all import (
AgentMessageDeltaNotification,
@@ -43,6 +43,7 @@ from .models import (
Notification,
UnknownNotification,
)
from ._message_router import MessageRouter
from .retry import retry_on_overload
from ._version import __version__ as SDK_VERSION
@@ -75,7 +76,9 @@ def _params_dict(
return dumped
if isinstance(params, dict):
return params
raise TypeError(f"Expected generated params model or dict, got {type(params).__name__}")
raise TypeError(
f"Expected generated params model or dict, got {type(params).__name__}"
)
def _installed_codex_path() -> Path:
@@ -146,11 +149,10 @@ class AppServerClient:
self._approval_handler = approval_handler or self._default_approval_handler
self._proc: subprocess.Popen[str] | None = None
self._lock = threading.Lock()
self._turn_consumer_lock = threading.Lock()
self._active_turn_consumer: str | None = None
self._pending_notifications: deque[Notification] = deque()
self._router = MessageRouter()
self._stderr_lines: deque[str] = deque(maxlen=400)
self._stderr_thread: threading.Thread | None = None
self._reader_thread: threading.Thread | None = None
def __enter__(self) -> "AppServerClient":
self.start()
@@ -189,13 +191,13 @@ class AppServerClient:
)
self._start_stderr_drain_thread()
self._start_reader_thread()
def close(self) -> None:
if self._proc is None:
return
proc = self._proc
self._proc = None
self._active_turn_consumer = None
if proc.stdin:
proc.stdin.close()
@@ -207,6 +209,8 @@ class AppServerClient:
if self._stderr_thread and self._stderr_thread.is_alive():
self._stderr_thread.join(timeout=0.5)
if self._reader_thread and self._reader_thread.is_alive():
self._reader_thread.join(timeout=0.5)
def initialize(self) -> InitializeResponse:
result = self.request(
@@ -240,70 +244,42 @@ class AppServerClient:
def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue:
request_id = str(uuid.uuid4())
self._write_message({"id": request_id, "method": method, "params": params or {}})
waiter = self._router.create_response_waiter(request_id)
while True:
msg = self._read_message()
try:
self._write_message(
{"id": request_id, "method": method, "params": params or {}}
)
except BaseException:
self._router.discard_response_waiter(request_id)
raise
if "method" in msg and "id" in msg:
response = self._handle_server_request(msg)
self._write_message({"id": msg["id"], "result": response})
continue
if "method" in msg and "id" not in msg:
self._pending_notifications.append(
self._coerce_notification(msg["method"], msg.get("params"))
)
continue
if msg.get("id") != request_id:
continue
if "error" in msg:
err = msg["error"]
if isinstance(err, dict):
raise map_jsonrpc_error(
int(err.get("code", -32000)),
str(err.get("message", "unknown")),
err.get("data"),
)
raise AppServerError("Malformed JSON-RPC error response")
return msg.get("result")
item = waiter.get()
if isinstance(item, BaseException):
raise item
return item
def notify(self, method: str, params: JsonObject | None = None) -> None:
self._write_message({"method": method, "params": params or {}})
def next_notification(self) -> Notification:
if self._pending_notifications:
return self._pending_notifications.popleft()
return self._router.next_global_notification()
while True:
msg = self._read_message()
if "method" in msg and "id" in msg:
response = self._handle_server_request(msg)
self._write_message({"id": msg["id"], "result": response})
continue
if "method" in msg and "id" not in msg:
return self._coerce_notification(msg["method"], msg.get("params"))
def register_turn_notifications(self, turn_id: str) -> None:
self._router.register_turn(turn_id)
def acquire_turn_consumer(self, turn_id: str) -> None:
with self._turn_consumer_lock:
if self._active_turn_consumer is not None:
raise RuntimeError(
"Concurrent turn consumers are not yet supported in the experimental SDK. "
f"Client is already streaming turn {self._active_turn_consumer!r}; "
f"cannot start turn {turn_id!r} until the active consumer finishes."
)
self._active_turn_consumer = turn_id
def unregister_turn_notifications(self, turn_id: str) -> None:
self._router.unregister_turn(turn_id)
def release_turn_consumer(self, turn_id: str) -> None:
with self._turn_consumer_lock:
if self._active_turn_consumer == turn_id:
self._active_turn_consumer = None
def next_turn_notification(self, turn_id: str) -> Notification:
return self._router.next_turn_notification(turn_id)
def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
return self.request("thread/start", _params_dict(params), response_model=ThreadStartResponse)
def thread_start(
self, params: V2ThreadStartParams | JsonObject | None = None
) -> ThreadStartResponse:
return self.request(
"thread/start", _params_dict(params), response_model=ThreadStartResponse
)
def thread_resume(
self,
@@ -311,12 +287,20 @@ class AppServerClient:
params: V2ThreadResumeParams | JsonObject | None = None,
) -> ThreadResumeResponse:
payload = {"threadId": thread_id, **_params_dict(params)}
return self.request("thread/resume", payload, response_model=ThreadResumeResponse)
return self.request(
"thread/resume", payload, response_model=ThreadResumeResponse
)
def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse:
return self.request("thread/list", _params_dict(params), response_model=ThreadListResponse)
def thread_list(
self, params: V2ThreadListParams | JsonObject | None = None
) -> ThreadListResponse:
return self.request(
"thread/list", _params_dict(params), response_model=ThreadListResponse
)
def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse:
def thread_read(
self, thread_id: str, include_turns: bool = False
) -> ThreadReadResponse:
return self.request(
"thread/read",
{"threadId": thread_id, "includeTurns": include_turns},
@@ -332,10 +316,18 @@ class AppServerClient:
return self.request("thread/fork", payload, response_model=ThreadForkResponse)
def thread_archive(self, thread_id: str) -> ThreadArchiveResponse:
return self.request("thread/archive", {"threadId": thread_id}, response_model=ThreadArchiveResponse)
return self.request(
"thread/archive",
{"threadId": thread_id},
response_model=ThreadArchiveResponse,
)
def thread_unarchive(self, thread_id: str) -> ThreadUnarchiveResponse:
return self.request("thread/unarchive", {"threadId": thread_id}, response_model=ThreadUnarchiveResponse)
return self.request(
"thread/unarchive",
{"threadId": thread_id},
response_model=ThreadUnarchiveResponse,
)
def thread_set_name(self, thread_id: str, name: str) -> ThreadSetNameResponse:
return self.request(
@@ -362,7 +354,9 @@ class AppServerClient:
"threadId": thread_id,
"input": self._normalize_input_items(input_items),
}
return self.request("turn/start", payload, response_model=TurnStartResponse)
started = self.request("turn/start", payload, response_model=TurnStartResponse)
self.register_turn_notifications(started.turn.id)
return started
def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
return self.request(
@@ -412,23 +406,18 @@ class AppServerClient:
)
def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
while True:
notification = self.next_notification()
if (
notification.method == "turn/completed"
and isinstance(notification.payload, TurnCompletedNotification)
and notification.payload.turn.id == turn_id
):
return notification.payload
def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]:
target_methods = {methods} if isinstance(methods, str) else set(methods)
out: list[Notification] = []
while True:
notification = self.next_notification()
out.append(notification)
if notification.method in target_methods:
return out
self.register_turn_notifications(turn_id)
try:
while True:
notification = self.next_turn_notification(turn_id)
if (
notification.method == "turn/completed"
and isinstance(notification.payload, TurnCompletedNotification)
and notification.payload.turn.id == turn_id
):
return notification.payload
finally:
self.unregister_turn_notifications(turn_id)
def stream_text(
self,
@@ -438,33 +427,41 @@ class AppServerClient:
) -> Iterator[AgentMessageDeltaNotification]:
started = self.turn_start(thread_id, text, params=params)
turn_id = started.turn.id
while True:
notification = self.next_notification()
if (
notification.method == "item/agentMessage/delta"
and isinstance(notification.payload, AgentMessageDeltaNotification)
and notification.payload.turn_id == turn_id
):
yield notification.payload
continue
if (
notification.method == "turn/completed"
and isinstance(notification.payload, TurnCompletedNotification)
and notification.payload.turn.id == turn_id
):
break
self.register_turn_notifications(turn_id)
try:
while True:
notification = self.next_turn_notification(turn_id)
if (
notification.method == "item/agentMessage/delta"
and isinstance(notification.payload, AgentMessageDeltaNotification)
and notification.payload.turn_id == turn_id
):
yield notification.payload
continue
if (
notification.method == "turn/completed"
and isinstance(notification.payload, TurnCompletedNotification)
and notification.payload.turn.id == turn_id
):
break
finally:
self.unregister_turn_notifications(turn_id)
def _coerce_notification(self, method: str, params: object) -> Notification:
params_dict = params if isinstance(params, dict) else {}
model = NOTIFICATION_MODELS.get(method)
if model is None:
return Notification(method=method, payload=UnknownNotification(params=params_dict))
return Notification(
method=method, payload=UnknownNotification(params=params_dict)
)
try:
payload = model.model_validate(params_dict)
except Exception: # noqa: BLE001
return Notification(method=method, payload=UnknownNotification(params=params_dict))
return Notification(
method=method, payload=UnknownNotification(params=params_dict)
)
return Notification(method=method, payload=payload)
def _normalize_input_items(
@@ -477,7 +474,9 @@ class AppServerClient:
return [input_items]
return input_items
def _default_approval_handler(self, method: str, params: JsonObject | None) -> JsonObject:
def _default_approval_handler(
self, method: str, params: JsonObject | None
) -> JsonObject:
if method == "item/commandExecution/requestApproval":
return {"decision": "accept"}
if method == "item/fileChange/requestApproval":
@@ -498,6 +497,32 @@ class AppServerClient:
self._stderr_thread = threading.Thread(target=_drain, daemon=True)
self._stderr_thread.start()
def _start_reader_thread(self) -> None:
if self._proc is None or self._proc.stdout is None:
return
self._reader_thread = threading.Thread(target=self._reader_loop, daemon=True)
self._reader_thread.start()
def _reader_loop(self) -> None:
try:
while True:
msg = self._read_message()
if "method" in msg and "id" in msg:
response = self._handle_server_request(msg)
self._write_message({"id": msg["id"], "result": response})
continue
if "method" in msg and "id" not in msg:
method = msg["method"]
if isinstance(method, str):
self._router.route_notification(
self._coerce_notification(method, msg.get("params"))
)
continue
self._router.route_response(msg)
except BaseException as exc:
self._router.fail_all(exc)
def _stderr_tail(self, limit: int = 40) -> str:
return "\n".join(list(self._stderr_lines)[-limit:])

View File

@@ -130,3 +130,43 @@ NOTIFICATION_MODELS: dict[str, type[BaseModel]] = {
"windows/worldWritableWarning": WindowsWorldWritableWarningNotification,
"windowsSandbox/setupCompleted": WindowsSandboxSetupCompletedNotification,
}
DIRECT_TURN_ID_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = (
AgentMessageDeltaNotification,
CommandExecutionOutputDeltaNotification,
ContextCompactedNotification,
ErrorNotification,
FileChangeOutputDeltaNotification,
FileChangePatchUpdatedNotification,
HookCompletedNotification,
HookStartedNotification,
ItemCompletedNotification,
ItemGuardianApprovalReviewCompletedNotification,
ItemGuardianApprovalReviewStartedNotification,
ItemStartedNotification,
McpToolCallProgressNotification,
ModelReroutedNotification,
ModelVerificationNotification,
PlanDeltaNotification,
ReasoningSummaryPartAddedNotification,
ReasoningSummaryTextDeltaNotification,
ReasoningTextDeltaNotification,
TerminalInteractionNotification,
ThreadGoalUpdatedNotification,
ThreadTokenUsageUpdatedNotification,
TurnDiffUpdatedNotification,
TurnPlanUpdatedNotification,
)
NESTED_TURN_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = (
TurnCompletedNotification,
TurnStartedNotification,
)
def notification_turn_id(payload: BaseModel) -> str | None:
if isinstance(payload, DIRECT_TURN_ID_NOTIFICATION_TYPES):
return payload.turn_id if isinstance(payload.turn_id, str) else None
if isinstance(payload, NESTED_TURN_NOTIFICATION_TYPES):
return payload.turn.id
return None