Support concurrent Python SDK turns across threads

This commit is contained in:
Shaqayeq
2026-03-19 16:16:02 -07:00
parent 6b8175c734
commit 083243dca1
8 changed files with 584 additions and 189 deletions

View File

@@ -133,6 +133,15 @@ def _token_usage_notification(
)
def _turn_notification_source(
notifications_by_turn: dict[tuple[str, str], deque[Notification]],
):
def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
return notifications_by_turn[(thread_id, turn_id)].popleft()
return fake_next_turn_notification
def test_codex_init_failure_closes_client(monkeypatch: pytest.MonkeyPatch) -> None:
closed: list[bool] = []
@@ -226,66 +235,132 @@ def test_async_codex_initializes_only_once_under_concurrency() -> None:
asyncio.run(scenario())
def test_turn_stream_rejects_second_active_consumer() -> None:
def test_turn_stream_allows_different_active_threads() -> None:
client = AppServerClient()
notifications: deque[Notification] = deque(
[
_delta_notification(turn_id="turn-1"),
_completed_notification(turn_id="turn-1"),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
]
),
("thread-2", "turn-2"): deque(
[
_delta_notification(thread_id="thread-2", turn_id="turn-2"),
_completed_notification(thread_id="thread-2", turn_id="turn-2"),
]
),
}
)
first_stream = TurnHandle(client, "thread-1", "turn-1").stream()
second_stream = TurnHandle(client, "thread-2", "turn-2").stream()
assert next(first_stream).method == "item/agentMessage/delta"
assert next(second_stream).method == "item/agentMessage/delta"
first_stream.close()
second_stream.close()
def test_turn_stream_blocks_next_notification_while_active() -> None:
client = AppServerClient()
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
]
),
}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
first_stream = TurnHandle(client, "thread-1", "turn-1").stream()
assert next(first_stream).method == "item/agentMessage/delta"
second_stream = TurnHandle(client, "thread-1", "turn-2").stream()
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"):
next(second_stream)
with pytest.raises(RuntimeError, match="next_notification\\(\\) is incompatible"):
client.next_notification()
first_stream.close()
def test_async_turn_stream_rejects_second_active_consumer() -> None:
def test_turn_start_rejects_same_thread_overlap_and_allows_after_completion() -> None:
client = AppServerClient()
turn_ids = iter(["turn-1", "turn-2"])
def fake_request(method: str, params, *, response_model): # type: ignore[no-untyped-def]
assert method == "turn/start"
return response_model.model_validate(
{
"turn": {
"id": next(turn_ids),
"items": [],
"status": TurnStatus.in_progress.value,
}
}
)
client.request = fake_request # type: ignore[method-assign]
first = client.turn_start("thread-1", "first turn")
assert first.turn.id == "turn-1"
with pytest.raises(RuntimeError, match="already has active turn"):
client.turn_start("thread-1", "second turn")
client._dispatch_notification( # type: ignore[attr-defined]
_completed_notification(thread_id="thread-1", turn_id="turn-1")
)
second = client.turn_start("thread-1", "second turn")
assert second.turn.id == "turn-2"
def test_async_turn_stream_allows_different_active_threads() -> None:
async def scenario() -> None:
codex = AsyncCodex()
async def fake_ensure_initialized() -> None:
return None
notifications: deque[Notification] = deque(
[
_delta_notification(turn_id="turn-1"),
_completed_notification(turn_id="turn-1"),
]
)
notifications_by_turn = {
("thread-1", "turn-1"): deque(
[
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
]
),
("thread-2", "turn-2"): deque(
[
_delta_notification(thread_id="thread-2", turn_id="turn-2"),
_completed_notification(thread_id="thread-2", turn_id="turn-2"),
]
),
}
async def fake_next_notification() -> Notification:
return notifications.popleft()
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
return notifications_by_turn[(thread_id, turn_id)].popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
first_stream = AsyncTurnHandle(codex, "thread-1", "turn-1").stream()
second_stream = AsyncTurnHandle(codex, "thread-2", "turn-2").stream()
assert (await anext(first_stream)).method == "item/agentMessage/delta"
second_stream = AsyncTurnHandle(codex, "thread-1", "turn-2").stream()
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"):
await anext(second_stream)
assert (await anext(second_stream)).method == "item/agentMessage/delta"
await first_stream.aclose()
await second_stream.aclose()
asyncio.run(scenario())
def test_turn_run_returns_completed_turn_payload() -> None:
client = AppServerClient()
notifications: deque[Notification] = deque(
[
_completed_notification(),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{("thread-1", "turn-1"): deque([_completed_notification()])}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
result = TurnHandle(client, "thread-1", "turn-1").run()
@@ -298,14 +373,17 @@ def test_thread_run_accepts_string_input_and_returns_run_result() -> None:
client = AppServerClient()
item_notification = _item_completed_notification(text="Hello.")
usage_notification = _token_usage_notification()
notifications: deque[Notification] = deque(
[
item_notification,
usage_notification,
_completed_notification(),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[
item_notification,
usage_notification,
_completed_notification(),
]
),
}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
seen: dict[str, object] = {}
def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202
@@ -331,14 +409,17 @@ def test_thread_run_uses_last_completed_assistant_message_as_final_response() ->
client = AppServerClient()
first_item_notification = _item_completed_notification(text="First message")
second_item_notification = _item_completed_notification(text="Second message")
notifications: deque[Notification] = deque(
[
first_item_notification,
second_item_notification,
_completed_notification(),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[
first_item_notification,
second_item_notification,
_completed_notification(),
]
),
}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
@@ -356,14 +437,17 @@ def test_thread_run_preserves_empty_last_assistant_message() -> None:
client = AppServerClient()
first_item_notification = _item_completed_notification(text="First message")
second_item_notification = _item_completed_notification(text="")
notifications: deque[Notification] = deque(
[
first_item_notification,
second_item_notification,
_completed_notification(),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[
first_item_notification,
second_item_notification,
_completed_notification(),
]
),
}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
@@ -387,14 +471,17 @@ def test_thread_run_prefers_explicit_final_answer_over_later_commentary() -> Non
text="Commentary",
phase=MessagePhase.commentary,
)
notifications: deque[Notification] = deque(
[
final_answer_notification,
commentary_notification,
_completed_notification(),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[
final_answer_notification,
commentary_notification,
_completed_notification(),
]
),
}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
@@ -414,13 +501,16 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non
text="Commentary",
phase=MessagePhase.commentary,
)
notifications: deque[Notification] = deque(
[
commentary_notification,
_completed_notification(),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[
commentary_notification,
_completed_notification(),
]
),
}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
@@ -433,12 +523,13 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non
def test_thread_run_raises_on_failed_turn() -> None:
client = AppServerClient()
notifications: deque[Notification] = deque(
[
_completed_notification(status="failed", error_message="boom"),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[_completed_notification(status="failed", error_message="boom")]
),
}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
@@ -471,12 +562,13 @@ def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None:
seen["params"] = params
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
async def fake_next_notification() -> Notification:
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
assert (thread_id, turn_id) == ("thread-1", "turn-1")
return notifications.popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
result = await AsyncThread(codex, "thread-1").run("hello")
@@ -511,12 +603,13 @@ def test_async_thread_run_uses_last_completed_assistant_message_as_final_respons
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
async def fake_next_notification() -> Notification:
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
assert (thread_id, turn_id) == ("thread-1", "turn-1")
return notifications.popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
result = await AsyncThread(codex, "thread-1").run("hello")
@@ -550,12 +643,13 @@ def test_async_thread_run_returns_none_when_only_commentary_messages_complete()
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
async def fake_next_notification() -> Notification:
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
assert (thread_id, turn_id) == ("thread-1", "turn-1")
return notifications.popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
result = await AsyncThread(codex, "thread-1").run("hello")