from __future__ import annotations import asyncio from collections import deque from pathlib import Path import pytest import codex_app_server.public_api as public_api_module from codex_app_server.client import AppServerClient from codex_app_server.generated.v2_all import ( AgentMessageDeltaNotification, RawResponseItemCompletedNotification, ThreadTokenUsageUpdatedNotification, ) from codex_app_server.models import InitializeResponse, Notification from codex_app_server.public_api import AsyncCodex, AsyncTurn, Codex, Turn from codex_app_server.public_types import TurnStatus ROOT = Path(__file__).resolve().parents[1] def _delta_notification( *, thread_id: str = "thread-1", turn_id: str = "turn-1", text: str = "delta-text", ) -> Notification: return Notification( method="item/agentMessage/delta", payload=AgentMessageDeltaNotification.model_validate( { "delta": text, "itemId": "item-1", "threadId": thread_id, "turnId": turn_id, } ), ) def _raw_response_notification( *, thread_id: str = "thread-1", turn_id: str = "turn-1", text: str = "raw-text", ) -> Notification: return Notification( method="rawResponseItem/completed", payload=RawResponseItemCompletedNotification.model_validate( { "item": { "type": "message", "role": "assistant", "content": [{"type": "output_text", "text": text}], }, "threadId": thread_id, "turnId": turn_id, } ), ) def _usage_notification( *, thread_id: str = "thread-1", turn_id: str = "turn-1", ) -> Notification: return Notification( method="thread/tokenUsage/updated", payload=ThreadTokenUsageUpdatedNotification.model_validate( { "threadId": thread_id, "turnId": turn_id, "tokenUsage": { "last": { "cachedInputTokens": 0, "inputTokens": 1, "outputTokens": 2, "reasoningOutputTokens": 0, "totalTokens": 3, }, "total": { "cachedInputTokens": 0, "inputTokens": 1, "outputTokens": 2, "reasoningOutputTokens": 0, "totalTokens": 3, }, }, } ), ) def _completed_notification( *, thread_id: str = "thread-1", turn_id: str = "turn-1", status: str = "completed", ) -> Notification: return Notification( method="turn/completed", payload=public_api_module.TurnCompletedNotificationPayload.model_validate( { "threadId": thread_id, "turn": { "id": turn_id, "items": [], "status": status, }, } ), ) def test_codex_init_failure_closes_client(monkeypatch: pytest.MonkeyPatch) -> None: closed: list[bool] = [] class FakeClient: def __init__(self, config=None) -> None: # noqa: ANN001,ARG002 self._closed = False def start(self) -> None: return None def initialize(self) -> InitializeResponse: return InitializeResponse.model_validate({}) def close(self) -> None: self._closed = True closed.append(True) monkeypatch.setattr(public_api_module, "AppServerClient", FakeClient) with pytest.raises(RuntimeError, match="missing required metadata"): Codex() assert closed == [True] def test_async_codex_init_failure_closes_client() -> None: async def scenario() -> None: codex = AsyncCodex() close_calls = 0 async def fake_start() -> None: return None async def fake_initialize() -> InitializeResponse: return InitializeResponse.model_validate({}) async def fake_close() -> None: nonlocal close_calls close_calls += 1 codex._client.start = fake_start # type: ignore[method-assign] codex._client.initialize = fake_initialize # type: ignore[method-assign] codex._client.close = fake_close # type: ignore[method-assign] with pytest.raises(RuntimeError, match="missing required metadata"): await codex.models() assert close_calls == 1 assert codex._initialized is False assert codex._init is None asyncio.run(scenario()) def test_async_codex_initializes_only_once_under_concurrency() -> None: async def scenario() -> None: codex = AsyncCodex() start_calls = 0 initialize_calls = 0 ready = asyncio.Event() async def fake_start() -> None: nonlocal start_calls start_calls += 1 async def fake_initialize() -> InitializeResponse: nonlocal initialize_calls initialize_calls += 1 ready.set() await asyncio.sleep(0.02) return InitializeResponse.model_validate( { "userAgent": "codex-cli/1.2.3", "serverInfo": {"name": "codex-cli", "version": "1.2.3"}, } ) async def fake_model_list(include_hidden: bool = False): # noqa: ANN202,ARG001 await ready.wait() return object() codex._client.start = fake_start # type: ignore[method-assign] codex._client.initialize = fake_initialize # type: ignore[method-assign] codex._client.model_list = fake_model_list # type: ignore[method-assign] await asyncio.gather(codex.models(), codex.models()) assert start_calls == 1 assert initialize_calls == 1 asyncio.run(scenario()) def test_turn_stream_rejects_second_active_consumer() -> None: client = AppServerClient() notifications: deque[Notification] = deque( [ _delta_notification(turn_id="turn-1"), _completed_notification(turn_id="turn-1"), ] ) client.next_notification = notifications.popleft # type: ignore[method-assign] first_stream = Turn(client, "thread-1", "turn-1").stream() assert next(first_stream).method == "item/agentMessage/delta" second_stream = Turn(client, "thread-1", "turn-2").stream() with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"): next(second_stream) first_stream.close() def test_async_turn_stream_rejects_second_active_consumer() -> None: async def scenario() -> None: codex = AsyncCodex() async def fake_ensure_initialized() -> None: return None notifications: deque[Notification] = deque( [ _delta_notification(turn_id="turn-1"), _completed_notification(turn_id="turn-1"), ] ) async def fake_next_notification() -> Notification: return notifications.popleft() codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign] codex._client.next_notification = fake_next_notification # type: ignore[method-assign] first_stream = AsyncTurn(codex, "thread-1", "turn-1").stream() assert (await anext(first_stream)).method == "item/agentMessage/delta" second_stream = AsyncTurn(codex, "thread-1", "turn-2").stream() with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"): await anext(second_stream) await first_stream.aclose() asyncio.run(scenario()) def test_turn_run_falls_back_to_completed_raw_response_text() -> None: client = AppServerClient() notifications: deque[Notification] = deque( [ _raw_response_notification(text="hello from raw response"), _usage_notification(), _completed_notification(), ] ) client.next_notification = notifications.popleft # type: ignore[method-assign] result = Turn(client, "thread-1", "turn-1").run() assert result.status == TurnStatus.completed assert result.text == "hello from raw response" def test_retry_examples_compare_status_with_enum() -> None: for path in ( ROOT / "examples" / "10_error_handling_and_retry" / "sync.py", ROOT / "examples" / "10_error_handling_and_retry" / "async.py", ): source = path.read_text() assert '== "failed"' not in source assert "TurnStatus.failed" in source