mirror of
https://github.com/openai/codex.git
synced 2026-04-23 22:24:57 +00:00
209 lines
7.1 KiB
Python
209 lines
7.1 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import Iterator
|
|
from typing import AsyncIterator, Callable, Iterable, ParamSpec, TypeVar
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from .client import AppServerClient, AppServerConfig
|
|
from .generated.v2_all import (
|
|
AgentMessageDeltaNotification,
|
|
ModelListResponse,
|
|
ThreadArchiveResponse,
|
|
ThreadCompactStartResponse,
|
|
ThreadForkParams as V2ThreadForkParams,
|
|
ThreadForkResponse,
|
|
ThreadListParams as V2ThreadListParams,
|
|
ThreadListResponse,
|
|
ThreadReadResponse,
|
|
ThreadResumeParams as V2ThreadResumeParams,
|
|
ThreadResumeResponse,
|
|
ThreadSetNameResponse,
|
|
ThreadStartParams as V2ThreadStartParams,
|
|
ThreadStartResponse,
|
|
ThreadUnarchiveResponse,
|
|
TurnCompletedNotification,
|
|
TurnInterruptResponse,
|
|
TurnStartParams as V2TurnStartParams,
|
|
TurnStartResponse,
|
|
TurnSteerResponse,
|
|
)
|
|
from .models import InitializeResponse, JsonObject, Notification
|
|
|
|
ModelT = TypeVar("ModelT", bound=BaseModel)
|
|
ParamsT = ParamSpec("ParamsT")
|
|
ReturnT = TypeVar("ReturnT")
|
|
|
|
|
|
class AsyncAppServerClient:
|
|
"""Async wrapper around AppServerClient using thread offloading."""
|
|
|
|
def __init__(self, config: AppServerConfig | None = None) -> None:
|
|
self._sync = AppServerClient(config=config)
|
|
# Single stdio transport cannot be read safely from multiple threads.
|
|
self._transport_lock = asyncio.Lock()
|
|
|
|
async def __aenter__(self) -> "AsyncAppServerClient":
|
|
await self.start()
|
|
return self
|
|
|
|
async def __aexit__(self, _exc_type, _exc, _tb) -> None:
|
|
await self.close()
|
|
|
|
async def _call_sync(
|
|
self,
|
|
fn: Callable[ParamsT, ReturnT],
|
|
/,
|
|
*args: ParamsT.args,
|
|
**kwargs: ParamsT.kwargs,
|
|
) -> ReturnT:
|
|
async with self._transport_lock:
|
|
return await asyncio.to_thread(fn, *args, **kwargs)
|
|
|
|
@staticmethod
|
|
def _next_from_iterator(
|
|
iterator: Iterator[AgentMessageDeltaNotification],
|
|
) -> tuple[bool, AgentMessageDeltaNotification | None]:
|
|
try:
|
|
return True, next(iterator)
|
|
except StopIteration:
|
|
return False, None
|
|
|
|
async def start(self) -> None:
|
|
await self._call_sync(self._sync.start)
|
|
|
|
async def close(self) -> None:
|
|
await self._call_sync(self._sync.close)
|
|
|
|
async def initialize(self) -> InitializeResponse:
|
|
return await self._call_sync(self._sync.initialize)
|
|
|
|
def acquire_turn_consumer(self, turn_id: str) -> None:
|
|
self._sync.acquire_turn_consumer(turn_id)
|
|
|
|
def release_turn_consumer(self, turn_id: str) -> None:
|
|
self._sync.release_turn_consumer(turn_id)
|
|
|
|
async def request(
|
|
self,
|
|
method: str,
|
|
params: JsonObject | None,
|
|
*,
|
|
response_model: type[ModelT],
|
|
) -> ModelT:
|
|
return await self._call_sync(
|
|
self._sync.request,
|
|
method,
|
|
params,
|
|
response_model=response_model,
|
|
)
|
|
|
|
async def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
|
|
return await self._call_sync(self._sync.thread_start, params)
|
|
|
|
async def thread_resume(
|
|
self,
|
|
thread_id: str,
|
|
params: V2ThreadResumeParams | JsonObject | None = None,
|
|
) -> ThreadResumeResponse:
|
|
return await self._call_sync(self._sync.thread_resume, thread_id, params)
|
|
|
|
async def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse:
|
|
return await self._call_sync(self._sync.thread_list, params)
|
|
|
|
async def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse:
|
|
return await self._call_sync(self._sync.thread_read, thread_id, include_turns)
|
|
|
|
async def thread_fork(
|
|
self,
|
|
thread_id: str,
|
|
params: V2ThreadForkParams | JsonObject | None = None,
|
|
) -> ThreadForkResponse:
|
|
return await self._call_sync(self._sync.thread_fork, thread_id, params)
|
|
|
|
async def thread_archive(self, thread_id: str) -> ThreadArchiveResponse:
|
|
return await self._call_sync(self._sync.thread_archive, thread_id)
|
|
|
|
async def thread_unarchive(self, thread_id: str) -> ThreadUnarchiveResponse:
|
|
return await self._call_sync(self._sync.thread_unarchive, thread_id)
|
|
|
|
async def thread_set_name(self, thread_id: str, name: str) -> ThreadSetNameResponse:
|
|
return await self._call_sync(self._sync.thread_set_name, thread_id, name)
|
|
|
|
async def thread_compact(self, thread_id: str) -> ThreadCompactStartResponse:
|
|
return await self._call_sync(self._sync.thread_compact, thread_id)
|
|
|
|
async def turn_start(
|
|
self,
|
|
thread_id: str,
|
|
input_items: list[JsonObject] | JsonObject | str,
|
|
params: V2TurnStartParams | JsonObject | None = None,
|
|
) -> TurnStartResponse:
|
|
return await self._call_sync(self._sync.turn_start, thread_id, input_items, params)
|
|
|
|
async def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
|
|
return await self._call_sync(self._sync.turn_interrupt, thread_id, turn_id)
|
|
|
|
async def turn_steer(
|
|
self,
|
|
thread_id: str,
|
|
expected_turn_id: str,
|
|
input_items: list[JsonObject] | JsonObject | str,
|
|
) -> TurnSteerResponse:
|
|
return await self._call_sync(
|
|
self._sync.turn_steer,
|
|
thread_id,
|
|
expected_turn_id,
|
|
input_items,
|
|
)
|
|
|
|
async def model_list(self, include_hidden: bool = False) -> ModelListResponse:
|
|
return await self._call_sync(self._sync.model_list, include_hidden)
|
|
|
|
async def request_with_retry_on_overload(
|
|
self,
|
|
method: str,
|
|
params: JsonObject | None,
|
|
*,
|
|
response_model: type[ModelT],
|
|
max_attempts: int = 3,
|
|
initial_delay_s: float = 0.25,
|
|
max_delay_s: float = 2.0,
|
|
) -> ModelT:
|
|
return await self._call_sync(
|
|
self._sync.request_with_retry_on_overload,
|
|
method,
|
|
params,
|
|
response_model=response_model,
|
|
max_attempts=max_attempts,
|
|
initial_delay_s=initial_delay_s,
|
|
max_delay_s=max_delay_s,
|
|
)
|
|
|
|
async def next_notification(self) -> Notification:
|
|
return await self._call_sync(self._sync.next_notification)
|
|
|
|
async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
|
|
return await self._call_sync(self._sync.wait_for_turn_completed, turn_id)
|
|
|
|
async def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]:
|
|
return await self._call_sync(self._sync.stream_until_methods, methods)
|
|
|
|
async def stream_text(
|
|
self,
|
|
thread_id: str,
|
|
text: str,
|
|
params: V2TurnStartParams | JsonObject | None = None,
|
|
) -> AsyncIterator[AgentMessageDeltaNotification]:
|
|
async with self._transport_lock:
|
|
iterator = self._sync.stream_text(thread_id, text, params)
|
|
while True:
|
|
has_value, chunk = await asyncio.to_thread(
|
|
self._next_from_iterator,
|
|
iterator,
|
|
)
|
|
if not has_value:
|
|
break
|
|
yield chunk
|