mirror of
https://github.com/openai/codex.git
synced 2026-04-24 22:54:54 +00:00
35 KiB
35 KiB
PR #1538: Add tests for chat stream aggregation and tool events
- URL: https://github.com/openai/codex/pull/1538
- Author: aibrahim-oai
- Created: 2025-07-11 19:03:03 UTC
- Updated: 2025-07-21 20:58:18 UTC
- Changes: +323/-12, Files changed: 5, Commits: 12
Description
Summary
- unit test AggregatedChatStream to ensure it merges assistant message deltas and forwards other items
- verify parsing of function_call_output and local_shell_call SSE events
- ensure chat request payload encodes tool calls correctly
Testing
cargo test -p codex-core --manifest-path codex-rs/Cargo.tomlcargo test --manifest-path codex-rs/Cargo.toml --all --tests(fails: Sandbox(LandlockRestrict))
https://chatgpt.com/codex/tasks/task_i_687158d61e748321ba5f1631199bd8a4
Full Diff
diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs
index ad7b55952a..8eabcaf342 100644
--- a/codex-rs/core/src/chat_completions.rs
+++ b/codex-rs/core/src/chat_completions.rs
@@ -458,6 +458,9 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
/// // event now contains cumulative text
/// }
/// ```
+ ///
+ /// See [`tests::aggregates_consecutive_message_chunks`] for an example.
+ /// ```
fn aggregate(self) -> AggregatedChatStream<Self> {
AggregatedChatStream {
inner: self,
@@ -468,3 +471,237 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
+
+#[cfg(test)]
+mod tests {
+ #![allow(clippy::unwrap_used)]
+
+ use super::*;
+ use crate::models::FunctionCallOutputPayload;
+ use crate::models::LocalShellAction;
+ use crate::models::LocalShellExecAction;
+ use crate::models::LocalShellStatus;
+ use crate::openai_tools::create_tools_json_for_chat_completions_api;
+ use futures::StreamExt;
+ use futures::stream;
+ use serde_json::json;
+
+ /// Helper constructing a minimal assistant text chunk.
+ fn text_chunk(txt: &str) -> ResponseEvent {
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
+ role: "assistant".to_string(),
+ content: vec![ContentItem::OutputText { text: txt.into() }],
+ })
+ }
+
+ #[tokio::test]
+ async fn aggregates_consecutive_message_chunks() {
+ let events = vec![
+ Ok(text_chunk("Hello")),
+ Ok(text_chunk(", world")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r1".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ let expected = vec![
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
+ role: "assistant".into(),
+ content: vec![ContentItem::OutputText {
+ text: "Hello, world".into(),
+ }],
+ }),
+ ResponseEvent::Completed {
+ response_id: "r1".into(),
+ token_usage: None,
+ },
+ ];
+
+ assert_eq!(
+ collected, expected,
+ "aggregated assistant message + Completed"
+ );
+ }
+
+ #[tokio::test]
+ async fn forwards_non_text_items_without_merging() {
+ let func_call = ResponseItem::FunctionCall {
+ name: "shell".to_string(),
+ arguments: "{}".to_string(),
+ call_id: "call1".to_string(),
+ };
+
+ let events = vec![
+ Ok(text_chunk("foo")),
+ Ok(ResponseEvent::OutputItemDone(func_call.clone())),
+ Ok(text_chunk("bar")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r2".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ let expected = vec![
+ ResponseEvent::OutputItemDone(func_call.clone()),
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
+ role: "assistant".into(),
+ content: vec![ContentItem::OutputText {
+ text: "foobar".into(),
+ }],
+ }),
+ ResponseEvent::Completed {
+ response_id: "r2".into(),
+ token_usage: None,
+ },
+ ];
+
+ assert_eq!(
+ collected, expected,
+ "non-text items forwarded intact; text merged"
+ );
+ }
+
+ #[tokio::test]
+ async fn formats_tool_calls_in_chat_payload() {
+ use std::sync::Arc;
+ use std::sync::Mutex;
+ use wiremock::Mock;
+ use wiremock::MockServer;
+ use wiremock::Request;
+ use wiremock::Respond;
+ use wiremock::ResponseTemplate;
+ use wiremock::matchers::method;
+ use wiremock::matchers::path;
+
+ struct CaptureResponder(Arc<Mutex<Option<serde_json::Value>>>);
+ impl Respond for CaptureResponder {
+ fn respond(&self, req: &Request) -> ResponseTemplate {
+ let v: serde_json::Value = serde_json::from_slice(&req.body).unwrap();
+ *self.0.lock().unwrap() = Some(v);
+ ResponseTemplate::new(200)
+ .insert_header("content-type", "text/event-stream")
+ .set_body_raw(
+ "event: response.completed\n\
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n",
+ "text/event-stream",
+ )
+ }
+ }
+
+ let server = MockServer::start().await;
+ let captured = Arc::new(Mutex::new(None));
+
+ Mock::given(method("POST"))
+ .and(path("/v1/chat/completions"))
+ .respond_with(CaptureResponder(captured.clone()))
+ .expect(1)
+ .mount(&server)
+ .await;
+
+ // Build provider pointing at mock server; no need to mutate global env vars.
+ let provider = ModelProviderInfo {
+ name: "openai".into(),
+ base_url: format!("{}/v1", server.uri()),
+ env_key: Some("PATH".into()),
+ env_key_instructions: None,
+ wire_api: crate::WireApi::Chat,
+ query_params: None,
+ http_headers: None,
+ env_http_headers: None,
+ };
+
+ let mut prompt = Prompt::default();
+ prompt.input.push(ResponseItem::Message {
+ role: "user".into(),
+ content: vec![ContentItem::InputText { text: "hi".into() }],
+ });
+ prompt.input.push(ResponseItem::FunctionCall {
+ name: "shell".into(),
+ arguments: "[]".into(),
+ call_id: "call123".into(),
+ });
+ prompt.input.push(ResponseItem::FunctionCallOutput {
+ call_id: "call123".into(),
+ output: FunctionCallOutputPayload {
+ content: "ok".into(),
+ success: Some(true),
+ },
+ });
+ prompt.input.push(ResponseItem::LocalShellCall {
+ id: Some("ls1".into()),
+ call_id: Some("call456".into()),
+ status: LocalShellStatus::Completed,
+ action: LocalShellAction::Exec(LocalShellExecAction {
+ command: vec!["echo".into(), "hi".into()],
+ timeout_ms: Some(1),
+ working_directory: None,
+ env: None,
+ user: None,
+ }),
+ });
+
+ let client = reqwest::Client::new();
+ let _ = stream_chat_completions(&prompt, "model", &client, &provider)
+ .await
+ .unwrap();
+
+ let body = captured.lock().unwrap().take().unwrap();
+
+ // Build the expected payload exactly as stream_chat_completions() should.
+ let full_instructions = prompt.get_full_instructions("model");
+ let expected_messages = vec![
+ json!({"role":"system","content":full_instructions}),
+ json!({"role":"user","content":"hi"}),
+ json!({
+ "role":"assistant",
+ "content":null,
+ "tool_calls":[{
+ "id":"call123",
+ "type":"function",
+ "function":{
+ "name":"shell",
+ "arguments":"[]"
+ }
+ }]
+ }),
+ json!({
+ "role":"tool",
+ "tool_call_id":"call123",
+ "content":"ok"
+ }),
+ json!({
+ "role":"assistant",
+ "content":null,
+ "tool_calls":[{
+ "id":"ls1",
+ "type":"local_shell_call",
+ "status":"completed",
+ "action":{
+ "type":"exec",
+ "command":["echo","hi"],
+ "timeout_ms":1,
+ "working_directory":null,
+ "env":null,
+ "user":null
+ }
+ }]
+ }),
+ ];
+ let tools_json = create_tools_json_for_chat_completions_api(&prompt, "model").unwrap();
+ let expected_body = json!({
+ "model":"model",
+ "messages": expected_messages,
+ "stream": true,
+ "tools": tools_json,
+ });
+
+ assert_eq!(body, expected_body, "chat payload encoded incorrectly");
+ }
+}
diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs
index 8ec68d02e8..b2ff284fd0 100644
--- a/codex-rs/core/src/client.rs
+++ b/codex-rs/core/src/client.rs
@@ -317,7 +317,7 @@ where
// duplicated `output` array embedded in the `response.completed`
// payload. That produced two concrete issues:
// 1. No real‑time streaming – the user only saw output after the
- // entire turn had finished, which broke the “typing” UX and
+ // entire turn had finished, which broke the "typing" UX and
// made long‑running turns look stalled.
// 2. Duplicate `function_call_output` items – both the
// individual *and* the completed array were forwarded, which
@@ -390,6 +390,7 @@ where
}
/// used in tests to stream from a text SSE file
+#[allow(dead_code)]
async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let f = std::fs::File::open(path.as_ref())?;
@@ -413,6 +414,8 @@ mod tests {
#![allow(clippy::expect_used, clippy::unwrap_used)]
use super::*;
+ use crate::models::LocalShellAction;
+ use crate::models::LocalShellStatus;
use serde_json::json;
use tokio::sync::mpsc;
use tokio_test::io::Builder as IoBuilder;
@@ -422,6 +425,17 @@ mod tests {
// Helpers
// ────────────────────────────
+ /// Build a tiny SSE string with the provided *raw* event chunks (already formatted as
+ /// `"event: ...\ndata: ..."` lines). Each chunk is separated by a blank line.
+ fn build_sse(chunks: &[&str]) -> String {
+ let mut out = String::new();
+ for c in chunks {
+ out.push_str(c);
+ out.push_str("\n\n");
+ }
+ out
+ }
+
/// Runs the SSE parser on pre-chunked byte slices and returns every event
/// (including any final `Err` from a stream-closure check).
async fn collect_events(chunks: &[&[u8]]) -> Vec<Result<ResponseEvent>> {
@@ -469,6 +483,65 @@ mod tests {
out
}
+ // ────────────────────────────
+ // Tests from `implement-unit-tests-for-event-aggregation-and-tool-calls`
+ // ────────────────────────────
+
+ #[tokio::test]
+ async fn parses_function_and_local_shell_items() {
+ let func = "event: response.output_item.done\n\
+data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call_output\",\"call_id\":\"call1\",\"output\":{\"content\":\"ok\",\"success\":true}}}";
+ let shell = "event: response.output_item.done\n\
+data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"local_shell_call\",\"id\":\"ls1\",\"call_id\":\"call2\",\"status\":\"in_progress\",\"action\":{\"type\":\"exec\",\"command\":[\"echo\",\"hi\"],\"timeout_ms\":123,\"working_directory\":null,\"env\":null,\"user\":null}}}";
+ let done = "event: response.completed\n\
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}";
+
+ let content = build_sse(&[func, shell, done]);
+
+ let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<ResponseEvent>>(8);
+ let stream = ReaderStream::new(std::io::Cursor::new(content)).map_err(CodexErr::Io);
+ tokio::spawn(super::process_sse(stream, tx));
+
+ // function_call_output
+ match rx.recv().await.unwrap().unwrap() {
+ ResponseEvent::OutputItemDone(ResponseItem::FunctionCallOutput { call_id, output }) => {
+ assert_eq!(call_id, "call1");
+ assert_eq!(output.content, "ok");
+ assert_eq!(output.success, Some(true));
+ }
+ other => panic!("unexpected first event: {other:?}"),
+ }
+
+ // local_shell_call
+ match rx.recv().await.unwrap().unwrap() {
+ ResponseEvent::OutputItemDone(ResponseItem::LocalShellCall {
+ id,
+ call_id,
+ status,
+ action,
+ }) => {
+ assert_eq!(id.as_deref(), Some("ls1"));
+ assert_eq!(call_id.as_deref(), Some("call2"));
+ if !matches!(status, LocalShellStatus::InProgress) {
+ panic!("unexpected status: {status:?}");
+ }
+ match action {
+ LocalShellAction::Exec(act) => {
+ assert_eq!(act.command, vec!["echo".to_string(), "hi".to_string()]);
+ assert_eq!(act.timeout_ms, Some(123));
+ }
+ }
+ }
+ other => panic!("unexpected second event: {other:?}"),
+ }
+
+ // completed
+ assert!(matches!(
+ rx.recv().await.unwrap().unwrap(),
+ ResponseEvent::Completed { response_id, .. } if response_id == "resp"
+ ));
+ }
+
// ────────────────────────────
// Tests from `implement-test-for-responses-api-sse-parser`
// ────────────────────────────
@@ -549,6 +622,7 @@ mod tests {
let events = collect_events(&[sse1.as_bytes()]).await;
+ // We expect the item + a final Err complaining about the missing completed event.
assert_eq!(events.len(), 2);
matches!(events[0], Ok(ResponseEvent::OutputItemDone(_)));
diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs
index 3e3c2e7efa..6b220d4fff 100644
--- a/codex-rs/core/src/client_common.rs
+++ b/codex-rs/core/src/client_common.rs
@@ -49,7 +49,7 @@ impl Prompt {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone, PartialEq)]
pub enum ResponseEvent {
Created,
OutputItemDone(ResponseItem),
diff --git a/codex-rs/core/src/models.rs b/codex-rs/core/src/models.rs
index 6b392fb19d..26babba715 100644
--- a/codex-rs/core/src/models.rs
+++ b/codex-rs/core/src/models.rs
@@ -8,7 +8,7 @@ use serde::ser::Serializer;
use crate::protocol::InputItem;
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseInputItem {
Message {
@@ -25,7 +25,7 @@ pub enum ResponseInputItem {
},
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentItem {
InputText { text: String },
@@ -33,7 +33,7 @@ pub enum ContentItem {
OutputText { text: String },
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseItem {
Message {
@@ -99,7 +99,7 @@ impl From<ResponseInputItem> for ResponseItem {
}
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum LocalShellStatus {
Completed,
@@ -107,13 +107,13 @@ pub enum LocalShellStatus {
Incomplete,
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum LocalShellAction {
Exec(LocalShellExecAction),
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct LocalShellExecAction {
pub command: Vec<String>,
pub timeout_ms: Option<u64>,
@@ -122,7 +122,7 @@ pub struct LocalShellExecAction {
pub user: Option<String>,
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ReasoningItemReasoningSummary {
SummaryText { text: String },
@@ -177,10 +177,10 @@ pub struct ShellToolCallParams {
pub timeout_ms: Option<u64>,
}
-#[derive(Deserialize, Debug, Clone)]
+#[derive(Deserialize, Debug, Clone, PartialEq)]
pub struct FunctionCallOutputPayload {
pub content: String,
- #[expect(dead_code)]
+ #[allow(dead_code)]
pub success: Option<bool>,
}
diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs
index b233d4f27b..c14b2e190a 100644
--- a/codex-rs/core/src/protocol.rs
+++ b/codex-rs/core/src/protocol.rs
@@ -332,7 +332,7 @@ pub struct TaskCompleteEvent {
pub last_agent_message: Option<String>,
}
-#[derive(Debug, Clone, Deserialize, Serialize, Default)]
+#[derive(Debug, Clone, Deserialize, Serialize, Default, PartialEq)]
pub struct TokenUsage {
pub input_tokens: u64,
pub cached_input_tokens: Option<u64>,
Review Comments
codex-rs/core/src/chat_completions.rs
- Created: 2025-07-12 19:43:31 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202890339
@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
+
+#[cfg(test)]
+mod tests {
+ #![allow(clippy::unwrap_used)]
+
+ use super::*;
+ use crate::models::FunctionCallOutputPayload;
+ use crate::models::LocalShellAction;
+ use crate::models::LocalShellExecAction;
+ use crate::models::LocalShellStatus;
+ use futures::StreamExt;
+ use futures::stream;
+
+ /// Helper constructing a minimal assistant text chunk.
+ fn text_chunk(txt: &str) -> ResponseEvent {
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
+ role: "assistant".to_string(),
+ content: vec![ContentItem::OutputText { text: txt.into() }],
+ })
+ }
+
+ #[tokio::test]
+ async fn aggregates_consecutive_message_chunks() {
+ let events = vec![
+ Ok(text_chunk("Hello")),
+ Ok(text_chunk(", world")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r1".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ assert_eq!(collected.len(), 2, "only final message and Completed");
just
assert_eq!()on all ofcollected?
- Created: 2025-07-12 19:44:08 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202890464
@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
+
+#[cfg(test)]
+mod tests {
+ #![allow(clippy::unwrap_used)]
+
+ use super::*;
+ use crate::models::FunctionCallOutputPayload;
+ use crate::models::LocalShellAction;
+ use crate::models::LocalShellExecAction;
+ use crate::models::LocalShellStatus;
+ use futures::StreamExt;
+ use futures::stream;
+
+ /// Helper constructing a minimal assistant text chunk.
+ fn text_chunk(txt: &str) -> ResponseEvent {
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
+ role: "assistant".to_string(),
+ content: vec![ContentItem::OutputText { text: txt.into() }],
+ })
+ }
+
+ #[tokio::test]
+ async fn aggregates_consecutive_message_chunks() {
+ let events = vec![
+ Ok(text_chunk("Hello")),
+ Ok(text_chunk(", world")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r1".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ assert_eq!(collected.len(), 2, "only final message and Completed");
+
+ match &collected[0] {
+ ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
+ let text = match &content[0] {
+ ContentItem::OutputText { text } => text,
+ _ => panic!("unexpected content item"),
+ };
+ assert_eq!(text, "Hello, world");
+ }
+ other => panic!("unexpected first event: {other:?}"),
+ }
+
+ assert!(matches!(
+ collected[1],
+ ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r1"
+ ));
+ }
+
+ #[tokio::test]
+ async fn forwards_non_text_items_without_merging() {
+ let func_call = ResponseItem::FunctionCall {
+ name: "shell".to_string(),
+ arguments: "{}".to_string(),
+ call_id: "call1".to_string(),
+ };
+
+ let events = vec![
+ Ok(text_chunk("foo")),
+ Ok(ResponseEvent::OutputItemDone(func_call.clone())),
+ Ok(text_chunk("bar")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r2".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ assert_eq!(collected.len(), 3);
same here
- Created: 2025-07-12 19:45:40 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202891601
@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
+
+#[cfg(test)]
+mod tests {
+ #![allow(clippy::unwrap_used)]
+
+ use super::*;
+ use crate::models::FunctionCallOutputPayload;
+ use crate::models::LocalShellAction;
+ use crate::models::LocalShellExecAction;
+ use crate::models::LocalShellStatus;
+ use futures::StreamExt;
+ use futures::stream;
+
+ /// Helper constructing a minimal assistant text chunk.
+ fn text_chunk(txt: &str) -> ResponseEvent {
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
+ role: "assistant".to_string(),
+ content: vec![ContentItem::OutputText { text: txt.into() }],
+ })
+ }
+
+ #[tokio::test]
+ async fn aggregates_consecutive_message_chunks() {
+ let events = vec![
+ Ok(text_chunk("Hello")),
+ Ok(text_chunk(", world")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r1".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ assert_eq!(collected.len(), 2, "only final message and Completed");
+
+ match &collected[0] {
+ ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
+ let text = match &content[0] {
+ ContentItem::OutputText { text } => text,
+ _ => panic!("unexpected content item"),
+ };
+ assert_eq!(text, "Hello, world");
+ }
+ other => panic!("unexpected first event: {other:?}"),
+ }
+
+ assert!(matches!(
+ collected[1],
+ ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r1"
+ ));
+ }
+
+ #[tokio::test]
+ async fn forwards_non_text_items_without_merging() {
+ let func_call = ResponseItem::FunctionCall {
+ name: "shell".to_string(),
+ arguments: "{}".to_string(),
+ call_id: "call1".to_string(),
+ };
+
+ let events = vec![
+ Ok(text_chunk("foo")),
+ Ok(ResponseEvent::OutputItemDone(func_call.clone())),
+ Ok(text_chunk("bar")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r2".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ assert_eq!(collected.len(), 3);
+
+ // First event should be the function call forwarded directly.
+ assert!(matches!(
+ collected[0],
+ ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. })
+ ));
+
+ // Second is the combined assistant message.
+ match &collected[1] {
+ ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
+ let text = match &content[0] {
+ ContentItem::OutputText { text } => text,
+ _ => panic!("unexpected content item"),
+ };
+ assert_eq!(text, "foobar");
+ }
+ other => panic!("unexpected second event: {other:?}"),
+ }
+
+ // Final Completed event.
+ assert!(matches!(
+ collected[2],
+ ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r2"
+ ));
+ }
+
+ #[tokio::test]
+ async fn formats_tool_calls_in_chat_payload() {
+ use serde_json::Value;
+ use std::sync::Arc;
+ use std::sync::Mutex;
+ use wiremock::Mock;
+ use wiremock::MockServer;
+ use wiremock::Request;
+ use wiremock::Respond;
+ use wiremock::ResponseTemplate;
+ use wiremock::matchers::method;
+ use wiremock::matchers::path;
+
+ struct CaptureResponder(Arc<Mutex<Option<Value>>>);
+ impl Respond for CaptureResponder {
+ fn respond(&self, req: &Request) -> ResponseTemplate {
+ let v: Value = serde_json::from_slice(&req.body).unwrap();
+ *self.0.lock().unwrap() = Some(v);
+ ResponseTemplate::new(200)
+ .insert_header("content-type", "text/event-stream")
+ .set_body_raw(
+ "event: response.completed\n\
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n",
+ "text/event-stream",
+ )
+ }
+ }
+
+ let server = MockServer::start().await;
+ let captured = Arc::new(Mutex::new(None));
+
+ Mock::given(method("POST"))
+ .and(path("/v1/chat/completions"))
+ .respond_with(CaptureResponder(captured.clone()))
+ .expect(1)
+ .mount(&server)
+ .await;
+
+ unsafe {
+ std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
At some point, we should really find another way to thread this through so we can eliminate all these
unsafeblocks.
- Created: 2025-07-12 19:46:18 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202892586
@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
+
+#[cfg(test)]
+mod tests {
+ #![allow(clippy::unwrap_used)]
+
+ use super::*;
+ use crate::models::FunctionCallOutputPayload;
+ use crate::models::LocalShellAction;
+ use crate::models::LocalShellExecAction;
+ use crate::models::LocalShellStatus;
+ use futures::StreamExt;
+ use futures::stream;
+
+ /// Helper constructing a minimal assistant text chunk.
+ fn text_chunk(txt: &str) -> ResponseEvent {
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
+ role: "assistant".to_string(),
+ content: vec![ContentItem::OutputText { text: txt.into() }],
+ })
+ }
+
+ #[tokio::test]
+ async fn aggregates_consecutive_message_chunks() {
+ let events = vec![
+ Ok(text_chunk("Hello")),
+ Ok(text_chunk(", world")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r1".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ assert_eq!(collected.len(), 2, "only final message and Completed");
+
+ match &collected[0] {
+ ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
+ let text = match &content[0] {
+ ContentItem::OutputText { text } => text,
+ _ => panic!("unexpected content item"),
+ };
+ assert_eq!(text, "Hello, world");
+ }
+ other => panic!("unexpected first event: {other:?}"),
+ }
+
+ assert!(matches!(
+ collected[1],
+ ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r1"
+ ));
+ }
+
+ #[tokio::test]
+ async fn forwards_non_text_items_without_merging() {
+ let func_call = ResponseItem::FunctionCall {
+ name: "shell".to_string(),
+ arguments: "{}".to_string(),
+ call_id: "call1".to_string(),
+ };
+
+ let events = vec![
+ Ok(text_chunk("foo")),
+ Ok(ResponseEvent::OutputItemDone(func_call.clone())),
+ Ok(text_chunk("bar")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r2".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ assert_eq!(collected.len(), 3);
+
+ // First event should be the function call forwarded directly.
+ assert!(matches!(
+ collected[0],
+ ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. })
+ ));
+
+ // Second is the combined assistant message.
+ match &collected[1] {
+ ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
+ let text = match &content[0] {
+ ContentItem::OutputText { text } => text,
+ _ => panic!("unexpected content item"),
+ };
+ assert_eq!(text, "foobar");
+ }
+ other => panic!("unexpected second event: {other:?}"),
+ }
+
+ // Final Completed event.
+ assert!(matches!(
+ collected[2],
+ ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r2"
+ ));
+ }
+
+ #[tokio::test]
+ async fn formats_tool_calls_in_chat_payload() {
+ use serde_json::Value;
+ use std::sync::Arc;
+ use std::sync::Mutex;
+ use wiremock::Mock;
+ use wiremock::MockServer;
+ use wiremock::Request;
+ use wiremock::Respond;
+ use wiremock::ResponseTemplate;
+ use wiremock::matchers::method;
+ use wiremock::matchers::path;
+
+ struct CaptureResponder(Arc<Mutex<Option<Value>>>);
+ impl Respond for CaptureResponder {
+ fn respond(&self, req: &Request) -> ResponseTemplate {
+ let v: Value = serde_json::from_slice(&req.body).unwrap();
+ *self.0.lock().unwrap() = Some(v);
+ ResponseTemplate::new(200)
+ .insert_header("content-type", "text/event-stream")
+ .set_body_raw(
+ "event: response.completed\n\
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n",
+ "text/event-stream",
+ )
+ }
+ }
+
+ let server = MockServer::start().await;
+ let captured = Arc::new(Mutex::new(None));
+
+ Mock::given(method("POST"))
+ .and(path("/v1/chat/completions"))
+ .respond_with(CaptureResponder(captured.clone()))
+ .expect(1)
+ .mount(&server)
+ .await;
+
+ unsafe {
+ std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
+ }
+
+ let provider = ModelProviderInfo {
+ name: "openai".into(),
+ base_url: format!("{}/v1", server.uri()),
+ env_key: Some("PATH".into()),
+ env_key_instructions: None,
+ wire_api: crate::WireApi::Chat,
+ query_params: None,
+ http_headers: None,
+ env_http_headers: None,
+ };
+
+ let mut prompt = Prompt::default();
+ prompt.input.push(ResponseItem::Message {
+ role: "user".into(),
+ content: vec![ContentItem::InputText { text: "hi".into() }],
+ });
+ prompt.input.push(ResponseItem::FunctionCall {
+ name: "shell".into(),
+ arguments: "[]".into(),
+ call_id: "call123".into(),
+ });
+ prompt.input.push(ResponseItem::FunctionCallOutput {
+ call_id: "call123".into(),
+ output: FunctionCallOutputPayload {
+ content: "ok".into(),
+ success: Some(true),
+ },
+ });
+ prompt.input.push(ResponseItem::LocalShellCall {
+ id: Some("ls1".into()),
+ call_id: Some("call456".into()),
+ status: LocalShellStatus::Completed,
+ action: LocalShellAction::Exec(LocalShellExecAction {
+ command: vec!["echo".into(), "hi".into()],
+ timeout_ms: Some(1),
+ working_directory: None,
+ env: None,
+ user: None,
+ }),
+ });
+
+ let client = reqwest::Client::new();
+ let _ = stream_chat_completions(&prompt, "model", &client, &provider)
+ .await
+ .unwrap();
+
+ let body = captured.lock().unwrap().take().unwrap();
+ let messages = body.get("messages").unwrap().as_array().unwrap();
assert_eq!()forbody