from __future__ import annotations import json import os import subprocess import threading import uuid from collections import deque from dataclasses import dataclass from pathlib import Path from typing import Callable, Iterable, Iterator, TypeVar from pydantic import BaseModel from .errors import AppServerError, TransportClosedError, map_jsonrpc_error from .generated.notification_registry import NOTIFICATION_MODELS from .generated.v2_all import ( AgentMessageDeltaNotification, ModelListResponse, ThreadArchiveResponse, ThreadCompactStartResponse, ThreadForkParams as V2ThreadForkParams, ThreadForkResponse, ThreadListParams as V2ThreadListParams, ThreadListResponse, ThreadReadResponse, ThreadResumeParams as V2ThreadResumeParams, ThreadResumeResponse, ThreadSetNameResponse, ThreadStartParams as V2ThreadStartParams, ThreadStartResponse, ThreadUnarchiveResponse, TurnCompletedNotification, TurnInterruptResponse, TurnStartParams as V2TurnStartParams, TurnStartResponse, TurnSteerResponse, ) from .models import ( InitializeResponse, JsonObject, JsonValue, Notification, UnknownNotification, ) from .retry import retry_on_overload ModelT = TypeVar("ModelT", bound=BaseModel) ApprovalHandler = Callable[[str, JsonObject | None], JsonObject] BUNDLED_SDK_PKG_NAME = "codex-app-server-sdk" RUNTIME_PKG_NAME = "codex-cli-bin" def _params_dict( params: ( V2ThreadStartParams | V2ThreadResumeParams | V2ThreadListParams | V2ThreadForkParams | V2TurnStartParams | JsonObject | None ), ) -> JsonObject: if params is None: return {} if hasattr(params, "model_dump"): dumped = params.model_dump( by_alias=True, exclude_none=True, mode="json", ) if not isinstance(dumped, dict): raise TypeError("Expected model_dump() to return dict") return dumped if isinstance(params, dict): return params raise TypeError(f"Expected generated params model or dict, got {type(params).__name__}") def _installed_codex_path() -> Path: try: from codex_cli_bin import bundled_codex_path except ImportError as exc: raise FileNotFoundError( "Unable to locate the Codex runtime. Install the published " f"{BUNDLED_SDK_PKG_NAME} package, install {RUNTIME_PKG_NAME} " "alongside the core SDK, or set AppServerConfig.codex_bin explicitly." ) from exc return bundled_codex_path() @dataclass(frozen=True) class CodexBinResolverOps: installed_codex_path: Callable[[], Path] path_exists: Callable[[Path], bool] def _default_codex_bin_resolver_ops() -> CodexBinResolverOps: return CodexBinResolverOps( installed_codex_path=_installed_codex_path, path_exists=lambda path: path.exists(), ) def resolve_codex_bin(config: "AppServerConfig", ops: CodexBinResolverOps) -> Path: if config.codex_bin is not None: codex_bin = Path(config.codex_bin) if not ops.path_exists(codex_bin): raise FileNotFoundError( f"Codex binary not found at {codex_bin}. Set AppServerConfig.codex_bin " "to a valid binary path." ) return codex_bin return ops.installed_codex_path() def _resolve_codex_bin(config: "AppServerConfig") -> Path: return resolve_codex_bin(config, _default_codex_bin_resolver_ops()) @dataclass(slots=True) class AppServerConfig: codex_bin: str | None = None launch_args_override: tuple[str, ...] | None = None config_overrides: tuple[str, ...] = () cwd: str | None = None env: dict[str, str] | None = None client_name: str = "codex_python_sdk" client_title: str = "Codex Python SDK" client_version: str = "0.2.0" experimental_api: bool = True class AppServerClient: """Synchronous typed JSON-RPC client for `codex app-server` over stdio.""" def __init__( self, config: AppServerConfig | None = None, approval_handler: ApprovalHandler | None = None, ) -> None: self.config = config or AppServerConfig() self._approval_handler = approval_handler or self._default_approval_handler self._proc: subprocess.Popen[str] | None = None self._lock = threading.Lock() self._turn_consumer_lock = threading.Lock() self._active_turn_consumer: str | None = None self._pending_notifications: deque[Notification] = deque() self._stderr_lines: deque[str] = deque(maxlen=400) self._stderr_thread: threading.Thread | None = None def __enter__(self) -> "AppServerClient": self.start() return self def __exit__(self, _exc_type, _exc, _tb) -> None: self.close() def start(self) -> None: if self._proc is not None: return if self.config.launch_args_override is not None: args = list(self.config.launch_args_override) else: codex_bin = _resolve_codex_bin(self.config) args = [str(codex_bin)] for kv in self.config.config_overrides: args.extend(["--config", kv]) args.extend(["app-server", "--listen", "stdio://"]) env = os.environ.copy() if self.config.env: env.update(self.config.env) self._proc = subprocess.Popen( args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, cwd=self.config.cwd, env=env, bufsize=1, ) self._start_stderr_drain_thread() def close(self) -> None: if self._proc is None: return proc = self._proc self._proc = None self._active_turn_consumer = None if proc.stdin: proc.stdin.close() try: proc.terminate() proc.wait(timeout=2) except Exception: proc.kill() if self._stderr_thread and self._stderr_thread.is_alive(): self._stderr_thread.join(timeout=0.5) def initialize(self) -> InitializeResponse: result = self.request( "initialize", { "clientInfo": { "name": self.config.client_name, "title": self.config.client_title, "version": self.config.client_version, }, "capabilities": { "experimentalApi": self.config.experimental_api, }, }, response_model=InitializeResponse, ) self.notify("initialized", None) return result def request( self, method: str, params: JsonObject | None, *, response_model: type[ModelT], ) -> ModelT: result = self._request_raw(method, params) if not isinstance(result, dict): raise AppServerError(f"{method} response must be a JSON object") return response_model.model_validate(result) def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue: request_id = str(uuid.uuid4()) self._write_message({"id": request_id, "method": method, "params": params or {}}) while True: msg = self._read_message() if "method" in msg and "id" in msg: response = self._handle_server_request(msg) self._write_message({"id": msg["id"], "result": response}) continue if "method" in msg and "id" not in msg: self._pending_notifications.append( self._coerce_notification(msg["method"], msg.get("params")) ) continue if msg.get("id") != request_id: continue if "error" in msg: err = msg["error"] if isinstance(err, dict): raise map_jsonrpc_error( int(err.get("code", -32000)), str(err.get("message", "unknown")), err.get("data"), ) raise AppServerError("Malformed JSON-RPC error response") return msg.get("result") def notify(self, method: str, params: JsonObject | None = None) -> None: self._write_message({"method": method, "params": params or {}}) def next_notification(self) -> Notification: if self._pending_notifications: return self._pending_notifications.popleft() while True: msg = self._read_message() if "method" in msg and "id" in msg: response = self._handle_server_request(msg) self._write_message({"id": msg["id"], "result": response}) continue if "method" in msg and "id" not in msg: return self._coerce_notification(msg["method"], msg.get("params")) def acquire_turn_consumer(self, turn_id: str) -> None: with self._turn_consumer_lock: if self._active_turn_consumer is not None: raise RuntimeError( "Concurrent turn consumers are not yet supported in the experimental SDK. " f"Client is already streaming turn {self._active_turn_consumer!r}; " f"cannot start turn {turn_id!r} until the active consumer finishes." ) self._active_turn_consumer = turn_id def release_turn_consumer(self, turn_id: str) -> None: with self._turn_consumer_lock: if self._active_turn_consumer == turn_id: self._active_turn_consumer = None def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse: return self.request("thread/start", _params_dict(params), response_model=ThreadStartResponse) def thread_resume( self, thread_id: str, params: V2ThreadResumeParams | JsonObject | None = None, ) -> ThreadResumeResponse: payload = {"threadId": thread_id, **_params_dict(params)} return self.request("thread/resume", payload, response_model=ThreadResumeResponse) def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse: return self.request("thread/list", _params_dict(params), response_model=ThreadListResponse) def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse: return self.request( "thread/read", {"threadId": thread_id, "includeTurns": include_turns}, response_model=ThreadReadResponse, ) def thread_fork( self, thread_id: str, params: V2ThreadForkParams | JsonObject | None = None, ) -> ThreadForkResponse: payload = {"threadId": thread_id, **_params_dict(params)} return self.request("thread/fork", payload, response_model=ThreadForkResponse) def thread_archive(self, thread_id: str) -> ThreadArchiveResponse: return self.request("thread/archive", {"threadId": thread_id}, response_model=ThreadArchiveResponse) def thread_unarchive(self, thread_id: str) -> ThreadUnarchiveResponse: return self.request("thread/unarchive", {"threadId": thread_id}, response_model=ThreadUnarchiveResponse) def thread_set_name(self, thread_id: str, name: str) -> ThreadSetNameResponse: return self.request( "thread/name/set", {"threadId": thread_id, "name": name}, response_model=ThreadSetNameResponse, ) def thread_compact(self, thread_id: str) -> ThreadCompactStartResponse: return self.request( "thread/compact/start", {"threadId": thread_id}, response_model=ThreadCompactStartResponse, ) def turn_start( self, thread_id: str, input_items: list[JsonObject] | JsonObject | str, params: V2TurnStartParams | JsonObject | None = None, ) -> TurnStartResponse: payload = { **_params_dict(params), "threadId": thread_id, "input": self._normalize_input_items(input_items), } return self.request("turn/start", payload, response_model=TurnStartResponse) def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse: return self.request( "turn/interrupt", {"threadId": thread_id, "turnId": turn_id}, response_model=TurnInterruptResponse, ) def turn_steer( self, thread_id: str, expected_turn_id: str, input_items: list[JsonObject] | JsonObject | str, ) -> TurnSteerResponse: return self.request( "turn/steer", { "threadId": thread_id, "expectedTurnId": expected_turn_id, "input": self._normalize_input_items(input_items), }, response_model=TurnSteerResponse, ) def model_list(self, include_hidden: bool = False) -> ModelListResponse: return self.request( "model/list", {"includeHidden": include_hidden}, response_model=ModelListResponse, ) def request_with_retry_on_overload( self, method: str, params: JsonObject | None, *, response_model: type[ModelT], max_attempts: int = 3, initial_delay_s: float = 0.25, max_delay_s: float = 2.0, ) -> ModelT: return retry_on_overload( lambda: self.request(method, params, response_model=response_model), max_attempts=max_attempts, initial_delay_s=initial_delay_s, max_delay_s=max_delay_s, ) def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification: while True: notification = self.next_notification() if ( notification.method == "turn/completed" and isinstance(notification.payload, TurnCompletedNotification) and notification.payload.turn.id == turn_id ): return notification.payload def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]: target_methods = {methods} if isinstance(methods, str) else set(methods) out: list[Notification] = [] while True: notification = self.next_notification() out.append(notification) if notification.method in target_methods: return out def stream_text( self, thread_id: str, text: str, params: V2TurnStartParams | JsonObject | None = None, ) -> Iterator[AgentMessageDeltaNotification]: started = self.turn_start(thread_id, text, params=params) turn_id = started.turn.id while True: notification = self.next_notification() if ( notification.method == "item/agentMessage/delta" and isinstance(notification.payload, AgentMessageDeltaNotification) and notification.payload.turn_id == turn_id ): yield notification.payload continue if ( notification.method == "turn/completed" and isinstance(notification.payload, TurnCompletedNotification) and notification.payload.turn.id == turn_id ): break def _coerce_notification(self, method: str, params: object) -> Notification: params_dict = params if isinstance(params, dict) else {} model = NOTIFICATION_MODELS.get(method) if model is None: return Notification(method=method, payload=UnknownNotification(params=params_dict)) try: payload = model.model_validate(params_dict) except Exception: # noqa: BLE001 return Notification(method=method, payload=UnknownNotification(params=params_dict)) return Notification(method=method, payload=payload) def _normalize_input_items( self, input_items: list[JsonObject] | JsonObject | str, ) -> list[JsonObject]: if isinstance(input_items, str): return [{"type": "text", "text": input_items}] if isinstance(input_items, dict): return [input_items] return input_items def _default_approval_handler(self, method: str, params: JsonObject | None) -> JsonObject: if method == "item/commandExecution/requestApproval": return {"decision": "accept"} if method == "item/fileChange/requestApproval": return {"decision": "accept"} return {} def _start_stderr_drain_thread(self) -> None: if self._proc is None or self._proc.stderr is None: return def _drain() -> None: stderr = self._proc.stderr if stderr is None: return for line in stderr: self._stderr_lines.append(line.rstrip("\n")) self._stderr_thread = threading.Thread(target=_drain, daemon=True) self._stderr_thread.start() def _stderr_tail(self, limit: int = 40) -> str: return "\n".join(list(self._stderr_lines)[-limit:]) def _handle_server_request(self, msg: dict[str, JsonValue]) -> JsonObject: method = msg["method"] params = msg.get("params") if not isinstance(method, str): return {} return self._approval_handler( method, params if isinstance(params, dict) else None, ) def _write_message(self, payload: JsonObject) -> None: if self._proc is None or self._proc.stdin is None: raise TransportClosedError("app-server is not running") with self._lock: self._proc.stdin.write(json.dumps(payload) + "\n") self._proc.stdin.flush() def _read_message(self) -> dict[str, JsonValue]: if self._proc is None or self._proc.stdout is None: raise TransportClosedError("app-server is not running") line = self._proc.stdout.readline() if not line: raise TransportClosedError( f"app-server closed stdout. stderr_tail={self._stderr_tail()[:2000]}" ) try: message = json.loads(line) except json.JSONDecodeError as exc: raise AppServerError(f"Invalid JSON-RPC line: {line!r}") from exc if not isinstance(message, dict): raise AppServerError(f"Invalid JSON-RPC payload: {message!r}") return message def default_codex_home() -> str: return str(Path.home() / ".codex")