# PR #1538: Add tests for chat stream aggregation and tool events - URL: https://github.com/openai/codex/pull/1538 - Author: aibrahim-oai - Created: 2025-07-11 19:03:03 UTC - Updated: 2025-07-21 20:58:18 UTC - Changes: +323/-12, Files changed: 5, Commits: 12 ## Description ## Summary - unit test AggregatedChatStream to ensure it merges assistant message deltas and forwards other items - verify parsing of function_call_output and local_shell_call SSE events - ensure chat request payload encodes tool calls correctly ## Testing - `cargo test -p codex-core --manifest-path codex-rs/Cargo.toml` - `cargo test --manifest-path codex-rs/Cargo.toml --all --tests` *(fails: Sandbox(LandlockRestrict))* ------ https://chatgpt.com/codex/tasks/task_i_687158d61e748321ba5f1631199bd8a4 ## Full Diff ```diff diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index ad7b55952a..8eabcaf342 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -458,6 +458,9 @@ pub(crate) trait AggregateStreamExt: Stream> + Size /// // event now contains cumulative text /// } /// ``` + /// + /// See [`tests::aggregates_consecutive_message_chunks`] for an example. + /// ``` fn aggregate(self) -> AggregatedChatStream { AggregatedChatStream { inner: self, @@ -468,3 +471,237 @@ pub(crate) trait AggregateStreamExt: Stream> + Size } impl AggregateStreamExt for T where T: Stream> + Sized {} + +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + + use super::*; + use crate::models::FunctionCallOutputPayload; + use crate::models::LocalShellAction; + use crate::models::LocalShellExecAction; + use crate::models::LocalShellStatus; + use crate::openai_tools::create_tools_json_for_chat_completions_api; + use futures::StreamExt; + use futures::stream; + use serde_json::json; + + /// Helper constructing a minimal assistant text chunk. + fn text_chunk(txt: &str) -> ResponseEvent { + ResponseEvent::OutputItemDone(ResponseItem::Message { + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { text: txt.into() }], + }) + } + + #[tokio::test] + async fn aggregates_consecutive_message_chunks() { + let events = vec![ + Ok(text_chunk("Hello")), + Ok(text_chunk(", world")), + Ok(ResponseEvent::Completed { + response_id: "r1".to_string(), + token_usage: None, + }), + ]; + + let stream = stream::iter(events).aggregate(); + let collected: Vec<_> = stream.map(Result::unwrap).collect().await; + + let expected = vec![ + ResponseEvent::OutputItemDone(ResponseItem::Message { + role: "assistant".into(), + content: vec![ContentItem::OutputText { + text: "Hello, world".into(), + }], + }), + ResponseEvent::Completed { + response_id: "r1".into(), + token_usage: None, + }, + ]; + + assert_eq!( + collected, expected, + "aggregated assistant message + Completed" + ); + } + + #[tokio::test] + async fn forwards_non_text_items_without_merging() { + let func_call = ResponseItem::FunctionCall { + name: "shell".to_string(), + arguments: "{}".to_string(), + call_id: "call1".to_string(), + }; + + let events = vec![ + Ok(text_chunk("foo")), + Ok(ResponseEvent::OutputItemDone(func_call.clone())), + Ok(text_chunk("bar")), + Ok(ResponseEvent::Completed { + response_id: "r2".to_string(), + token_usage: None, + }), + ]; + + let stream = stream::iter(events).aggregate(); + let collected: Vec<_> = stream.map(Result::unwrap).collect().await; + + let expected = vec![ + ResponseEvent::OutputItemDone(func_call.clone()), + ResponseEvent::OutputItemDone(ResponseItem::Message { + role: "assistant".into(), + content: vec![ContentItem::OutputText { + text: "foobar".into(), + }], + }), + ResponseEvent::Completed { + response_id: "r2".into(), + token_usage: None, + }, + ]; + + assert_eq!( + collected, expected, + "non-text items forwarded intact; text merged" + ); + } + + #[tokio::test] + async fn formats_tool_calls_in_chat_payload() { + use std::sync::Arc; + use std::sync::Mutex; + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::Request; + use wiremock::Respond; + use wiremock::ResponseTemplate; + use wiremock::matchers::method; + use wiremock::matchers::path; + + struct CaptureResponder(Arc>>); + impl Respond for CaptureResponder { + fn respond(&self, req: &Request) -> ResponseTemplate { + let v: serde_json::Value = serde_json::from_slice(&req.body).unwrap(); + *self.0.lock().unwrap() = Some(v); + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw( + "event: response.completed\n\ +data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n", + "text/event-stream", + ) + } + } + + let server = MockServer::start().await; + let captured = Arc::new(Mutex::new(None)); + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(CaptureResponder(captured.clone())) + .expect(1) + .mount(&server) + .await; + + // Build provider pointing at mock server; no need to mutate global env vars. + let provider = ModelProviderInfo { + name: "openai".into(), + base_url: format!("{}/v1", server.uri()), + env_key: Some("PATH".into()), + env_key_instructions: None, + wire_api: crate::WireApi::Chat, + query_params: None, + http_headers: None, + env_http_headers: None, + }; + + let mut prompt = Prompt::default(); + prompt.input.push(ResponseItem::Message { + role: "user".into(), + content: vec![ContentItem::InputText { text: "hi".into() }], + }); + prompt.input.push(ResponseItem::FunctionCall { + name: "shell".into(), + arguments: "[]".into(), + call_id: "call123".into(), + }); + prompt.input.push(ResponseItem::FunctionCallOutput { + call_id: "call123".into(), + output: FunctionCallOutputPayload { + content: "ok".into(), + success: Some(true), + }, + }); + prompt.input.push(ResponseItem::LocalShellCall { + id: Some("ls1".into()), + call_id: Some("call456".into()), + status: LocalShellStatus::Completed, + action: LocalShellAction::Exec(LocalShellExecAction { + command: vec!["echo".into(), "hi".into()], + timeout_ms: Some(1), + working_directory: None, + env: None, + user: None, + }), + }); + + let client = reqwest::Client::new(); + let _ = stream_chat_completions(&prompt, "model", &client, &provider) + .await + .unwrap(); + + let body = captured.lock().unwrap().take().unwrap(); + + // Build the expected payload exactly as stream_chat_completions() should. + let full_instructions = prompt.get_full_instructions("model"); + let expected_messages = vec![ + json!({"role":"system","content":full_instructions}), + json!({"role":"user","content":"hi"}), + json!({ + "role":"assistant", + "content":null, + "tool_calls":[{ + "id":"call123", + "type":"function", + "function":{ + "name":"shell", + "arguments":"[]" + } + }] + }), + json!({ + "role":"tool", + "tool_call_id":"call123", + "content":"ok" + }), + json!({ + "role":"assistant", + "content":null, + "tool_calls":[{ + "id":"ls1", + "type":"local_shell_call", + "status":"completed", + "action":{ + "type":"exec", + "command":["echo","hi"], + "timeout_ms":1, + "working_directory":null, + "env":null, + "user":null + } + }] + }), + ]; + let tools_json = create_tools_json_for_chat_completions_api(&prompt, "model").unwrap(); + let expected_body = json!({ + "model":"model", + "messages": expected_messages, + "stream": true, + "tools": tools_json, + }); + + assert_eq!(body, expected_body, "chat payload encoded incorrectly"); + } +} diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 8ec68d02e8..b2ff284fd0 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -317,7 +317,7 @@ where // duplicated `output` array embedded in the `response.completed` // payload. That produced two concrete issues: // 1. No real‑time streaming – the user only saw output after the - // entire turn had finished, which broke the “typing” UX and + // entire turn had finished, which broke the "typing" UX and // made long‑running turns look stalled. // 2. Duplicate `function_call_output` items – both the // individual *and* the completed array were forwarded, which @@ -390,6 +390,7 @@ where } /// used in tests to stream from a text SSE file +#[allow(dead_code)] async fn stream_from_fixture(path: impl AsRef) -> Result { let (tx_event, rx_event) = mpsc::channel::>(1600); let f = std::fs::File::open(path.as_ref())?; @@ -413,6 +414,8 @@ mod tests { #![allow(clippy::expect_used, clippy::unwrap_used)] use super::*; + use crate::models::LocalShellAction; + use crate::models::LocalShellStatus; use serde_json::json; use tokio::sync::mpsc; use tokio_test::io::Builder as IoBuilder; @@ -422,6 +425,17 @@ mod tests { // Helpers // ──────────────────────────── + /// Build a tiny SSE string with the provided *raw* event chunks (already formatted as + /// `"event: ...\ndata: ..."` lines). Each chunk is separated by a blank line. + fn build_sse(chunks: &[&str]) -> String { + let mut out = String::new(); + for c in chunks { + out.push_str(c); + out.push_str("\n\n"); + } + out + } + /// Runs the SSE parser on pre-chunked byte slices and returns every event /// (including any final `Err` from a stream-closure check). async fn collect_events(chunks: &[&[u8]]) -> Vec> { @@ -469,6 +483,65 @@ mod tests { out } + // ──────────────────────────── + // Tests from `implement-unit-tests-for-event-aggregation-and-tool-calls` + // ──────────────────────────── + + #[tokio::test] + async fn parses_function_and_local_shell_items() { + let func = "event: response.output_item.done\n\ +data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call_output\",\"call_id\":\"call1\",\"output\":{\"content\":\"ok\",\"success\":true}}}"; + let shell = "event: response.output_item.done\n\ +data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"local_shell_call\",\"id\":\"ls1\",\"call_id\":\"call2\",\"status\":\"in_progress\",\"action\":{\"type\":\"exec\",\"command\":[\"echo\",\"hi\"],\"timeout_ms\":123,\"working_directory\":null,\"env\":null,\"user\":null}}}"; + let done = "event: response.completed\n\ +data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}"; + + let content = build_sse(&[func, shell, done]); + + let (tx, mut rx) = tokio::sync::mpsc::channel::>(8); + let stream = ReaderStream::new(std::io::Cursor::new(content)).map_err(CodexErr::Io); + tokio::spawn(super::process_sse(stream, tx)); + + // function_call_output + match rx.recv().await.unwrap().unwrap() { + ResponseEvent::OutputItemDone(ResponseItem::FunctionCallOutput { call_id, output }) => { + assert_eq!(call_id, "call1"); + assert_eq!(output.content, "ok"); + assert_eq!(output.success, Some(true)); + } + other => panic!("unexpected first event: {other:?}"), + } + + // local_shell_call + match rx.recv().await.unwrap().unwrap() { + ResponseEvent::OutputItemDone(ResponseItem::LocalShellCall { + id, + call_id, + status, + action, + }) => { + assert_eq!(id.as_deref(), Some("ls1")); + assert_eq!(call_id.as_deref(), Some("call2")); + if !matches!(status, LocalShellStatus::InProgress) { + panic!("unexpected status: {status:?}"); + } + match action { + LocalShellAction::Exec(act) => { + assert_eq!(act.command, vec!["echo".to_string(), "hi".to_string()]); + assert_eq!(act.timeout_ms, Some(123)); + } + } + } + other => panic!("unexpected second event: {other:?}"), + } + + // completed + assert!(matches!( + rx.recv().await.unwrap().unwrap(), + ResponseEvent::Completed { response_id, .. } if response_id == "resp" + )); + } + // ──────────────────────────── // Tests from `implement-test-for-responses-api-sse-parser` // ──────────────────────────── @@ -549,6 +622,7 @@ mod tests { let events = collect_events(&[sse1.as_bytes()]).await; + // We expect the item + a final Err complaining about the missing completed event. assert_eq!(events.len(), 2); matches!(events[0], Ok(ResponseEvent::OutputItemDone(_))); diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 3e3c2e7efa..6b220d4fff 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -49,7 +49,7 @@ impl Prompt { } } -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq)] pub enum ResponseEvent { Created, OutputItemDone(ResponseItem), diff --git a/codex-rs/core/src/models.rs b/codex-rs/core/src/models.rs index 6b392fb19d..26babba715 100644 --- a/codex-rs/core/src/models.rs +++ b/codex-rs/core/src/models.rs @@ -8,7 +8,7 @@ use serde::ser::Serializer; use crate::protocol::InputItem; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ResponseInputItem { Message { @@ -25,7 +25,7 @@ pub enum ResponseInputItem { }, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ContentItem { InputText { text: String }, @@ -33,7 +33,7 @@ pub enum ContentItem { OutputText { text: String }, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ResponseItem { Message { @@ -99,7 +99,7 @@ impl From for ResponseItem { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "snake_case")] pub enum LocalShellStatus { Completed, @@ -107,13 +107,13 @@ pub enum LocalShellStatus { Incomplete, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(tag = "type", rename_all = "snake_case")] pub enum LocalShellAction { Exec(LocalShellExecAction), } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct LocalShellExecAction { pub command: Vec, pub timeout_ms: Option, @@ -122,7 +122,7 @@ pub struct LocalShellExecAction { pub user: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ReasoningItemReasoningSummary { SummaryText { text: String }, @@ -177,10 +177,10 @@ pub struct ShellToolCallParams { pub timeout_ms: Option, } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, PartialEq)] pub struct FunctionCallOutputPayload { pub content: String, - #[expect(dead_code)] + #[allow(dead_code)] pub success: Option, } diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs index b233d4f27b..c14b2e190a 100644 --- a/codex-rs/core/src/protocol.rs +++ b/codex-rs/core/src/protocol.rs @@ -332,7 +332,7 @@ pub struct TaskCompleteEvent { pub last_agent_message: Option, } -#[derive(Debug, Clone, Deserialize, Serialize, Default)] +#[derive(Debug, Clone, Deserialize, Serialize, Default, PartialEq)] pub struct TokenUsage { pub input_tokens: u64, pub cached_input_tokens: Option, ``` ## Review Comments ### codex-rs/core/src/chat_completions.rs - Created: 2025-07-12 19:43:31 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202890339 ```diff @@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream> + Size } impl AggregateStreamExt for T where T: Stream> + Sized {} + +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + + use super::*; + use crate::models::FunctionCallOutputPayload; + use crate::models::LocalShellAction; + use crate::models::LocalShellExecAction; + use crate::models::LocalShellStatus; + use futures::StreamExt; + use futures::stream; + + /// Helper constructing a minimal assistant text chunk. + fn text_chunk(txt: &str) -> ResponseEvent { + ResponseEvent::OutputItemDone(ResponseItem::Message { + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { text: txt.into() }], + }) + } + + #[tokio::test] + async fn aggregates_consecutive_message_chunks() { + let events = vec![ + Ok(text_chunk("Hello")), + Ok(text_chunk(", world")), + Ok(ResponseEvent::Completed { + response_id: "r1".to_string(), + token_usage: None, + }), + ]; + + let stream = stream::iter(events).aggregate(); + let collected: Vec<_> = stream.map(Result::unwrap).collect().await; + + assert_eq!(collected.len(), 2, "only final message and Completed"); ``` > just `assert_eq!()` on all of `collected`? - Created: 2025-07-12 19:44:08 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202890464 ```diff @@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream> + Size } impl AggregateStreamExt for T where T: Stream> + Sized {} + +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + + use super::*; + use crate::models::FunctionCallOutputPayload; + use crate::models::LocalShellAction; + use crate::models::LocalShellExecAction; + use crate::models::LocalShellStatus; + use futures::StreamExt; + use futures::stream; + + /// Helper constructing a minimal assistant text chunk. + fn text_chunk(txt: &str) -> ResponseEvent { + ResponseEvent::OutputItemDone(ResponseItem::Message { + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { text: txt.into() }], + }) + } + + #[tokio::test] + async fn aggregates_consecutive_message_chunks() { + let events = vec![ + Ok(text_chunk("Hello")), + Ok(text_chunk(", world")), + Ok(ResponseEvent::Completed { + response_id: "r1".to_string(), + token_usage: None, + }), + ]; + + let stream = stream::iter(events).aggregate(); + let collected: Vec<_> = stream.map(Result::unwrap).collect().await; + + assert_eq!(collected.len(), 2, "only final message and Completed"); + + match &collected[0] { + ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => { + let text = match &content[0] { + ContentItem::OutputText { text } => text, + _ => panic!("unexpected content item"), + }; + assert_eq!(text, "Hello, world"); + } + other => panic!("unexpected first event: {other:?}"), + } + + assert!(matches!( + collected[1], + ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r1" + )); + } + + #[tokio::test] + async fn forwards_non_text_items_without_merging() { + let func_call = ResponseItem::FunctionCall { + name: "shell".to_string(), + arguments: "{}".to_string(), + call_id: "call1".to_string(), + }; + + let events = vec![ + Ok(text_chunk("foo")), + Ok(ResponseEvent::OutputItemDone(func_call.clone())), + Ok(text_chunk("bar")), + Ok(ResponseEvent::Completed { + response_id: "r2".to_string(), + token_usage: None, + }), + ]; + + let stream = stream::iter(events).aggregate(); + let collected: Vec<_> = stream.map(Result::unwrap).collect().await; + + assert_eq!(collected.len(), 3); ``` > same here - Created: 2025-07-12 19:45:40 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202891601 ```diff @@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream> + Size } impl AggregateStreamExt for T where T: Stream> + Sized {} + +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + + use super::*; + use crate::models::FunctionCallOutputPayload; + use crate::models::LocalShellAction; + use crate::models::LocalShellExecAction; + use crate::models::LocalShellStatus; + use futures::StreamExt; + use futures::stream; + + /// Helper constructing a minimal assistant text chunk. + fn text_chunk(txt: &str) -> ResponseEvent { + ResponseEvent::OutputItemDone(ResponseItem::Message { + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { text: txt.into() }], + }) + } + + #[tokio::test] + async fn aggregates_consecutive_message_chunks() { + let events = vec![ + Ok(text_chunk("Hello")), + Ok(text_chunk(", world")), + Ok(ResponseEvent::Completed { + response_id: "r1".to_string(), + token_usage: None, + }), + ]; + + let stream = stream::iter(events).aggregate(); + let collected: Vec<_> = stream.map(Result::unwrap).collect().await; + + assert_eq!(collected.len(), 2, "only final message and Completed"); + + match &collected[0] { + ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => { + let text = match &content[0] { + ContentItem::OutputText { text } => text, + _ => panic!("unexpected content item"), + }; + assert_eq!(text, "Hello, world"); + } + other => panic!("unexpected first event: {other:?}"), + } + + assert!(matches!( + collected[1], + ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r1" + )); + } + + #[tokio::test] + async fn forwards_non_text_items_without_merging() { + let func_call = ResponseItem::FunctionCall { + name: "shell".to_string(), + arguments: "{}".to_string(), + call_id: "call1".to_string(), + }; + + let events = vec![ + Ok(text_chunk("foo")), + Ok(ResponseEvent::OutputItemDone(func_call.clone())), + Ok(text_chunk("bar")), + Ok(ResponseEvent::Completed { + response_id: "r2".to_string(), + token_usage: None, + }), + ]; + + let stream = stream::iter(events).aggregate(); + let collected: Vec<_> = stream.map(Result::unwrap).collect().await; + + assert_eq!(collected.len(), 3); + + // First event should be the function call forwarded directly. + assert!(matches!( + collected[0], + ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. }) + )); + + // Second is the combined assistant message. + match &collected[1] { + ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => { + let text = match &content[0] { + ContentItem::OutputText { text } => text, + _ => panic!("unexpected content item"), + }; + assert_eq!(text, "foobar"); + } + other => panic!("unexpected second event: {other:?}"), + } + + // Final Completed event. + assert!(matches!( + collected[2], + ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r2" + )); + } + + #[tokio::test] + async fn formats_tool_calls_in_chat_payload() { + use serde_json::Value; + use std::sync::Arc; + use std::sync::Mutex; + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::Request; + use wiremock::Respond; + use wiremock::ResponseTemplate; + use wiremock::matchers::method; + use wiremock::matchers::path; + + struct CaptureResponder(Arc>>); + impl Respond for CaptureResponder { + fn respond(&self, req: &Request) -> ResponseTemplate { + let v: Value = serde_json::from_slice(&req.body).unwrap(); + *self.0.lock().unwrap() = Some(v); + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw( + "event: response.completed\n\ +data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n", + "text/event-stream", + ) + } + } + + let server = MockServer::start().await; + let captured = Arc::new(Mutex::new(None)); + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(CaptureResponder(captured.clone())) + .expect(1) + .mount(&server) + .await; + + unsafe { + std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0"); ``` > At some point, we should really find another way to thread this through so we can eliminate all these `unsafe` blocks. - Created: 2025-07-12 19:46:18 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202892586 ```diff @@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream> + Size } impl AggregateStreamExt for T where T: Stream> + Sized {} + +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + + use super::*; + use crate::models::FunctionCallOutputPayload; + use crate::models::LocalShellAction; + use crate::models::LocalShellExecAction; + use crate::models::LocalShellStatus; + use futures::StreamExt; + use futures::stream; + + /// Helper constructing a minimal assistant text chunk. + fn text_chunk(txt: &str) -> ResponseEvent { + ResponseEvent::OutputItemDone(ResponseItem::Message { + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { text: txt.into() }], + }) + } + + #[tokio::test] + async fn aggregates_consecutive_message_chunks() { + let events = vec![ + Ok(text_chunk("Hello")), + Ok(text_chunk(", world")), + Ok(ResponseEvent::Completed { + response_id: "r1".to_string(), + token_usage: None, + }), + ]; + + let stream = stream::iter(events).aggregate(); + let collected: Vec<_> = stream.map(Result::unwrap).collect().await; + + assert_eq!(collected.len(), 2, "only final message and Completed"); + + match &collected[0] { + ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => { + let text = match &content[0] { + ContentItem::OutputText { text } => text, + _ => panic!("unexpected content item"), + }; + assert_eq!(text, "Hello, world"); + } + other => panic!("unexpected first event: {other:?}"), + } + + assert!(matches!( + collected[1], + ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r1" + )); + } + + #[tokio::test] + async fn forwards_non_text_items_without_merging() { + let func_call = ResponseItem::FunctionCall { + name: "shell".to_string(), + arguments: "{}".to_string(), + call_id: "call1".to_string(), + }; + + let events = vec![ + Ok(text_chunk("foo")), + Ok(ResponseEvent::OutputItemDone(func_call.clone())), + Ok(text_chunk("bar")), + Ok(ResponseEvent::Completed { + response_id: "r2".to_string(), + token_usage: None, + }), + ]; + + let stream = stream::iter(events).aggregate(); + let collected: Vec<_> = stream.map(Result::unwrap).collect().await; + + assert_eq!(collected.len(), 3); + + // First event should be the function call forwarded directly. + assert!(matches!( + collected[0], + ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. }) + )); + + // Second is the combined assistant message. + match &collected[1] { + ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => { + let text = match &content[0] { + ContentItem::OutputText { text } => text, + _ => panic!("unexpected content item"), + }; + assert_eq!(text, "foobar"); + } + other => panic!("unexpected second event: {other:?}"), + } + + // Final Completed event. + assert!(matches!( + collected[2], + ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r2" + )); + } + + #[tokio::test] + async fn formats_tool_calls_in_chat_payload() { + use serde_json::Value; + use std::sync::Arc; + use std::sync::Mutex; + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::Request; + use wiremock::Respond; + use wiremock::ResponseTemplate; + use wiremock::matchers::method; + use wiremock::matchers::path; + + struct CaptureResponder(Arc>>); + impl Respond for CaptureResponder { + fn respond(&self, req: &Request) -> ResponseTemplate { + let v: Value = serde_json::from_slice(&req.body).unwrap(); + *self.0.lock().unwrap() = Some(v); + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw( + "event: response.completed\n\ +data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n", + "text/event-stream", + ) + } + } + + let server = MockServer::start().await; + let captured = Arc::new(Mutex::new(None)); + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(CaptureResponder(captured.clone())) + .expect(1) + .mount(&server) + .await; + + unsafe { + std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0"); + } + + let provider = ModelProviderInfo { + name: "openai".into(), + base_url: format!("{}/v1", server.uri()), + env_key: Some("PATH".into()), + env_key_instructions: None, + wire_api: crate::WireApi::Chat, + query_params: None, + http_headers: None, + env_http_headers: None, + }; + + let mut prompt = Prompt::default(); + prompt.input.push(ResponseItem::Message { + role: "user".into(), + content: vec![ContentItem::InputText { text: "hi".into() }], + }); + prompt.input.push(ResponseItem::FunctionCall { + name: "shell".into(), + arguments: "[]".into(), + call_id: "call123".into(), + }); + prompt.input.push(ResponseItem::FunctionCallOutput { + call_id: "call123".into(), + output: FunctionCallOutputPayload { + content: "ok".into(), + success: Some(true), + }, + }); + prompt.input.push(ResponseItem::LocalShellCall { + id: Some("ls1".into()), + call_id: Some("call456".into()), + status: LocalShellStatus::Completed, + action: LocalShellAction::Exec(LocalShellExecAction { + command: vec!["echo".into(), "hi".into()], + timeout_ms: Some(1), + working_directory: None, + env: None, + user: None, + }), + }); + + let client = reqwest::Client::new(); + let _ = stream_chat_completions(&prompt, "model", &client, &provider) + .await + .unwrap(); + + let body = captured.lock().unwrap().take().unwrap(); + let messages = body.get("messages").unwrap().as_array().unwrap(); ``` > `assert_eq!()` for `body`