mirror of
https://github.com/openai/codex.git
synced 2026-05-24 13:04:29 +00:00
Route Python SDK turn notifications by id
Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -38,14 +38,14 @@ from .generated.v2_all import (
|
||||
)
|
||||
from .models import InitializeResponse, JsonObject, Notification, ServerInfo
|
||||
from ._inputs import (
|
||||
ImageInput,
|
||||
ImageInput, # noqa: F401
|
||||
Input,
|
||||
InputItem,
|
||||
LocalImageInput,
|
||||
MentionInput,
|
||||
InputItem, # noqa: F401
|
||||
LocalImageInput, # noqa: F401
|
||||
MentionInput, # noqa: F401
|
||||
RunInput,
|
||||
SkillInput,
|
||||
TextInput,
|
||||
SkillInput, # noqa: F401
|
||||
TextInput, # noqa: F401
|
||||
_normalize_run_input,
|
||||
_to_wire_input,
|
||||
)
|
||||
@@ -274,6 +274,7 @@ class Codex:
|
||||
def thread_unarchive(self, thread_id: str) -> Thread:
|
||||
unarchived = self._client.thread_unarchive(thread_id)
|
||||
return Thread(self._client, unarchived.thread.id)
|
||||
|
||||
# END GENERATED: Codex.flat_methods
|
||||
|
||||
def models(self, *, include_hidden: bool = False) -> ModelListResponse:
|
||||
@@ -476,6 +477,7 @@ class AsyncCodex:
|
||||
await self._ensure_initialized()
|
||||
unarchived = await self._client.thread_unarchive(thread_id)
|
||||
return AsyncThread(self, unarchived.thread.id)
|
||||
|
||||
# END GENERATED: AsyncCodex.flat_methods
|
||||
|
||||
async def models(self, *, include_hidden: bool = False) -> ModelListResponse:
|
||||
@@ -555,6 +557,7 @@ class Thread:
|
||||
)
|
||||
turn = self._client.turn_start(self.id, wire_input, params=params)
|
||||
return TurnHandle(self._client, self.id, turn.turn.id)
|
||||
|
||||
# END GENERATED: Thread.flat_methods
|
||||
|
||||
def read(self, *, include_turns: bool = False) -> ThreadReadResponse:
|
||||
@@ -644,6 +647,7 @@ class AsyncThread:
|
||||
params=params,
|
||||
)
|
||||
return AsyncTurnHandle(self._codex, self.id, turn.turn.id)
|
||||
|
||||
# END GENERATED: AsyncThread.flat_methods
|
||||
|
||||
async def read(self, *, include_turns: bool = False) -> ThreadReadResponse:
|
||||
@@ -674,11 +678,10 @@ class TurnHandle:
|
||||
return self._client.turn_interrupt(self.thread_id, self.id)
|
||||
|
||||
def stream(self) -> Iterator[Notification]:
|
||||
# TODO: replace this client-wide experimental guard with per-turn event demux.
|
||||
self._client.acquire_turn_consumer(self.id)
|
||||
self._client.register_turn_notifications(self.id)
|
||||
try:
|
||||
while True:
|
||||
event = self._client.next_notification()
|
||||
event = self._client.next_turn_notification(self.id)
|
||||
yield event
|
||||
if (
|
||||
event.method == "turn/completed"
|
||||
@@ -687,7 +690,7 @@ class TurnHandle:
|
||||
):
|
||||
break
|
||||
finally:
|
||||
self._client.release_turn_consumer(self.id)
|
||||
self._client.unregister_turn_notifications(self.id)
|
||||
|
||||
def run(self) -> AppServerTurn:
|
||||
completed: TurnCompletedNotification | None = None
|
||||
@@ -728,11 +731,10 @@ class AsyncTurnHandle:
|
||||
|
||||
async def stream(self) -> AsyncIterator[Notification]:
|
||||
await self._codex._ensure_initialized()
|
||||
# TODO: replace this client-wide experimental guard with per-turn event demux.
|
||||
self._codex._client.acquire_turn_consumer(self.id)
|
||||
self._codex._client.register_turn_notifications(self.id)
|
||||
try:
|
||||
while True:
|
||||
event = await self._codex._client.next_notification()
|
||||
event = await self._codex._client.next_turn_notification(self.id)
|
||||
yield event
|
||||
if (
|
||||
event.method == "turn/completed"
|
||||
@@ -741,7 +743,7 @@ class AsyncTurnHandle:
|
||||
):
|
||||
break
|
||||
finally:
|
||||
self._codex._client.release_turn_consumer(self.id)
|
||||
self._codex._client.unregister_turn_notifications(self.id)
|
||||
|
||||
async def run(self) -> AppServerTurn:
|
||||
completed: TurnCompletedNotification | None = None
|
||||
|
||||
@@ -79,11 +79,11 @@ class AsyncAppServerClient:
|
||||
async def initialize(self) -> InitializeResponse:
|
||||
return await self._call_sync(self._sync.initialize)
|
||||
|
||||
def acquire_turn_consumer(self, turn_id: str) -> None:
|
||||
self._sync.acquire_turn_consumer(turn_id)
|
||||
def register_turn_notifications(self, turn_id: str) -> None:
|
||||
self._sync.register_turn_notifications(turn_id)
|
||||
|
||||
def release_turn_consumer(self, turn_id: str) -> None:
|
||||
self._sync.release_turn_consumer(turn_id)
|
||||
def unregister_turn_notifications(self, turn_id: str) -> None:
|
||||
self._sync.unregister_turn_notifications(turn_id)
|
||||
|
||||
async def request(
|
||||
self,
|
||||
@@ -99,7 +99,9 @@ class AsyncAppServerClient:
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
async def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
|
||||
async def thread_start(
|
||||
self, params: V2ThreadStartParams | JsonObject | None = None
|
||||
) -> ThreadStartResponse:
|
||||
return await self._call_sync(self._sync.thread_start, params)
|
||||
|
||||
async def thread_resume(
|
||||
@@ -109,10 +111,14 @@ class AsyncAppServerClient:
|
||||
) -> ThreadResumeResponse:
|
||||
return await self._call_sync(self._sync.thread_resume, thread_id, params)
|
||||
|
||||
async def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse:
|
||||
async def thread_list(
|
||||
self, params: V2ThreadListParams | JsonObject | None = None
|
||||
) -> ThreadListResponse:
|
||||
return await self._call_sync(self._sync.thread_list, params)
|
||||
|
||||
async def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse:
|
||||
async def thread_read(
|
||||
self, thread_id: str, include_turns: bool = False
|
||||
) -> ThreadReadResponse:
|
||||
return await self._call_sync(self._sync.thread_read, thread_id, include_turns)
|
||||
|
||||
async def thread_fork(
|
||||
@@ -140,9 +146,13 @@ class AsyncAppServerClient:
|
||||
input_items: list[JsonObject] | JsonObject | str,
|
||||
params: V2TurnStartParams | JsonObject | None = None,
|
||||
) -> TurnStartResponse:
|
||||
return await self._call_sync(self._sync.turn_start, thread_id, input_items, params)
|
||||
return await self._call_sync(
|
||||
self._sync.turn_start, thread_id, input_items, params
|
||||
)
|
||||
|
||||
async def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
|
||||
async def turn_interrupt(
|
||||
self, thread_id: str, turn_id: str
|
||||
) -> TurnInterruptResponse:
|
||||
return await self._call_sync(self._sync.turn_interrupt, thread_id, turn_id)
|
||||
|
||||
async def turn_steer(
|
||||
@@ -184,10 +194,15 @@ class AsyncAppServerClient:
|
||||
async def next_notification(self) -> Notification:
|
||||
return await self._call_sync(self._sync.next_notification)
|
||||
|
||||
async def next_turn_notification(self, turn_id: str) -> Notification:
|
||||
return await self._call_sync(self._sync.next_turn_notification, turn_id)
|
||||
|
||||
async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
|
||||
return await self._call_sync(self._sync.wait_for_turn_completed, turn_id)
|
||||
|
||||
async def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]:
|
||||
async def stream_until_methods(
|
||||
self, methods: Iterable[str] | str
|
||||
) -> list[Notification]:
|
||||
return await self._call_sync(self._sync.stream_until_methods, methods)
|
||||
|
||||
async def stream_text(
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import subprocess
|
||||
import threading
|
||||
import uuid
|
||||
@@ -49,6 +50,8 @@ from ._version import __version__ as SDK_VERSION
|
||||
ModelT = TypeVar("ModelT", bound=BaseModel)
|
||||
ApprovalHandler = Callable[[str, JsonObject | None], JsonObject]
|
||||
RUNTIME_PKG_NAME = "openai-codex-cli-bin"
|
||||
ResponseQueueItem = JsonValue | BaseException
|
||||
NotificationQueueItem = Notification | BaseException
|
||||
|
||||
|
||||
def _params_dict(
|
||||
@@ -75,7 +78,9 @@ def _params_dict(
|
||||
return dumped
|
||||
if isinstance(params, dict):
|
||||
return params
|
||||
raise TypeError(f"Expected generated params model or dict, got {type(params).__name__}")
|
||||
raise TypeError(
|
||||
f"Expected generated params model or dict, got {type(params).__name__}"
|
||||
)
|
||||
|
||||
|
||||
def _installed_codex_path() -> Path:
|
||||
@@ -146,11 +151,14 @@ class AppServerClient:
|
||||
self._approval_handler = approval_handler or self._default_approval_handler
|
||||
self._proc: subprocess.Popen[str] | None = None
|
||||
self._lock = threading.Lock()
|
||||
self._turn_consumer_lock = threading.Lock()
|
||||
self._active_turn_consumer: str | None = None
|
||||
self._pending_notifications: deque[Notification] = deque()
|
||||
self._router_lock = threading.Lock()
|
||||
self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {}
|
||||
self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {}
|
||||
self._pending_turn_notifications: dict[str, deque[Notification]] = {}
|
||||
self._global_notifications: queue.Queue[NotificationQueueItem] = queue.Queue()
|
||||
self._stderr_lines: deque[str] = deque(maxlen=400)
|
||||
self._stderr_thread: threading.Thread | None = None
|
||||
self._reader_thread: threading.Thread | None = None
|
||||
|
||||
def __enter__(self) -> "AppServerClient":
|
||||
self.start()
|
||||
@@ -189,13 +197,13 @@ class AppServerClient:
|
||||
)
|
||||
|
||||
self._start_stderr_drain_thread()
|
||||
self._start_reader_thread()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._proc is None:
|
||||
return
|
||||
proc = self._proc
|
||||
self._proc = None
|
||||
self._active_turn_consumer = None
|
||||
|
||||
if proc.stdin:
|
||||
proc.stdin.close()
|
||||
@@ -207,6 +215,8 @@ class AppServerClient:
|
||||
|
||||
if self._stderr_thread and self._stderr_thread.is_alive():
|
||||
self._stderr_thread.join(timeout=0.5)
|
||||
if self._reader_thread and self._reader_thread.is_alive():
|
||||
self._reader_thread.join(timeout=0.5)
|
||||
|
||||
def initialize(self) -> InitializeResponse:
|
||||
result = self.request(
|
||||
@@ -240,70 +250,63 @@ class AppServerClient:
|
||||
|
||||
def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue:
|
||||
request_id = str(uuid.uuid4())
|
||||
self._write_message({"id": request_id, "method": method, "params": params or {}})
|
||||
waiter: queue.Queue[ResponseQueueItem] = queue.Queue(maxsize=1)
|
||||
with self._router_lock:
|
||||
self._response_waiters[request_id] = waiter
|
||||
|
||||
while True:
|
||||
msg = self._read_message()
|
||||
try:
|
||||
self._write_message(
|
||||
{"id": request_id, "method": method, "params": params or {}}
|
||||
)
|
||||
except BaseException:
|
||||
with self._router_lock:
|
||||
self._response_waiters.pop(request_id, None)
|
||||
raise
|
||||
|
||||
if "method" in msg and "id" in msg:
|
||||
response = self._handle_server_request(msg)
|
||||
self._write_message({"id": msg["id"], "result": response})
|
||||
continue
|
||||
|
||||
if "method" in msg and "id" not in msg:
|
||||
self._pending_notifications.append(
|
||||
self._coerce_notification(msg["method"], msg.get("params"))
|
||||
)
|
||||
continue
|
||||
|
||||
if msg.get("id") != request_id:
|
||||
continue
|
||||
|
||||
if "error" in msg:
|
||||
err = msg["error"]
|
||||
if isinstance(err, dict):
|
||||
raise map_jsonrpc_error(
|
||||
int(err.get("code", -32000)),
|
||||
str(err.get("message", "unknown")),
|
||||
err.get("data"),
|
||||
)
|
||||
raise AppServerError("Malformed JSON-RPC error response")
|
||||
|
||||
return msg.get("result")
|
||||
item = waiter.get()
|
||||
if isinstance(item, BaseException):
|
||||
raise item
|
||||
return item
|
||||
|
||||
def notify(self, method: str, params: JsonObject | None = None) -> None:
|
||||
self._write_message({"method": method, "params": params or {}})
|
||||
|
||||
def next_notification(self) -> Notification:
|
||||
if self._pending_notifications:
|
||||
return self._pending_notifications.popleft()
|
||||
item = self._global_notifications.get()
|
||||
if isinstance(item, BaseException):
|
||||
raise item
|
||||
return item
|
||||
|
||||
while True:
|
||||
msg = self._read_message()
|
||||
if "method" in msg and "id" in msg:
|
||||
response = self._handle_server_request(msg)
|
||||
self._write_message({"id": msg["id"], "result": response})
|
||||
continue
|
||||
if "method" in msg and "id" not in msg:
|
||||
return self._coerce_notification(msg["method"], msg.get("params"))
|
||||
def register_turn_notifications(self, turn_id: str) -> None:
|
||||
turn_queue: queue.Queue[NotificationQueueItem] = queue.Queue()
|
||||
with self._router_lock:
|
||||
if turn_id in self._turn_notifications:
|
||||
return
|
||||
pending = self._pending_turn_notifications.pop(turn_id, deque())
|
||||
self._turn_notifications[turn_id] = turn_queue
|
||||
for notification in pending:
|
||||
turn_queue.put(notification)
|
||||
|
||||
def acquire_turn_consumer(self, turn_id: str) -> None:
|
||||
with self._turn_consumer_lock:
|
||||
if self._active_turn_consumer is not None:
|
||||
raise RuntimeError(
|
||||
"Concurrent turn consumers are not yet supported in the experimental SDK. "
|
||||
f"Client is already streaming turn {self._active_turn_consumer!r}; "
|
||||
f"cannot start turn {turn_id!r} until the active consumer finishes."
|
||||
)
|
||||
self._active_turn_consumer = turn_id
|
||||
def unregister_turn_notifications(self, turn_id: str) -> None:
|
||||
with self._router_lock:
|
||||
self._turn_notifications.pop(turn_id, None)
|
||||
|
||||
def release_turn_consumer(self, turn_id: str) -> None:
|
||||
with self._turn_consumer_lock:
|
||||
if self._active_turn_consumer == turn_id:
|
||||
self._active_turn_consumer = None
|
||||
def next_turn_notification(self, turn_id: str) -> Notification:
|
||||
with self._router_lock:
|
||||
turn_queue = self._turn_notifications.get(turn_id)
|
||||
if turn_queue is None:
|
||||
raise RuntimeError(f"turn {turn_id!r} is not registered for streaming")
|
||||
item = turn_queue.get()
|
||||
if isinstance(item, BaseException):
|
||||
raise item
|
||||
return item
|
||||
|
||||
def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
|
||||
return self.request("thread/start", _params_dict(params), response_model=ThreadStartResponse)
|
||||
def thread_start(
|
||||
self, params: V2ThreadStartParams | JsonObject | None = None
|
||||
) -> ThreadStartResponse:
|
||||
return self.request(
|
||||
"thread/start", _params_dict(params), response_model=ThreadStartResponse
|
||||
)
|
||||
|
||||
def thread_resume(
|
||||
self,
|
||||
@@ -311,12 +314,20 @@ class AppServerClient:
|
||||
params: V2ThreadResumeParams | JsonObject | None = None,
|
||||
) -> ThreadResumeResponse:
|
||||
payload = {"threadId": thread_id, **_params_dict(params)}
|
||||
return self.request("thread/resume", payload, response_model=ThreadResumeResponse)
|
||||
return self.request(
|
||||
"thread/resume", payload, response_model=ThreadResumeResponse
|
||||
)
|
||||
|
||||
def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse:
|
||||
return self.request("thread/list", _params_dict(params), response_model=ThreadListResponse)
|
||||
def thread_list(
|
||||
self, params: V2ThreadListParams | JsonObject | None = None
|
||||
) -> ThreadListResponse:
|
||||
return self.request(
|
||||
"thread/list", _params_dict(params), response_model=ThreadListResponse
|
||||
)
|
||||
|
||||
def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse:
|
||||
def thread_read(
|
||||
self, thread_id: str, include_turns: bool = False
|
||||
) -> ThreadReadResponse:
|
||||
return self.request(
|
||||
"thread/read",
|
||||
{"threadId": thread_id, "includeTurns": include_turns},
|
||||
@@ -332,10 +343,18 @@ class AppServerClient:
|
||||
return self.request("thread/fork", payload, response_model=ThreadForkResponse)
|
||||
|
||||
def thread_archive(self, thread_id: str) -> ThreadArchiveResponse:
|
||||
return self.request("thread/archive", {"threadId": thread_id}, response_model=ThreadArchiveResponse)
|
||||
return self.request(
|
||||
"thread/archive",
|
||||
{"threadId": thread_id},
|
||||
response_model=ThreadArchiveResponse,
|
||||
)
|
||||
|
||||
def thread_unarchive(self, thread_id: str) -> ThreadUnarchiveResponse:
|
||||
return self.request("thread/unarchive", {"threadId": thread_id}, response_model=ThreadUnarchiveResponse)
|
||||
return self.request(
|
||||
"thread/unarchive",
|
||||
{"threadId": thread_id},
|
||||
response_model=ThreadUnarchiveResponse,
|
||||
)
|
||||
|
||||
def thread_set_name(self, thread_id: str, name: str) -> ThreadSetNameResponse:
|
||||
return self.request(
|
||||
@@ -362,7 +381,9 @@ class AppServerClient:
|
||||
"threadId": thread_id,
|
||||
"input": self._normalize_input_items(input_items),
|
||||
}
|
||||
return self.request("turn/start", payload, response_model=TurnStartResponse)
|
||||
started = self.request("turn/start", payload, response_model=TurnStartResponse)
|
||||
self.register_turn_notifications(started.turn.id)
|
||||
return started
|
||||
|
||||
def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
|
||||
return self.request(
|
||||
@@ -412,14 +433,18 @@ class AppServerClient:
|
||||
)
|
||||
|
||||
def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
|
||||
while True:
|
||||
notification = self.next_notification()
|
||||
if (
|
||||
notification.method == "turn/completed"
|
||||
and isinstance(notification.payload, TurnCompletedNotification)
|
||||
and notification.payload.turn.id == turn_id
|
||||
):
|
||||
return notification.payload
|
||||
self.register_turn_notifications(turn_id)
|
||||
try:
|
||||
while True:
|
||||
notification = self.next_turn_notification(turn_id)
|
||||
if (
|
||||
notification.method == "turn/completed"
|
||||
and isinstance(notification.payload, TurnCompletedNotification)
|
||||
and notification.payload.turn.id == turn_id
|
||||
):
|
||||
return notification.payload
|
||||
finally:
|
||||
self.unregister_turn_notifications(turn_id)
|
||||
|
||||
def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]:
|
||||
target_methods = {methods} if isinstance(methods, str) else set(methods)
|
||||
@@ -438,33 +463,40 @@ class AppServerClient:
|
||||
) -> Iterator[AgentMessageDeltaNotification]:
|
||||
started = self.turn_start(thread_id, text, params=params)
|
||||
turn_id = started.turn.id
|
||||
while True:
|
||||
notification = self.next_notification()
|
||||
if (
|
||||
notification.method == "item/agentMessage/delta"
|
||||
and isinstance(notification.payload, AgentMessageDeltaNotification)
|
||||
and notification.payload.turn_id == turn_id
|
||||
):
|
||||
yield notification.payload
|
||||
continue
|
||||
if (
|
||||
notification.method == "turn/completed"
|
||||
and isinstance(notification.payload, TurnCompletedNotification)
|
||||
and notification.payload.turn.id == turn_id
|
||||
):
|
||||
break
|
||||
try:
|
||||
while True:
|
||||
notification = self.next_turn_notification(turn_id)
|
||||
if (
|
||||
notification.method == "item/agentMessage/delta"
|
||||
and isinstance(notification.payload, AgentMessageDeltaNotification)
|
||||
and notification.payload.turn_id == turn_id
|
||||
):
|
||||
yield notification.payload
|
||||
continue
|
||||
if (
|
||||
notification.method == "turn/completed"
|
||||
and isinstance(notification.payload, TurnCompletedNotification)
|
||||
and notification.payload.turn.id == turn_id
|
||||
):
|
||||
break
|
||||
finally:
|
||||
self.unregister_turn_notifications(turn_id)
|
||||
|
||||
def _coerce_notification(self, method: str, params: object) -> Notification:
|
||||
params_dict = params if isinstance(params, dict) else {}
|
||||
|
||||
model = NOTIFICATION_MODELS.get(method)
|
||||
if model is None:
|
||||
return Notification(method=method, payload=UnknownNotification(params=params_dict))
|
||||
return Notification(
|
||||
method=method, payload=UnknownNotification(params=params_dict)
|
||||
)
|
||||
|
||||
try:
|
||||
payload = model.model_validate(params_dict)
|
||||
except Exception: # noqa: BLE001
|
||||
return Notification(method=method, payload=UnknownNotification(params=params_dict))
|
||||
return Notification(
|
||||
method=method, payload=UnknownNotification(params=params_dict)
|
||||
)
|
||||
return Notification(method=method, payload=payload)
|
||||
|
||||
def _normalize_input_items(
|
||||
@@ -477,7 +509,9 @@ class AppServerClient:
|
||||
return [input_items]
|
||||
return input_items
|
||||
|
||||
def _default_approval_handler(self, method: str, params: JsonObject | None) -> JsonObject:
|
||||
def _default_approval_handler(
|
||||
self, method: str, params: JsonObject | None
|
||||
) -> JsonObject:
|
||||
if method == "item/commandExecution/requestApproval":
|
||||
return {"decision": "accept"}
|
||||
if method == "item/fileChange/requestApproval":
|
||||
@@ -498,9 +532,95 @@ class AppServerClient:
|
||||
self._stderr_thread = threading.Thread(target=_drain, daemon=True)
|
||||
self._stderr_thread.start()
|
||||
|
||||
def _start_reader_thread(self) -> None:
|
||||
if self._proc is None or self._proc.stdout is None:
|
||||
return
|
||||
|
||||
self._reader_thread = threading.Thread(target=self._reader_loop, daemon=True)
|
||||
self._reader_thread.start()
|
||||
|
||||
def _reader_loop(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
msg = self._read_message()
|
||||
if "method" in msg and "id" in msg:
|
||||
response = self._handle_server_request(msg)
|
||||
self._write_message({"id": msg["id"], "result": response})
|
||||
continue
|
||||
if "method" in msg and "id" not in msg:
|
||||
method = msg["method"]
|
||||
if isinstance(method, str):
|
||||
self._route_notification(
|
||||
self._coerce_notification(method, msg.get("params"))
|
||||
)
|
||||
continue
|
||||
self._route_response(msg)
|
||||
except BaseException as exc:
|
||||
self._fail_pending_queues(exc)
|
||||
|
||||
def _stderr_tail(self, limit: int = 40) -> str:
|
||||
return "\n".join(list(self._stderr_lines)[-limit:])
|
||||
|
||||
def _route_response(self, msg: dict[str, JsonValue]) -> None:
|
||||
request_id = msg.get("id")
|
||||
with self._router_lock:
|
||||
waiter = self._response_waiters.pop(str(request_id), None)
|
||||
if waiter is None:
|
||||
return
|
||||
|
||||
if "error" in msg:
|
||||
err = msg["error"]
|
||||
if isinstance(err, dict):
|
||||
waiter.put(
|
||||
map_jsonrpc_error(
|
||||
int(err.get("code", -32000)),
|
||||
str(err.get("message", "unknown")),
|
||||
err.get("data"),
|
||||
)
|
||||
)
|
||||
else:
|
||||
waiter.put(AppServerError("Malformed JSON-RPC error response"))
|
||||
return
|
||||
|
||||
waiter.put(msg.get("result"))
|
||||
|
||||
def _route_notification(self, notification: Notification) -> None:
|
||||
turn_id = self._notification_turn_id(notification)
|
||||
if turn_id is None:
|
||||
self._global_notifications.put(notification)
|
||||
return
|
||||
|
||||
with self._router_lock:
|
||||
turn_queue = self._turn_notifications.get(turn_id)
|
||||
if turn_queue is None:
|
||||
self._pending_turn_notifications.setdefault(turn_id, deque()).append(
|
||||
notification
|
||||
)
|
||||
return
|
||||
turn_queue.put(notification)
|
||||
|
||||
def _notification_turn_id(self, notification: Notification) -> str | None:
|
||||
payload = notification.payload
|
||||
turn_id = getattr(payload, "turn_id", None)
|
||||
if isinstance(turn_id, str):
|
||||
return turn_id
|
||||
turn = getattr(payload, "turn", None)
|
||||
nested_turn_id = getattr(turn, "id", None)
|
||||
if isinstance(nested_turn_id, str):
|
||||
return nested_turn_id
|
||||
return None
|
||||
|
||||
def _fail_pending_queues(self, exc: BaseException) -> None:
|
||||
with self._router_lock:
|
||||
response_waiters = list(self._response_waiters.values())
|
||||
self._response_waiters.clear()
|
||||
turn_queues = list(self._turn_notifications.values())
|
||||
for waiter in response_waiters:
|
||||
waiter.put(exc)
|
||||
for turn_queue in turn_queues:
|
||||
turn_queue.put(exc)
|
||||
self._global_notifications.put(exc)
|
||||
|
||||
def _handle_server_request(self, msg: dict[str, JsonValue]) -> JsonObject:
|
||||
method = msg["method"]
|
||||
params = msg.get("params")
|
||||
|
||||
Reference in New Issue
Block a user