mirror of
https://github.com/openai/codex.git
synced 2026-02-01 22:47:52 +00:00
Support response.done and add integration tests (#9129)
The agent loop using a persistent incremental web socket connection.
This commit is contained in:
@@ -88,6 +88,14 @@ struct ResponseCompleted {
|
||||
usage: Option<ResponseCompletedUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseDone {
|
||||
#[serde(default)]
|
||||
id: Option<String>,
|
||||
#[serde(default)]
|
||||
usage: Option<ResponseCompletedUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCompletedUsage {
|
||||
input_tokens: i64,
|
||||
@@ -229,6 +237,29 @@ pub fn process_responses_event(
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.done" => {
|
||||
if let Some(resp_val) = event.response {
|
||||
match serde_json::from_value::<ResponseDone>(resp_val) {
|
||||
Ok(resp) => {
|
||||
return Ok(Some(ResponseEvent::Completed {
|
||||
response_id: resp.id.unwrap_or_default(),
|
||||
token_usage: resp.usage.map(Into::into),
|
||||
}));
|
||||
}
|
||||
Err(err) => {
|
||||
let error = format!("failed to parse ResponseCompleted: {err}");
|
||||
debug!("{error}");
|
||||
return Err(ResponsesEventError::Api(ApiError::Stream(error)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!("response.done missing response payload");
|
||||
return Ok(Some(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}));
|
||||
}
|
||||
"response.output_item.added" => {
|
||||
if let Some(item_val) = event.item {
|
||||
if let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) {
|
||||
@@ -517,6 +548,65 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn response_done_emits_completed() {
|
||||
let done = json!({
|
||||
"type": "response.done",
|
||||
"response": {
|
||||
"usage": {
|
||||
"input_tokens": 1,
|
||||
"input_tokens_details": null,
|
||||
"output_tokens": 2,
|
||||
"output_tokens_details": null,
|
||||
"total_tokens": 3
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let sse1 = format!("event: response.done\ndata: {done}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
match &events[0] {
|
||||
Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}) => {
|
||||
assert_eq!(response_id, "");
|
||||
assert!(token_usage.is_some());
|
||||
}
|
||||
other => panic!("unexpected event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn response_done_without_payload_emits_completed() {
|
||||
let done = json!({
|
||||
"type": "response.done"
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let sse1 = format!("event: response.done\ndata: {done}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
match &events[0] {
|
||||
Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}) => {
|
||||
assert_eq!(response_id, "");
|
||||
assert!(token_usage.is_none());
|
||||
}
|
||||
other => panic!("unexpected event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn error_when_error_event() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_689bcf18d7f08194bf3440ba62fe05d803fee0cdac429894","object":"response","created_at":1755041560,"status":"failed","background":false,"error":{"code":"rate_limit_exceeded","message":"Rate limit reached for gpt-5.1 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."}, "usage":null,"user":null,"metadata":{}}}"#;
|
||||
|
||||
@@ -2543,6 +2543,8 @@ pub(crate) async fn run_turn(
|
||||
// many turns, from the perspective of the user, it is a single turn.
|
||||
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
|
||||
let mut client_session = turn_context.client.new_session();
|
||||
|
||||
loop {
|
||||
// Note that pending_input would be something like a message the user
|
||||
// submitted through the UI while the model was running. Though the UI
|
||||
@@ -2573,6 +2575,7 @@ pub(crate) async fn run_turn(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
&mut client_session,
|
||||
turn_input,
|
||||
cancellation_token.child_token(),
|
||||
)
|
||||
@@ -2650,6 +2653,7 @@ async fn run_model_turn(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
client_session: &mut ModelClientSession,
|
||||
input: Vec<ResponseItem>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
@@ -2684,15 +2688,13 @@ async fn run_model_turn(
|
||||
output_schema: turn_context.final_output_json_schema.clone(),
|
||||
};
|
||||
|
||||
let mut client_session = turn_context.client.new_session();
|
||||
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
let err = match try_run_turn(
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
&mut client_session,
|
||||
client_session,
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
&prompt,
|
||||
cancellation_token.child_token(),
|
||||
|
||||
@@ -319,6 +319,15 @@ pub fn ev_completed(id: &str) -> Value {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn ev_done() -> Value {
|
||||
serde_json::json!({
|
||||
"type": "response.done",
|
||||
"response": {
|
||||
"usage": {"input_tokens":0,"input_tokens_details":null,"output_tokens":0,"output_tokens_details":null,"total_tokens":0}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Convenience: SSE event for a created response with a specific id.
|
||||
pub fn ev_response_created(id: &str) -> Value {
|
||||
serde_json::json!({
|
||||
|
||||
@@ -8,6 +8,7 @@ use codex_core::CodexAuth;
|
||||
use codex_core::CodexThread;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::ThreadManager;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::features::Feature;
|
||||
@@ -23,6 +24,7 @@ use tempfile::TempDir;
|
||||
use wiremock::MockServer;
|
||||
|
||||
use crate::load_default_config_for_test;
|
||||
use crate::responses::WebSocketTestServer;
|
||||
use crate::responses::start_mock_server;
|
||||
use crate::streaming_sse::StreamingSseServer;
|
||||
use crate::wait_for_event;
|
||||
@@ -101,6 +103,21 @@ impl TestCodexBuilder {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn build_with_websocket_server(
|
||||
&mut self,
|
||||
server: &WebSocketTestServer,
|
||||
) -> anyhow::Result<TestCodex> {
|
||||
let base_url = format!("{}/v1", server.uri());
|
||||
let home = Arc::new(TempDir::new()?);
|
||||
let base_url_clone = base_url.clone();
|
||||
self.config_mutators.push(Box::new(move |config| {
|
||||
config.model_provider.base_url = Some(base_url_clone);
|
||||
config.model_provider.wire_api = WireApi::ResponsesWebsocket;
|
||||
}));
|
||||
self.build_with_home_and_base_url(base_url, home, None)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn resume(
|
||||
&mut self,
|
||||
server: &wiremock::MockServer,
|
||||
|
||||
69
codex-rs/core/tests/suite/agent_websocket.rs
Normal file
69
codex-rs/core/tests/suite/agent_websocket.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
use anyhow::Result;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_done;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::ev_shell_command_call;
|
||||
use core_test_support::responses::start_websocket_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn websocket_test_codex_shell_chain() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let call_id = "shell-command-call";
|
||||
let server = start_websocket_server(vec![vec![
|
||||
vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_shell_command_call(call_id, "echo websocket"),
|
||||
ev_done(),
|
||||
],
|
||||
vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
],
|
||||
]])
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex();
|
||||
|
||||
let test = builder.build_with_websocket_server(&server).await?;
|
||||
test.submit_turn("run the echo command").await?;
|
||||
|
||||
let connection = server.single_connection();
|
||||
assert_eq!(connection.len(), 2);
|
||||
|
||||
let first = connection
|
||||
.first()
|
||||
.expect("missing first request")
|
||||
.body_json();
|
||||
let second = connection
|
||||
.get(1)
|
||||
.expect("missing second request")
|
||||
.body_json();
|
||||
|
||||
assert_eq!(first["type"].as_str(), Some("response.create"));
|
||||
assert_eq!(second["type"].as_str(), Some("response.append"));
|
||||
|
||||
let append_items = second
|
||||
.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.expect("response.append input array");
|
||||
assert!(!append_items.is_empty());
|
||||
|
||||
let output_item = append_items
|
||||
.iter()
|
||||
.find(|item| item.get("type").and_then(Value::as_str) == Some("function_call_output"))
|
||||
.expect("function_call_output in append");
|
||||
assert_eq!(
|
||||
output_item.get("call_id").and_then(Value::as_str),
|
||||
Some(call_id)
|
||||
);
|
||||
|
||||
server.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
@@ -15,12 +15,14 @@ pub static CODEX_ALIASES_TEMP_DIR: TempDir = unsafe {
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
mod abort_tasks;
|
||||
mod agent_websocket;
|
||||
mod apply_patch_cli;
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
mod approvals;
|
||||
mod auth_refresh;
|
||||
mod cli_stream;
|
||||
mod client;
|
||||
mod client_websockets;
|
||||
mod codex_delegate;
|
||||
mod compact;
|
||||
mod compact_remote;
|
||||
@@ -72,4 +74,3 @@ mod user_notification;
|
||||
mod user_shell_cmd;
|
||||
mod view_image;
|
||||
mod web_search_cached;
|
||||
mod websocket;
|
||||
|
||||
Reference in New Issue
Block a user