Files
codex/prs/bolinfest/PR-1538.md
2025-09-02 15:17:45 -07:00

35 KiB
Raw Blame History

PR #1538: Add tests for chat stream aggregation and tool events

Description

Summary

  • unit test AggregatedChatStream to ensure it merges assistant message deltas and forwards other items
  • verify parsing of function_call_output and local_shell_call SSE events
  • ensure chat request payload encodes tool calls correctly

Testing

  • cargo test -p codex-core --manifest-path codex-rs/Cargo.toml
  • cargo test --manifest-path codex-rs/Cargo.toml --all --tests (fails: Sandbox(LandlockRestrict))

https://chatgpt.com/codex/tasks/task_i_687158d61e748321ba5f1631199bd8a4

Full Diff

diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs
index ad7b55952a..8eabcaf342 100644
--- a/codex-rs/core/src/chat_completions.rs
+++ b/codex-rs/core/src/chat_completions.rs
@@ -458,6 +458,9 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
     ///     // event now contains cumulative text
     /// }
     /// ```
+    ///
+    /// See [`tests::aggregates_consecutive_message_chunks`] for an example.
+    /// ```
     fn aggregate(self) -> AggregatedChatStream<Self> {
         AggregatedChatStream {
             inner: self,
@@ -468,3 +471,237 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
 }
 
 impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
+
+#[cfg(test)]
+mod tests {
+    #![allow(clippy::unwrap_used)]
+
+    use super::*;
+    use crate::models::FunctionCallOutputPayload;
+    use crate::models::LocalShellAction;
+    use crate::models::LocalShellExecAction;
+    use crate::models::LocalShellStatus;
+    use crate::openai_tools::create_tools_json_for_chat_completions_api;
+    use futures::StreamExt;
+    use futures::stream;
+    use serde_json::json;
+
+    /// Helper constructing a minimal assistant text chunk.
+    fn text_chunk(txt: &str) -> ResponseEvent {
+        ResponseEvent::OutputItemDone(ResponseItem::Message {
+            role: "assistant".to_string(),
+            content: vec![ContentItem::OutputText { text: txt.into() }],
+        })
+    }
+
+    #[tokio::test]
+    async fn aggregates_consecutive_message_chunks() {
+        let events = vec![
+            Ok(text_chunk("Hello")),
+            Ok(text_chunk(", world")),
+            Ok(ResponseEvent::Completed {
+                response_id: "r1".to_string(),
+                token_usage: None,
+            }),
+        ];
+
+        let stream = stream::iter(events).aggregate();
+        let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+        let expected = vec![
+            ResponseEvent::OutputItemDone(ResponseItem::Message {
+                role: "assistant".into(),
+                content: vec![ContentItem::OutputText {
+                    text: "Hello, world".into(),
+                }],
+            }),
+            ResponseEvent::Completed {
+                response_id: "r1".into(),
+                token_usage: None,
+            },
+        ];
+
+        assert_eq!(
+            collected, expected,
+            "aggregated assistant message + Completed"
+        );
+    }
+
+    #[tokio::test]
+    async fn forwards_non_text_items_without_merging() {
+        let func_call = ResponseItem::FunctionCall {
+            name: "shell".to_string(),
+            arguments: "{}".to_string(),
+            call_id: "call1".to_string(),
+        };
+
+        let events = vec![
+            Ok(text_chunk("foo")),
+            Ok(ResponseEvent::OutputItemDone(func_call.clone())),
+            Ok(text_chunk("bar")),
+            Ok(ResponseEvent::Completed {
+                response_id: "r2".to_string(),
+                token_usage: None,
+            }),
+        ];
+
+        let stream = stream::iter(events).aggregate();
+        let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+        let expected = vec![
+            ResponseEvent::OutputItemDone(func_call.clone()),
+            ResponseEvent::OutputItemDone(ResponseItem::Message {
+                role: "assistant".into(),
+                content: vec![ContentItem::OutputText {
+                    text: "foobar".into(),
+                }],
+            }),
+            ResponseEvent::Completed {
+                response_id: "r2".into(),
+                token_usage: None,
+            },
+        ];
+
+        assert_eq!(
+            collected, expected,
+            "non-text items forwarded intact; text merged"
+        );
+    }
+
+    #[tokio::test]
+    async fn formats_tool_calls_in_chat_payload() {
+        use std::sync::Arc;
+        use std::sync::Mutex;
+        use wiremock::Mock;
+        use wiremock::MockServer;
+        use wiremock::Request;
+        use wiremock::Respond;
+        use wiremock::ResponseTemplate;
+        use wiremock::matchers::method;
+        use wiremock::matchers::path;
+
+        struct CaptureResponder(Arc<Mutex<Option<serde_json::Value>>>);
+        impl Respond for CaptureResponder {
+            fn respond(&self, req: &Request) -> ResponseTemplate {
+                let v: serde_json::Value = serde_json::from_slice(&req.body).unwrap();
+                *self.0.lock().unwrap() = Some(v);
+                ResponseTemplate::new(200)
+                    .insert_header("content-type", "text/event-stream")
+                    .set_body_raw(
+                        "event: response.completed\n\
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n",
+                        "text/event-stream",
+                    )
+            }
+        }
+
+        let server = MockServer::start().await;
+        let captured = Arc::new(Mutex::new(None));
+
+        Mock::given(method("POST"))
+            .and(path("/v1/chat/completions"))
+            .respond_with(CaptureResponder(captured.clone()))
+            .expect(1)
+            .mount(&server)
+            .await;
+
+        // Build provider pointing at mock server; no need to mutate global env vars.
+        let provider = ModelProviderInfo {
+            name: "openai".into(),
+            base_url: format!("{}/v1", server.uri()),
+            env_key: Some("PATH".into()),
+            env_key_instructions: None,
+            wire_api: crate::WireApi::Chat,
+            query_params: None,
+            http_headers: None,
+            env_http_headers: None,
+        };
+
+        let mut prompt = Prompt::default();
+        prompt.input.push(ResponseItem::Message {
+            role: "user".into(),
+            content: vec![ContentItem::InputText { text: "hi".into() }],
+        });
+        prompt.input.push(ResponseItem::FunctionCall {
+            name: "shell".into(),
+            arguments: "[]".into(),
+            call_id: "call123".into(),
+        });
+        prompt.input.push(ResponseItem::FunctionCallOutput {
+            call_id: "call123".into(),
+            output: FunctionCallOutputPayload {
+                content: "ok".into(),
+                success: Some(true),
+            },
+        });
+        prompt.input.push(ResponseItem::LocalShellCall {
+            id: Some("ls1".into()),
+            call_id: Some("call456".into()),
+            status: LocalShellStatus::Completed,
+            action: LocalShellAction::Exec(LocalShellExecAction {
+                command: vec!["echo".into(), "hi".into()],
+                timeout_ms: Some(1),
+                working_directory: None,
+                env: None,
+                user: None,
+            }),
+        });
+
+        let client = reqwest::Client::new();
+        let _ = stream_chat_completions(&prompt, "model", &client, &provider)
+            .await
+            .unwrap();
+
+        let body = captured.lock().unwrap().take().unwrap();
+
+        // Build the expected payload exactly as stream_chat_completions() should.
+        let full_instructions = prompt.get_full_instructions("model");
+        let expected_messages = vec![
+            json!({"role":"system","content":full_instructions}),
+            json!({"role":"user","content":"hi"}),
+            json!({
+                "role":"assistant",
+                "content":null,
+                "tool_calls":[{
+                    "id":"call123",
+                    "type":"function",
+                    "function":{
+                        "name":"shell",
+                        "arguments":"[]"
+                    }
+                }]
+            }),
+            json!({
+                "role":"tool",
+                "tool_call_id":"call123",
+                "content":"ok"
+            }),
+            json!({
+                "role":"assistant",
+                "content":null,
+                "tool_calls":[{
+                    "id":"ls1",
+                    "type":"local_shell_call",
+                    "status":"completed",
+                    "action":{
+                        "type":"exec",
+                        "command":["echo","hi"],
+                        "timeout_ms":1,
+                        "working_directory":null,
+                        "env":null,
+                        "user":null
+                    }
+                }]
+            }),
+        ];
+        let tools_json = create_tools_json_for_chat_completions_api(&prompt, "model").unwrap();
+        let expected_body = json!({
+            "model":"model",
+            "messages": expected_messages,
+            "stream": true,
+            "tools": tools_json,
+        });
+
+        assert_eq!(body, expected_body, "chat payload encoded incorrectly");
+    }
+}
diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs
index 8ec68d02e8..b2ff284fd0 100644
--- a/codex-rs/core/src/client.rs
+++ b/codex-rs/core/src/client.rs
@@ -317,7 +317,7 @@ where
             // duplicated `output` array embedded in the `response.completed`
             // payload.  That produced two concrete issues:
             //   1. No realtime streaming  the user only saw output after the
-            //      entire turn had finished, which broke the “typing” UX and
+            //      entire turn had finished, which broke the "typing" UX and
             //      made longrunning turns look stalled.
             //   2. Duplicate `function_call_output` items  both the
             //      individual *and* the completed array were forwarded, which
@@ -390,6 +390,7 @@ where
 }
 
 /// used in tests to stream from a text SSE file
+#[allow(dead_code)]
 async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
     let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
     let f = std::fs::File::open(path.as_ref())?;
@@ -413,6 +414,8 @@ mod tests {
     #![allow(clippy::expect_used, clippy::unwrap_used)]
 
     use super::*;
+    use crate::models::LocalShellAction;
+    use crate::models::LocalShellStatus;
     use serde_json::json;
     use tokio::sync::mpsc;
     use tokio_test::io::Builder as IoBuilder;
@@ -422,6 +425,17 @@ mod tests {
     // Helpers
     // ────────────────────────────
 
+    /// Build a tiny SSE string with the provided *raw* event chunks (already formatted as
+    /// `"event: ...\ndata: ..."` lines). Each chunk is separated by a blank line.
+    fn build_sse(chunks: &[&str]) -> String {
+        let mut out = String::new();
+        for c in chunks {
+            out.push_str(c);
+            out.push_str("\n\n");
+        }
+        out
+    }
+
     /// Runs the SSE parser on pre-chunked byte slices and returns every event
     /// (including any final `Err` from a stream-closure check).
     async fn collect_events(chunks: &[&[u8]]) -> Vec<Result<ResponseEvent>> {
@@ -469,6 +483,65 @@ mod tests {
         out
     }
 
+    // ────────────────────────────
+    // Tests from `implement-unit-tests-for-event-aggregation-and-tool-calls`
+    // ────────────────────────────
+
+    #[tokio::test]
+    async fn parses_function_and_local_shell_items() {
+        let func = "event: response.output_item.done\n\
+data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call_output\",\"call_id\":\"call1\",\"output\":{\"content\":\"ok\",\"success\":true}}}";
+        let shell = "event: response.output_item.done\n\
+data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"local_shell_call\",\"id\":\"ls1\",\"call_id\":\"call2\",\"status\":\"in_progress\",\"action\":{\"type\":\"exec\",\"command\":[\"echo\",\"hi\"],\"timeout_ms\":123,\"working_directory\":null,\"env\":null,\"user\":null}}}";
+        let done = "event: response.completed\n\
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}";
+
+        let content = build_sse(&[func, shell, done]);
+
+        let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<ResponseEvent>>(8);
+        let stream = ReaderStream::new(std::io::Cursor::new(content)).map_err(CodexErr::Io);
+        tokio::spawn(super::process_sse(stream, tx));
+
+        // function_call_output
+        match rx.recv().await.unwrap().unwrap() {
+            ResponseEvent::OutputItemDone(ResponseItem::FunctionCallOutput { call_id, output }) => {
+                assert_eq!(call_id, "call1");
+                assert_eq!(output.content, "ok");
+                assert_eq!(output.success, Some(true));
+            }
+            other => panic!("unexpected first event: {other:?}"),
+        }
+
+        // local_shell_call
+        match rx.recv().await.unwrap().unwrap() {
+            ResponseEvent::OutputItemDone(ResponseItem::LocalShellCall {
+                id,
+                call_id,
+                status,
+                action,
+            }) => {
+                assert_eq!(id.as_deref(), Some("ls1"));
+                assert_eq!(call_id.as_deref(), Some("call2"));
+                if !matches!(status, LocalShellStatus::InProgress) {
+                    panic!("unexpected status: {status:?}");
+                }
+                match action {
+                    LocalShellAction::Exec(act) => {
+                        assert_eq!(act.command, vec!["echo".to_string(), "hi".to_string()]);
+                        assert_eq!(act.timeout_ms, Some(123));
+                    }
+                }
+            }
+            other => panic!("unexpected second event: {other:?}"),
+        }
+
+        // completed
+        assert!(matches!(
+            rx.recv().await.unwrap().unwrap(),
+            ResponseEvent::Completed { response_id, .. } if response_id == "resp"
+        ));
+    }
+
     // ────────────────────────────
     // Tests from `implement-test-for-responses-api-sse-parser`
     // ────────────────────────────
@@ -549,6 +622,7 @@ mod tests {
 
         let events = collect_events(&[sse1.as_bytes()]).await;
 
+        // We expect the item + a final Err complaining about the missing completed event.
         assert_eq!(events.len(), 2);
 
         matches!(events[0], Ok(ResponseEvent::OutputItemDone(_)));
diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs
index 3e3c2e7efa..6b220d4fff 100644
--- a/codex-rs/core/src/client_common.rs
+++ b/codex-rs/core/src/client_common.rs
@@ -49,7 +49,7 @@ impl Prompt {
     }
 }
 
-#[derive(Debug)]
+#[derive(Debug, Clone, PartialEq)]
 pub enum ResponseEvent {
     Created,
     OutputItemDone(ResponseItem),
diff --git a/codex-rs/core/src/models.rs b/codex-rs/core/src/models.rs
index 6b392fb19d..26babba715 100644
--- a/codex-rs/core/src/models.rs
+++ b/codex-rs/core/src/models.rs
@@ -8,7 +8,7 @@ use serde::ser::Serializer;
 
 use crate::protocol::InputItem;
 
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 #[serde(tag = "type", rename_all = "snake_case")]
 pub enum ResponseInputItem {
     Message {
@@ -25,7 +25,7 @@ pub enum ResponseInputItem {
     },
 }
 
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 #[serde(tag = "type", rename_all = "snake_case")]
 pub enum ContentItem {
     InputText { text: String },
@@ -33,7 +33,7 @@ pub enum ContentItem {
     OutputText { text: String },
 }
 
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 #[serde(tag = "type", rename_all = "snake_case")]
 pub enum ResponseItem {
     Message {
@@ -99,7 +99,7 @@ impl From<ResponseInputItem> for ResponseItem {
     }
 }
 
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 #[serde(rename_all = "snake_case")]
 pub enum LocalShellStatus {
     Completed,
@@ -107,13 +107,13 @@ pub enum LocalShellStatus {
     Incomplete,
 }
 
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 #[serde(tag = "type", rename_all = "snake_case")]
 pub enum LocalShellAction {
     Exec(LocalShellExecAction),
 }
 
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 pub struct LocalShellExecAction {
     pub command: Vec<String>,
     pub timeout_ms: Option<u64>,
@@ -122,7 +122,7 @@ pub struct LocalShellExecAction {
     pub user: Option<String>,
 }
 
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 #[serde(tag = "type", rename_all = "snake_case")]
 pub enum ReasoningItemReasoningSummary {
     SummaryText { text: String },
@@ -177,10 +177,10 @@ pub struct ShellToolCallParams {
     pub timeout_ms: Option<u64>,
 }
 
-#[derive(Deserialize, Debug, Clone)]
+#[derive(Deserialize, Debug, Clone, PartialEq)]
 pub struct FunctionCallOutputPayload {
     pub content: String,
-    #[expect(dead_code)]
+    #[allow(dead_code)]
     pub success: Option<bool>,
 }
 
diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs
index b233d4f27b..c14b2e190a 100644
--- a/codex-rs/core/src/protocol.rs
+++ b/codex-rs/core/src/protocol.rs
@@ -332,7 +332,7 @@ pub struct TaskCompleteEvent {
     pub last_agent_message: Option<String>,
 }
 
-#[derive(Debug, Clone, Deserialize, Serialize, Default)]
+#[derive(Debug, Clone, Deserialize, Serialize, Default, PartialEq)]
 pub struct TokenUsage {
     pub input_tokens: u64,
     pub cached_input_tokens: Option<u64>,

Review Comments

codex-rs/core/src/chat_completions.rs

@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
 }
 
 impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
+
+#[cfg(test)]
+mod tests {
+    #![allow(clippy::unwrap_used)]
+
+    use super::*;
+    use crate::models::FunctionCallOutputPayload;
+    use crate::models::LocalShellAction;
+    use crate::models::LocalShellExecAction;
+    use crate::models::LocalShellStatus;
+    use futures::StreamExt;
+    use futures::stream;
+
+    /// Helper constructing a minimal assistant text chunk.
+    fn text_chunk(txt: &str) -> ResponseEvent {
+        ResponseEvent::OutputItemDone(ResponseItem::Message {
+            role: "assistant".to_string(),
+            content: vec![ContentItem::OutputText { text: txt.into() }],
+        })
+    }
+
+    #[tokio::test]
+    async fn aggregates_consecutive_message_chunks() {
+        let events = vec![
+            Ok(text_chunk("Hello")),
+            Ok(text_chunk(", world")),
+            Ok(ResponseEvent::Completed {
+                response_id: "r1".to_string(),
+                token_usage: None,
+            }),
+        ];
+
+        let stream = stream::iter(events).aggregate();
+        let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+        assert_eq!(collected.len(), 2, "only final message and Completed");

just assert_eq!() on all of collected?

@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
 }
 
 impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
+
+#[cfg(test)]
+mod tests {
+    #![allow(clippy::unwrap_used)]
+
+    use super::*;
+    use crate::models::FunctionCallOutputPayload;
+    use crate::models::LocalShellAction;
+    use crate::models::LocalShellExecAction;
+    use crate::models::LocalShellStatus;
+    use futures::StreamExt;
+    use futures::stream;
+
+    /// Helper constructing a minimal assistant text chunk.
+    fn text_chunk(txt: &str) -> ResponseEvent {
+        ResponseEvent::OutputItemDone(ResponseItem::Message {
+            role: "assistant".to_string(),
+            content: vec![ContentItem::OutputText { text: txt.into() }],
+        })
+    }
+
+    #[tokio::test]
+    async fn aggregates_consecutive_message_chunks() {
+        let events = vec![
+            Ok(text_chunk("Hello")),
+            Ok(text_chunk(", world")),
+            Ok(ResponseEvent::Completed {
+                response_id: "r1".to_string(),
+                token_usage: None,
+            }),
+        ];
+
+        let stream = stream::iter(events).aggregate();
+        let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+        assert_eq!(collected.len(), 2, "only final message and Completed");
+
+        match &collected[0] {
+            ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
+                let text = match &content[0] {
+                    ContentItem::OutputText { text } => text,
+                    _ => panic!("unexpected content item"),
+                };
+                assert_eq!(text, "Hello, world");
+            }
+            other => panic!("unexpected first event: {other:?}"),
+        }
+
+        assert!(matches!(
+            collected[1],
+            ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r1"
+        ));
+    }
+
+    #[tokio::test]
+    async fn forwards_non_text_items_without_merging() {
+        let func_call = ResponseItem::FunctionCall {
+            name: "shell".to_string(),
+            arguments: "{}".to_string(),
+            call_id: "call1".to_string(),
+        };
+
+        let events = vec![
+            Ok(text_chunk("foo")),
+            Ok(ResponseEvent::OutputItemDone(func_call.clone())),
+            Ok(text_chunk("bar")),
+            Ok(ResponseEvent::Completed {
+                response_id: "r2".to_string(),
+                token_usage: None,
+            }),
+        ];
+
+        let stream = stream::iter(events).aggregate();
+        let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+        assert_eq!(collected.len(), 3);

same here

@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
 }
 
 impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
+
+#[cfg(test)]
+mod tests {
+    #![allow(clippy::unwrap_used)]
+
+    use super::*;
+    use crate::models::FunctionCallOutputPayload;
+    use crate::models::LocalShellAction;
+    use crate::models::LocalShellExecAction;
+    use crate::models::LocalShellStatus;
+    use futures::StreamExt;
+    use futures::stream;
+
+    /// Helper constructing a minimal assistant text chunk.
+    fn text_chunk(txt: &str) -> ResponseEvent {
+        ResponseEvent::OutputItemDone(ResponseItem::Message {
+            role: "assistant".to_string(),
+            content: vec![ContentItem::OutputText { text: txt.into() }],
+        })
+    }
+
+    #[tokio::test]
+    async fn aggregates_consecutive_message_chunks() {
+        let events = vec![
+            Ok(text_chunk("Hello")),
+            Ok(text_chunk(", world")),
+            Ok(ResponseEvent::Completed {
+                response_id: "r1".to_string(),
+                token_usage: None,
+            }),
+        ];
+
+        let stream = stream::iter(events).aggregate();
+        let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+        assert_eq!(collected.len(), 2, "only final message and Completed");
+
+        match &collected[0] {
+            ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
+                let text = match &content[0] {
+                    ContentItem::OutputText { text } => text,
+                    _ => panic!("unexpected content item"),
+                };
+                assert_eq!(text, "Hello, world");
+            }
+            other => panic!("unexpected first event: {other:?}"),
+        }
+
+        assert!(matches!(
+            collected[1],
+            ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r1"
+        ));
+    }
+
+    #[tokio::test]
+    async fn forwards_non_text_items_without_merging() {
+        let func_call = ResponseItem::FunctionCall {
+            name: "shell".to_string(),
+            arguments: "{}".to_string(),
+            call_id: "call1".to_string(),
+        };
+
+        let events = vec![
+            Ok(text_chunk("foo")),
+            Ok(ResponseEvent::OutputItemDone(func_call.clone())),
+            Ok(text_chunk("bar")),
+            Ok(ResponseEvent::Completed {
+                response_id: "r2".to_string(),
+                token_usage: None,
+            }),
+        ];
+
+        let stream = stream::iter(events).aggregate();
+        let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+        assert_eq!(collected.len(), 3);
+
+        // First event should be the function call forwarded directly.
+        assert!(matches!(
+            collected[0],
+            ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. })
+        ));
+
+        // Second is the combined assistant message.
+        match &collected[1] {
+            ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
+                let text = match &content[0] {
+                    ContentItem::OutputText { text } => text,
+                    _ => panic!("unexpected content item"),
+                };
+                assert_eq!(text, "foobar");
+            }
+            other => panic!("unexpected second event: {other:?}"),
+        }
+
+        // Final Completed event.
+        assert!(matches!(
+            collected[2],
+            ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r2"
+        ));
+    }
+
+    #[tokio::test]
+    async fn formats_tool_calls_in_chat_payload() {
+        use serde_json::Value;
+        use std::sync::Arc;
+        use std::sync::Mutex;
+        use wiremock::Mock;
+        use wiremock::MockServer;
+        use wiremock::Request;
+        use wiremock::Respond;
+        use wiremock::ResponseTemplate;
+        use wiremock::matchers::method;
+        use wiremock::matchers::path;
+
+        struct CaptureResponder(Arc<Mutex<Option<Value>>>);
+        impl Respond for CaptureResponder {
+            fn respond(&self, req: &Request) -> ResponseTemplate {
+                let v: Value = serde_json::from_slice(&req.body).unwrap();
+                *self.0.lock().unwrap() = Some(v);
+                ResponseTemplate::new(200)
+                    .insert_header("content-type", "text/event-stream")
+                    .set_body_raw(
+                        "event: response.completed\n\
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n",
+                        "text/event-stream",
+                    )
+            }
+        }
+
+        let server = MockServer::start().await;
+        let captured = Arc::new(Mutex::new(None));
+
+        Mock::given(method("POST"))
+            .and(path("/v1/chat/completions"))
+            .respond_with(CaptureResponder(captured.clone()))
+            .expect(1)
+            .mount(&server)
+            .await;
+
+        unsafe {
+            std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");

At some point, we should really find another way to thread this through so we can eliminate all these unsafe blocks.

@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
 }
 
 impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
+
+#[cfg(test)]
+mod tests {
+    #![allow(clippy::unwrap_used)]
+
+    use super::*;
+    use crate::models::FunctionCallOutputPayload;
+    use crate::models::LocalShellAction;
+    use crate::models::LocalShellExecAction;
+    use crate::models::LocalShellStatus;
+    use futures::StreamExt;
+    use futures::stream;
+
+    /// Helper constructing a minimal assistant text chunk.
+    fn text_chunk(txt: &str) -> ResponseEvent {
+        ResponseEvent::OutputItemDone(ResponseItem::Message {
+            role: "assistant".to_string(),
+            content: vec![ContentItem::OutputText { text: txt.into() }],
+        })
+    }
+
+    #[tokio::test]
+    async fn aggregates_consecutive_message_chunks() {
+        let events = vec![
+            Ok(text_chunk("Hello")),
+            Ok(text_chunk(", world")),
+            Ok(ResponseEvent::Completed {
+                response_id: "r1".to_string(),
+                token_usage: None,
+            }),
+        ];
+
+        let stream = stream::iter(events).aggregate();
+        let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+        assert_eq!(collected.len(), 2, "only final message and Completed");
+
+        match &collected[0] {
+            ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
+                let text = match &content[0] {
+                    ContentItem::OutputText { text } => text,
+                    _ => panic!("unexpected content item"),
+                };
+                assert_eq!(text, "Hello, world");
+            }
+            other => panic!("unexpected first event: {other:?}"),
+        }
+
+        assert!(matches!(
+            collected[1],
+            ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r1"
+        ));
+    }
+
+    #[tokio::test]
+    async fn forwards_non_text_items_without_merging() {
+        let func_call = ResponseItem::FunctionCall {
+            name: "shell".to_string(),
+            arguments: "{}".to_string(),
+            call_id: "call1".to_string(),
+        };
+
+        let events = vec![
+            Ok(text_chunk("foo")),
+            Ok(ResponseEvent::OutputItemDone(func_call.clone())),
+            Ok(text_chunk("bar")),
+            Ok(ResponseEvent::Completed {
+                response_id: "r2".to_string(),
+                token_usage: None,
+            }),
+        ];
+
+        let stream = stream::iter(events).aggregate();
+        let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+        assert_eq!(collected.len(), 3);
+
+        // First event should be the function call forwarded directly.
+        assert!(matches!(
+            collected[0],
+            ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. })
+        ));
+
+        // Second is the combined assistant message.
+        match &collected[1] {
+            ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
+                let text = match &content[0] {
+                    ContentItem::OutputText { text } => text,
+                    _ => panic!("unexpected content item"),
+                };
+                assert_eq!(text, "foobar");
+            }
+            other => panic!("unexpected second event: {other:?}"),
+        }
+
+        // Final Completed event.
+        assert!(matches!(
+            collected[2],
+            ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r2"
+        ));
+    }
+
+    #[tokio::test]
+    async fn formats_tool_calls_in_chat_payload() {
+        use serde_json::Value;
+        use std::sync::Arc;
+        use std::sync::Mutex;
+        use wiremock::Mock;
+        use wiremock::MockServer;
+        use wiremock::Request;
+        use wiremock::Respond;
+        use wiremock::ResponseTemplate;
+        use wiremock::matchers::method;
+        use wiremock::matchers::path;
+
+        struct CaptureResponder(Arc<Mutex<Option<Value>>>);
+        impl Respond for CaptureResponder {
+            fn respond(&self, req: &Request) -> ResponseTemplate {
+                let v: Value = serde_json::from_slice(&req.body).unwrap();
+                *self.0.lock().unwrap() = Some(v);
+                ResponseTemplate::new(200)
+                    .insert_header("content-type", "text/event-stream")
+                    .set_body_raw(
+                        "event: response.completed\n\
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n",
+                        "text/event-stream",
+                    )
+            }
+        }
+
+        let server = MockServer::start().await;
+        let captured = Arc::new(Mutex::new(None));
+
+        Mock::given(method("POST"))
+            .and(path("/v1/chat/completions"))
+            .respond_with(CaptureResponder(captured.clone()))
+            .expect(1)
+            .mount(&server)
+            .await;
+
+        unsafe {
+            std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
+        }
+
+        let provider = ModelProviderInfo {
+            name: "openai".into(),
+            base_url: format!("{}/v1", server.uri()),
+            env_key: Some("PATH".into()),
+            env_key_instructions: None,
+            wire_api: crate::WireApi::Chat,
+            query_params: None,
+            http_headers: None,
+            env_http_headers: None,
+        };
+
+        let mut prompt = Prompt::default();
+        prompt.input.push(ResponseItem::Message {
+            role: "user".into(),
+            content: vec![ContentItem::InputText { text: "hi".into() }],
+        });
+        prompt.input.push(ResponseItem::FunctionCall {
+            name: "shell".into(),
+            arguments: "[]".into(),
+            call_id: "call123".into(),
+        });
+        prompt.input.push(ResponseItem::FunctionCallOutput {
+            call_id: "call123".into(),
+            output: FunctionCallOutputPayload {
+                content: "ok".into(),
+                success: Some(true),
+            },
+        });
+        prompt.input.push(ResponseItem::LocalShellCall {
+            id: Some("ls1".into()),
+            call_id: Some("call456".into()),
+            status: LocalShellStatus::Completed,
+            action: LocalShellAction::Exec(LocalShellExecAction {
+                command: vec!["echo".into(), "hi".into()],
+                timeout_ms: Some(1),
+                working_directory: None,
+                env: None,
+                user: None,
+            }),
+        });
+
+        let client = reqwest::Client::new();
+        let _ = stream_chat_completions(&prompt, "model", &client, &provider)
+            .await
+            .unwrap();
+
+        let body = captured.lock().unwrap().take().unwrap();
+        let messages = body.get("messages").unwrap().as_array().unwrap();

assert_eq!() for body