mirror of
https://github.com/openai/codex.git
synced 2026-04-28 08:34:54 +00:00
998 lines
35 KiB
Markdown
998 lines
35 KiB
Markdown
# PR #1538: Add tests for chat stream aggregation and tool events
|
||
|
||
- URL: https://github.com/openai/codex/pull/1538
|
||
- Author: aibrahim-oai
|
||
- Created: 2025-07-11 19:03:03 UTC
|
||
- Updated: 2025-07-21 20:58:18 UTC
|
||
- Changes: +323/-12, Files changed: 5, Commits: 12
|
||
|
||
## Description
|
||
|
||
## Summary
|
||
- unit test AggregatedChatStream to ensure it merges assistant message deltas and forwards other items
|
||
- verify parsing of function_call_output and local_shell_call SSE events
|
||
- ensure chat request payload encodes tool calls correctly
|
||
|
||
## Testing
|
||
- `cargo test -p codex-core --manifest-path codex-rs/Cargo.toml`
|
||
- `cargo test --manifest-path codex-rs/Cargo.toml --all --tests` *(fails: Sandbox(LandlockRestrict))*
|
||
|
||
------
|
||
https://chatgpt.com/codex/tasks/task_i_687158d61e748321ba5f1631199bd8a4
|
||
|
||
## Full Diff
|
||
|
||
```diff
|
||
diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs
|
||
index ad7b55952a..8eabcaf342 100644
|
||
--- a/codex-rs/core/src/chat_completions.rs
|
||
+++ b/codex-rs/core/src/chat_completions.rs
|
||
@@ -458,6 +458,9 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
|
||
/// // event now contains cumulative text
|
||
/// }
|
||
/// ```
|
||
+ ///
|
||
+ /// See [`tests::aggregates_consecutive_message_chunks`] for an example.
|
||
+ /// ```
|
||
fn aggregate(self) -> AggregatedChatStream<Self> {
|
||
AggregatedChatStream {
|
||
inner: self,
|
||
@@ -468,3 +471,237 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
|
||
}
|
||
|
||
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
|
||
+
|
||
+#[cfg(test)]
|
||
+mod tests {
|
||
+ #![allow(clippy::unwrap_used)]
|
||
+
|
||
+ use super::*;
|
||
+ use crate::models::FunctionCallOutputPayload;
|
||
+ use crate::models::LocalShellAction;
|
||
+ use crate::models::LocalShellExecAction;
|
||
+ use crate::models::LocalShellStatus;
|
||
+ use crate::openai_tools::create_tools_json_for_chat_completions_api;
|
||
+ use futures::StreamExt;
|
||
+ use futures::stream;
|
||
+ use serde_json::json;
|
||
+
|
||
+ /// Helper constructing a minimal assistant text chunk.
|
||
+ fn text_chunk(txt: &str) -> ResponseEvent {
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
|
||
+ role: "assistant".to_string(),
|
||
+ content: vec![ContentItem::OutputText { text: txt.into() }],
|
||
+ })
|
||
+ }
|
||
+
|
||
+ #[tokio::test]
|
||
+ async fn aggregates_consecutive_message_chunks() {
|
||
+ let events = vec![
|
||
+ Ok(text_chunk("Hello")),
|
||
+ Ok(text_chunk(", world")),
|
||
+ Ok(ResponseEvent::Completed {
|
||
+ response_id: "r1".to_string(),
|
||
+ token_usage: None,
|
||
+ }),
|
||
+ ];
|
||
+
|
||
+ let stream = stream::iter(events).aggregate();
|
||
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
|
||
+
|
||
+ let expected = vec![
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
|
||
+ role: "assistant".into(),
|
||
+ content: vec![ContentItem::OutputText {
|
||
+ text: "Hello, world".into(),
|
||
+ }],
|
||
+ }),
|
||
+ ResponseEvent::Completed {
|
||
+ response_id: "r1".into(),
|
||
+ token_usage: None,
|
||
+ },
|
||
+ ];
|
||
+
|
||
+ assert_eq!(
|
||
+ collected, expected,
|
||
+ "aggregated assistant message + Completed"
|
||
+ );
|
||
+ }
|
||
+
|
||
+ #[tokio::test]
|
||
+ async fn forwards_non_text_items_without_merging() {
|
||
+ let func_call = ResponseItem::FunctionCall {
|
||
+ name: "shell".to_string(),
|
||
+ arguments: "{}".to_string(),
|
||
+ call_id: "call1".to_string(),
|
||
+ };
|
||
+
|
||
+ let events = vec![
|
||
+ Ok(text_chunk("foo")),
|
||
+ Ok(ResponseEvent::OutputItemDone(func_call.clone())),
|
||
+ Ok(text_chunk("bar")),
|
||
+ Ok(ResponseEvent::Completed {
|
||
+ response_id: "r2".to_string(),
|
||
+ token_usage: None,
|
||
+ }),
|
||
+ ];
|
||
+
|
||
+ let stream = stream::iter(events).aggregate();
|
||
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
|
||
+
|
||
+ let expected = vec![
|
||
+ ResponseEvent::OutputItemDone(func_call.clone()),
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
|
||
+ role: "assistant".into(),
|
||
+ content: vec![ContentItem::OutputText {
|
||
+ text: "foobar".into(),
|
||
+ }],
|
||
+ }),
|
||
+ ResponseEvent::Completed {
|
||
+ response_id: "r2".into(),
|
||
+ token_usage: None,
|
||
+ },
|
||
+ ];
|
||
+
|
||
+ assert_eq!(
|
||
+ collected, expected,
|
||
+ "non-text items forwarded intact; text merged"
|
||
+ );
|
||
+ }
|
||
+
|
||
+ #[tokio::test]
|
||
+ async fn formats_tool_calls_in_chat_payload() {
|
||
+ use std::sync::Arc;
|
||
+ use std::sync::Mutex;
|
||
+ use wiremock::Mock;
|
||
+ use wiremock::MockServer;
|
||
+ use wiremock::Request;
|
||
+ use wiremock::Respond;
|
||
+ use wiremock::ResponseTemplate;
|
||
+ use wiremock::matchers::method;
|
||
+ use wiremock::matchers::path;
|
||
+
|
||
+ struct CaptureResponder(Arc<Mutex<Option<serde_json::Value>>>);
|
||
+ impl Respond for CaptureResponder {
|
||
+ fn respond(&self, req: &Request) -> ResponseTemplate {
|
||
+ let v: serde_json::Value = serde_json::from_slice(&req.body).unwrap();
|
||
+ *self.0.lock().unwrap() = Some(v);
|
||
+ ResponseTemplate::new(200)
|
||
+ .insert_header("content-type", "text/event-stream")
|
||
+ .set_body_raw(
|
||
+ "event: response.completed\n\
|
||
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n",
|
||
+ "text/event-stream",
|
||
+ )
|
||
+ }
|
||
+ }
|
||
+
|
||
+ let server = MockServer::start().await;
|
||
+ let captured = Arc::new(Mutex::new(None));
|
||
+
|
||
+ Mock::given(method("POST"))
|
||
+ .and(path("/v1/chat/completions"))
|
||
+ .respond_with(CaptureResponder(captured.clone()))
|
||
+ .expect(1)
|
||
+ .mount(&server)
|
||
+ .await;
|
||
+
|
||
+ // Build provider pointing at mock server; no need to mutate global env vars.
|
||
+ let provider = ModelProviderInfo {
|
||
+ name: "openai".into(),
|
||
+ base_url: format!("{}/v1", server.uri()),
|
||
+ env_key: Some("PATH".into()),
|
||
+ env_key_instructions: None,
|
||
+ wire_api: crate::WireApi::Chat,
|
||
+ query_params: None,
|
||
+ http_headers: None,
|
||
+ env_http_headers: None,
|
||
+ };
|
||
+
|
||
+ let mut prompt = Prompt::default();
|
||
+ prompt.input.push(ResponseItem::Message {
|
||
+ role: "user".into(),
|
||
+ content: vec![ContentItem::InputText { text: "hi".into() }],
|
||
+ });
|
||
+ prompt.input.push(ResponseItem::FunctionCall {
|
||
+ name: "shell".into(),
|
||
+ arguments: "[]".into(),
|
||
+ call_id: "call123".into(),
|
||
+ });
|
||
+ prompt.input.push(ResponseItem::FunctionCallOutput {
|
||
+ call_id: "call123".into(),
|
||
+ output: FunctionCallOutputPayload {
|
||
+ content: "ok".into(),
|
||
+ success: Some(true),
|
||
+ },
|
||
+ });
|
||
+ prompt.input.push(ResponseItem::LocalShellCall {
|
||
+ id: Some("ls1".into()),
|
||
+ call_id: Some("call456".into()),
|
||
+ status: LocalShellStatus::Completed,
|
||
+ action: LocalShellAction::Exec(LocalShellExecAction {
|
||
+ command: vec!["echo".into(), "hi".into()],
|
||
+ timeout_ms: Some(1),
|
||
+ working_directory: None,
|
||
+ env: None,
|
||
+ user: None,
|
||
+ }),
|
||
+ });
|
||
+
|
||
+ let client = reqwest::Client::new();
|
||
+ let _ = stream_chat_completions(&prompt, "model", &client, &provider)
|
||
+ .await
|
||
+ .unwrap();
|
||
+
|
||
+ let body = captured.lock().unwrap().take().unwrap();
|
||
+
|
||
+ // Build the expected payload exactly as stream_chat_completions() should.
|
||
+ let full_instructions = prompt.get_full_instructions("model");
|
||
+ let expected_messages = vec![
|
||
+ json!({"role":"system","content":full_instructions}),
|
||
+ json!({"role":"user","content":"hi"}),
|
||
+ json!({
|
||
+ "role":"assistant",
|
||
+ "content":null,
|
||
+ "tool_calls":[{
|
||
+ "id":"call123",
|
||
+ "type":"function",
|
||
+ "function":{
|
||
+ "name":"shell",
|
||
+ "arguments":"[]"
|
||
+ }
|
||
+ }]
|
||
+ }),
|
||
+ json!({
|
||
+ "role":"tool",
|
||
+ "tool_call_id":"call123",
|
||
+ "content":"ok"
|
||
+ }),
|
||
+ json!({
|
||
+ "role":"assistant",
|
||
+ "content":null,
|
||
+ "tool_calls":[{
|
||
+ "id":"ls1",
|
||
+ "type":"local_shell_call",
|
||
+ "status":"completed",
|
||
+ "action":{
|
||
+ "type":"exec",
|
||
+ "command":["echo","hi"],
|
||
+ "timeout_ms":1,
|
||
+ "working_directory":null,
|
||
+ "env":null,
|
||
+ "user":null
|
||
+ }
|
||
+ }]
|
||
+ }),
|
||
+ ];
|
||
+ let tools_json = create_tools_json_for_chat_completions_api(&prompt, "model").unwrap();
|
||
+ let expected_body = json!({
|
||
+ "model":"model",
|
||
+ "messages": expected_messages,
|
||
+ "stream": true,
|
||
+ "tools": tools_json,
|
||
+ });
|
||
+
|
||
+ assert_eq!(body, expected_body, "chat payload encoded incorrectly");
|
||
+ }
|
||
+}
|
||
diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs
|
||
index 8ec68d02e8..b2ff284fd0 100644
|
||
--- a/codex-rs/core/src/client.rs
|
||
+++ b/codex-rs/core/src/client.rs
|
||
@@ -317,7 +317,7 @@ where
|
||
// duplicated `output` array embedded in the `response.completed`
|
||
// payload. That produced two concrete issues:
|
||
// 1. No real‑time streaming – the user only saw output after the
|
||
- // entire turn had finished, which broke the “typing” UX and
|
||
+ // entire turn had finished, which broke the "typing" UX and
|
||
// made long‑running turns look stalled.
|
||
// 2. Duplicate `function_call_output` items – both the
|
||
// individual *and* the completed array were forwarded, which
|
||
@@ -390,6 +390,7 @@ where
|
||
}
|
||
|
||
/// used in tests to stream from a text SSE file
|
||
+#[allow(dead_code)]
|
||
async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
|
||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||
let f = std::fs::File::open(path.as_ref())?;
|
||
@@ -413,6 +414,8 @@ mod tests {
|
||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||
|
||
use super::*;
|
||
+ use crate::models::LocalShellAction;
|
||
+ use crate::models::LocalShellStatus;
|
||
use serde_json::json;
|
||
use tokio::sync::mpsc;
|
||
use tokio_test::io::Builder as IoBuilder;
|
||
@@ -422,6 +425,17 @@ mod tests {
|
||
// Helpers
|
||
// ────────────────────────────
|
||
|
||
+ /// Build a tiny SSE string with the provided *raw* event chunks (already formatted as
|
||
+ /// `"event: ...\ndata: ..."` lines). Each chunk is separated by a blank line.
|
||
+ fn build_sse(chunks: &[&str]) -> String {
|
||
+ let mut out = String::new();
|
||
+ for c in chunks {
|
||
+ out.push_str(c);
|
||
+ out.push_str("\n\n");
|
||
+ }
|
||
+ out
|
||
+ }
|
||
+
|
||
/// Runs the SSE parser on pre-chunked byte slices and returns every event
|
||
/// (including any final `Err` from a stream-closure check).
|
||
async fn collect_events(chunks: &[&[u8]]) -> Vec<Result<ResponseEvent>> {
|
||
@@ -469,6 +483,65 @@ mod tests {
|
||
out
|
||
}
|
||
|
||
+ // ────────────────────────────
|
||
+ // Tests from `implement-unit-tests-for-event-aggregation-and-tool-calls`
|
||
+ // ────────────────────────────
|
||
+
|
||
+ #[tokio::test]
|
||
+ async fn parses_function_and_local_shell_items() {
|
||
+ let func = "event: response.output_item.done\n\
|
||
+data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call_output\",\"call_id\":\"call1\",\"output\":{\"content\":\"ok\",\"success\":true}}}";
|
||
+ let shell = "event: response.output_item.done\n\
|
||
+data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"local_shell_call\",\"id\":\"ls1\",\"call_id\":\"call2\",\"status\":\"in_progress\",\"action\":{\"type\":\"exec\",\"command\":[\"echo\",\"hi\"],\"timeout_ms\":123,\"working_directory\":null,\"env\":null,\"user\":null}}}";
|
||
+ let done = "event: response.completed\n\
|
||
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}";
|
||
+
|
||
+ let content = build_sse(&[func, shell, done]);
|
||
+
|
||
+ let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<ResponseEvent>>(8);
|
||
+ let stream = ReaderStream::new(std::io::Cursor::new(content)).map_err(CodexErr::Io);
|
||
+ tokio::spawn(super::process_sse(stream, tx));
|
||
+
|
||
+ // function_call_output
|
||
+ match rx.recv().await.unwrap().unwrap() {
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::FunctionCallOutput { call_id, output }) => {
|
||
+ assert_eq!(call_id, "call1");
|
||
+ assert_eq!(output.content, "ok");
|
||
+ assert_eq!(output.success, Some(true));
|
||
+ }
|
||
+ other => panic!("unexpected first event: {other:?}"),
|
||
+ }
|
||
+
|
||
+ // local_shell_call
|
||
+ match rx.recv().await.unwrap().unwrap() {
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::LocalShellCall {
|
||
+ id,
|
||
+ call_id,
|
||
+ status,
|
||
+ action,
|
||
+ }) => {
|
||
+ assert_eq!(id.as_deref(), Some("ls1"));
|
||
+ assert_eq!(call_id.as_deref(), Some("call2"));
|
||
+ if !matches!(status, LocalShellStatus::InProgress) {
|
||
+ panic!("unexpected status: {status:?}");
|
||
+ }
|
||
+ match action {
|
||
+ LocalShellAction::Exec(act) => {
|
||
+ assert_eq!(act.command, vec!["echo".to_string(), "hi".to_string()]);
|
||
+ assert_eq!(act.timeout_ms, Some(123));
|
||
+ }
|
||
+ }
|
||
+ }
|
||
+ other => panic!("unexpected second event: {other:?}"),
|
||
+ }
|
||
+
|
||
+ // completed
|
||
+ assert!(matches!(
|
||
+ rx.recv().await.unwrap().unwrap(),
|
||
+ ResponseEvent::Completed { response_id, .. } if response_id == "resp"
|
||
+ ));
|
||
+ }
|
||
+
|
||
// ────────────────────────────
|
||
// Tests from `implement-test-for-responses-api-sse-parser`
|
||
// ────────────────────────────
|
||
@@ -549,6 +622,7 @@ mod tests {
|
||
|
||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||
|
||
+ // We expect the item + a final Err complaining about the missing completed event.
|
||
assert_eq!(events.len(), 2);
|
||
|
||
matches!(events[0], Ok(ResponseEvent::OutputItemDone(_)));
|
||
diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs
|
||
index 3e3c2e7efa..6b220d4fff 100644
|
||
--- a/codex-rs/core/src/client_common.rs
|
||
+++ b/codex-rs/core/src/client_common.rs
|
||
@@ -49,7 +49,7 @@ impl Prompt {
|
||
}
|
||
}
|
||
|
||
-#[derive(Debug)]
|
||
+#[derive(Debug, Clone, PartialEq)]
|
||
pub enum ResponseEvent {
|
||
Created,
|
||
OutputItemDone(ResponseItem),
|
||
diff --git a/codex-rs/core/src/models.rs b/codex-rs/core/src/models.rs
|
||
index 6b392fb19d..26babba715 100644
|
||
--- a/codex-rs/core/src/models.rs
|
||
+++ b/codex-rs/core/src/models.rs
|
||
@@ -8,7 +8,7 @@ use serde::ser::Serializer;
|
||
|
||
use crate::protocol::InputItem;
|
||
|
||
-#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||
#[serde(tag = "type", rename_all = "snake_case")]
|
||
pub enum ResponseInputItem {
|
||
Message {
|
||
@@ -25,7 +25,7 @@ pub enum ResponseInputItem {
|
||
},
|
||
}
|
||
|
||
-#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||
#[serde(tag = "type", rename_all = "snake_case")]
|
||
pub enum ContentItem {
|
||
InputText { text: String },
|
||
@@ -33,7 +33,7 @@ pub enum ContentItem {
|
||
OutputText { text: String },
|
||
}
|
||
|
||
-#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||
#[serde(tag = "type", rename_all = "snake_case")]
|
||
pub enum ResponseItem {
|
||
Message {
|
||
@@ -99,7 +99,7 @@ impl From<ResponseInputItem> for ResponseItem {
|
||
}
|
||
}
|
||
|
||
-#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||
#[serde(rename_all = "snake_case")]
|
||
pub enum LocalShellStatus {
|
||
Completed,
|
||
@@ -107,13 +107,13 @@ pub enum LocalShellStatus {
|
||
Incomplete,
|
||
}
|
||
|
||
-#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||
#[serde(tag = "type", rename_all = "snake_case")]
|
||
pub enum LocalShellAction {
|
||
Exec(LocalShellExecAction),
|
||
}
|
||
|
||
-#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||
pub struct LocalShellExecAction {
|
||
pub command: Vec<String>,
|
||
pub timeout_ms: Option<u64>,
|
||
@@ -122,7 +122,7 @@ pub struct LocalShellExecAction {
|
||
pub user: Option<String>,
|
||
}
|
||
|
||
-#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||
#[serde(tag = "type", rename_all = "snake_case")]
|
||
pub enum ReasoningItemReasoningSummary {
|
||
SummaryText { text: String },
|
||
@@ -177,10 +177,10 @@ pub struct ShellToolCallParams {
|
||
pub timeout_ms: Option<u64>,
|
||
}
|
||
|
||
-#[derive(Deserialize, Debug, Clone)]
|
||
+#[derive(Deserialize, Debug, Clone, PartialEq)]
|
||
pub struct FunctionCallOutputPayload {
|
||
pub content: String,
|
||
- #[expect(dead_code)]
|
||
+ #[allow(dead_code)]
|
||
pub success: Option<bool>,
|
||
}
|
||
|
||
diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs
|
||
index b233d4f27b..c14b2e190a 100644
|
||
--- a/codex-rs/core/src/protocol.rs
|
||
+++ b/codex-rs/core/src/protocol.rs
|
||
@@ -332,7 +332,7 @@ pub struct TaskCompleteEvent {
|
||
pub last_agent_message: Option<String>,
|
||
}
|
||
|
||
-#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
||
+#[derive(Debug, Clone, Deserialize, Serialize, Default, PartialEq)]
|
||
pub struct TokenUsage {
|
||
pub input_tokens: u64,
|
||
pub cached_input_tokens: Option<u64>,
|
||
```
|
||
|
||
## Review Comments
|
||
|
||
### codex-rs/core/src/chat_completions.rs
|
||
|
||
- Created: 2025-07-12 19:43:31 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202890339
|
||
|
||
```diff
|
||
@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
|
||
}
|
||
|
||
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
|
||
+
|
||
+#[cfg(test)]
|
||
+mod tests {
|
||
+ #![allow(clippy::unwrap_used)]
|
||
+
|
||
+ use super::*;
|
||
+ use crate::models::FunctionCallOutputPayload;
|
||
+ use crate::models::LocalShellAction;
|
||
+ use crate::models::LocalShellExecAction;
|
||
+ use crate::models::LocalShellStatus;
|
||
+ use futures::StreamExt;
|
||
+ use futures::stream;
|
||
+
|
||
+ /// Helper constructing a minimal assistant text chunk.
|
||
+ fn text_chunk(txt: &str) -> ResponseEvent {
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
|
||
+ role: "assistant".to_string(),
|
||
+ content: vec![ContentItem::OutputText { text: txt.into() }],
|
||
+ })
|
||
+ }
|
||
+
|
||
+ #[tokio::test]
|
||
+ async fn aggregates_consecutive_message_chunks() {
|
||
+ let events = vec![
|
||
+ Ok(text_chunk("Hello")),
|
||
+ Ok(text_chunk(", world")),
|
||
+ Ok(ResponseEvent::Completed {
|
||
+ response_id: "r1".to_string(),
|
||
+ token_usage: None,
|
||
+ }),
|
||
+ ];
|
||
+
|
||
+ let stream = stream::iter(events).aggregate();
|
||
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
|
||
+
|
||
+ assert_eq!(collected.len(), 2, "only final message and Completed");
|
||
```
|
||
|
||
> just `assert_eq!()` on all of `collected`?
|
||
|
||
- Created: 2025-07-12 19:44:08 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202890464
|
||
|
||
```diff
|
||
@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
|
||
}
|
||
|
||
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
|
||
+
|
||
+#[cfg(test)]
|
||
+mod tests {
|
||
+ #![allow(clippy::unwrap_used)]
|
||
+
|
||
+ use super::*;
|
||
+ use crate::models::FunctionCallOutputPayload;
|
||
+ use crate::models::LocalShellAction;
|
||
+ use crate::models::LocalShellExecAction;
|
||
+ use crate::models::LocalShellStatus;
|
||
+ use futures::StreamExt;
|
||
+ use futures::stream;
|
||
+
|
||
+ /// Helper constructing a minimal assistant text chunk.
|
||
+ fn text_chunk(txt: &str) -> ResponseEvent {
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
|
||
+ role: "assistant".to_string(),
|
||
+ content: vec![ContentItem::OutputText { text: txt.into() }],
|
||
+ })
|
||
+ }
|
||
+
|
||
+ #[tokio::test]
|
||
+ async fn aggregates_consecutive_message_chunks() {
|
||
+ let events = vec![
|
||
+ Ok(text_chunk("Hello")),
|
||
+ Ok(text_chunk(", world")),
|
||
+ Ok(ResponseEvent::Completed {
|
||
+ response_id: "r1".to_string(),
|
||
+ token_usage: None,
|
||
+ }),
|
||
+ ];
|
||
+
|
||
+ let stream = stream::iter(events).aggregate();
|
||
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
|
||
+
|
||
+ assert_eq!(collected.len(), 2, "only final message and Completed");
|
||
+
|
||
+ match &collected[0] {
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
|
||
+ let text = match &content[0] {
|
||
+ ContentItem::OutputText { text } => text,
|
||
+ _ => panic!("unexpected content item"),
|
||
+ };
|
||
+ assert_eq!(text, "Hello, world");
|
||
+ }
|
||
+ other => panic!("unexpected first event: {other:?}"),
|
||
+ }
|
||
+
|
||
+ assert!(matches!(
|
||
+ collected[1],
|
||
+ ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r1"
|
||
+ ));
|
||
+ }
|
||
+
|
||
+ #[tokio::test]
|
||
+ async fn forwards_non_text_items_without_merging() {
|
||
+ let func_call = ResponseItem::FunctionCall {
|
||
+ name: "shell".to_string(),
|
||
+ arguments: "{}".to_string(),
|
||
+ call_id: "call1".to_string(),
|
||
+ };
|
||
+
|
||
+ let events = vec![
|
||
+ Ok(text_chunk("foo")),
|
||
+ Ok(ResponseEvent::OutputItemDone(func_call.clone())),
|
||
+ Ok(text_chunk("bar")),
|
||
+ Ok(ResponseEvent::Completed {
|
||
+ response_id: "r2".to_string(),
|
||
+ token_usage: None,
|
||
+ }),
|
||
+ ];
|
||
+
|
||
+ let stream = stream::iter(events).aggregate();
|
||
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
|
||
+
|
||
+ assert_eq!(collected.len(), 3);
|
||
```
|
||
|
||
> same here
|
||
|
||
- Created: 2025-07-12 19:45:40 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202891601
|
||
|
||
```diff
|
||
@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
|
||
}
|
||
|
||
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
|
||
+
|
||
+#[cfg(test)]
|
||
+mod tests {
|
||
+ #![allow(clippy::unwrap_used)]
|
||
+
|
||
+ use super::*;
|
||
+ use crate::models::FunctionCallOutputPayload;
|
||
+ use crate::models::LocalShellAction;
|
||
+ use crate::models::LocalShellExecAction;
|
||
+ use crate::models::LocalShellStatus;
|
||
+ use futures::StreamExt;
|
||
+ use futures::stream;
|
||
+
|
||
+ /// Helper constructing a minimal assistant text chunk.
|
||
+ fn text_chunk(txt: &str) -> ResponseEvent {
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
|
||
+ role: "assistant".to_string(),
|
||
+ content: vec![ContentItem::OutputText { text: txt.into() }],
|
||
+ })
|
||
+ }
|
||
+
|
||
+ #[tokio::test]
|
||
+ async fn aggregates_consecutive_message_chunks() {
|
||
+ let events = vec![
|
||
+ Ok(text_chunk("Hello")),
|
||
+ Ok(text_chunk(", world")),
|
||
+ Ok(ResponseEvent::Completed {
|
||
+ response_id: "r1".to_string(),
|
||
+ token_usage: None,
|
||
+ }),
|
||
+ ];
|
||
+
|
||
+ let stream = stream::iter(events).aggregate();
|
||
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
|
||
+
|
||
+ assert_eq!(collected.len(), 2, "only final message and Completed");
|
||
+
|
||
+ match &collected[0] {
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
|
||
+ let text = match &content[0] {
|
||
+ ContentItem::OutputText { text } => text,
|
||
+ _ => panic!("unexpected content item"),
|
||
+ };
|
||
+ assert_eq!(text, "Hello, world");
|
||
+ }
|
||
+ other => panic!("unexpected first event: {other:?}"),
|
||
+ }
|
||
+
|
||
+ assert!(matches!(
|
||
+ collected[1],
|
||
+ ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r1"
|
||
+ ));
|
||
+ }
|
||
+
|
||
+ #[tokio::test]
|
||
+ async fn forwards_non_text_items_without_merging() {
|
||
+ let func_call = ResponseItem::FunctionCall {
|
||
+ name: "shell".to_string(),
|
||
+ arguments: "{}".to_string(),
|
||
+ call_id: "call1".to_string(),
|
||
+ };
|
||
+
|
||
+ let events = vec![
|
||
+ Ok(text_chunk("foo")),
|
||
+ Ok(ResponseEvent::OutputItemDone(func_call.clone())),
|
||
+ Ok(text_chunk("bar")),
|
||
+ Ok(ResponseEvent::Completed {
|
||
+ response_id: "r2".to_string(),
|
||
+ token_usage: None,
|
||
+ }),
|
||
+ ];
|
||
+
|
||
+ let stream = stream::iter(events).aggregate();
|
||
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
|
||
+
|
||
+ assert_eq!(collected.len(), 3);
|
||
+
|
||
+ // First event should be the function call forwarded directly.
|
||
+ assert!(matches!(
|
||
+ collected[0],
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. })
|
||
+ ));
|
||
+
|
||
+ // Second is the combined assistant message.
|
||
+ match &collected[1] {
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
|
||
+ let text = match &content[0] {
|
||
+ ContentItem::OutputText { text } => text,
|
||
+ _ => panic!("unexpected content item"),
|
||
+ };
|
||
+ assert_eq!(text, "foobar");
|
||
+ }
|
||
+ other => panic!("unexpected second event: {other:?}"),
|
||
+ }
|
||
+
|
||
+ // Final Completed event.
|
||
+ assert!(matches!(
|
||
+ collected[2],
|
||
+ ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r2"
|
||
+ ));
|
||
+ }
|
||
+
|
||
+ #[tokio::test]
|
||
+ async fn formats_tool_calls_in_chat_payload() {
|
||
+ use serde_json::Value;
|
||
+ use std::sync::Arc;
|
||
+ use std::sync::Mutex;
|
||
+ use wiremock::Mock;
|
||
+ use wiremock::MockServer;
|
||
+ use wiremock::Request;
|
||
+ use wiremock::Respond;
|
||
+ use wiremock::ResponseTemplate;
|
||
+ use wiremock::matchers::method;
|
||
+ use wiremock::matchers::path;
|
||
+
|
||
+ struct CaptureResponder(Arc<Mutex<Option<Value>>>);
|
||
+ impl Respond for CaptureResponder {
|
||
+ fn respond(&self, req: &Request) -> ResponseTemplate {
|
||
+ let v: Value = serde_json::from_slice(&req.body).unwrap();
|
||
+ *self.0.lock().unwrap() = Some(v);
|
||
+ ResponseTemplate::new(200)
|
||
+ .insert_header("content-type", "text/event-stream")
|
||
+ .set_body_raw(
|
||
+ "event: response.completed\n\
|
||
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n",
|
||
+ "text/event-stream",
|
||
+ )
|
||
+ }
|
||
+ }
|
||
+
|
||
+ let server = MockServer::start().await;
|
||
+ let captured = Arc::new(Mutex::new(None));
|
||
+
|
||
+ Mock::given(method("POST"))
|
||
+ .and(path("/v1/chat/completions"))
|
||
+ .respond_with(CaptureResponder(captured.clone()))
|
||
+ .expect(1)
|
||
+ .mount(&server)
|
||
+ .await;
|
||
+
|
||
+ unsafe {
|
||
+ std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
||
```
|
||
|
||
> At some point, we should really find another way to thread this through so we can eliminate all these `unsafe` blocks.
|
||
|
||
- Created: 2025-07-12 19:46:18 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202892586
|
||
|
||
```diff
|
||
@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
|
||
}
|
||
|
||
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
|
||
+
|
||
+#[cfg(test)]
|
||
+mod tests {
|
||
+ #![allow(clippy::unwrap_used)]
|
||
+
|
||
+ use super::*;
|
||
+ use crate::models::FunctionCallOutputPayload;
|
||
+ use crate::models::LocalShellAction;
|
||
+ use crate::models::LocalShellExecAction;
|
||
+ use crate::models::LocalShellStatus;
|
||
+ use futures::StreamExt;
|
||
+ use futures::stream;
|
||
+
|
||
+ /// Helper constructing a minimal assistant text chunk.
|
||
+ fn text_chunk(txt: &str) -> ResponseEvent {
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
|
||
+ role: "assistant".to_string(),
|
||
+ content: vec![ContentItem::OutputText { text: txt.into() }],
|
||
+ })
|
||
+ }
|
||
+
|
||
+ #[tokio::test]
|
||
+ async fn aggregates_consecutive_message_chunks() {
|
||
+ let events = vec![
|
||
+ Ok(text_chunk("Hello")),
|
||
+ Ok(text_chunk(", world")),
|
||
+ Ok(ResponseEvent::Completed {
|
||
+ response_id: "r1".to_string(),
|
||
+ token_usage: None,
|
||
+ }),
|
||
+ ];
|
||
+
|
||
+ let stream = stream::iter(events).aggregate();
|
||
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
|
||
+
|
||
+ assert_eq!(collected.len(), 2, "only final message and Completed");
|
||
+
|
||
+ match &collected[0] {
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
|
||
+ let text = match &content[0] {
|
||
+ ContentItem::OutputText { text } => text,
|
||
+ _ => panic!("unexpected content item"),
|
||
+ };
|
||
+ assert_eq!(text, "Hello, world");
|
||
+ }
|
||
+ other => panic!("unexpected first event: {other:?}"),
|
||
+ }
|
||
+
|
||
+ assert!(matches!(
|
||
+ collected[1],
|
||
+ ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r1"
|
||
+ ));
|
||
+ }
|
||
+
|
||
+ #[tokio::test]
|
||
+ async fn forwards_non_text_items_without_merging() {
|
||
+ let func_call = ResponseItem::FunctionCall {
|
||
+ name: "shell".to_string(),
|
||
+ arguments: "{}".to_string(),
|
||
+ call_id: "call1".to_string(),
|
||
+ };
|
||
+
|
||
+ let events = vec![
|
||
+ Ok(text_chunk("foo")),
|
||
+ Ok(ResponseEvent::OutputItemDone(func_call.clone())),
|
||
+ Ok(text_chunk("bar")),
|
||
+ Ok(ResponseEvent::Completed {
|
||
+ response_id: "r2".to_string(),
|
||
+ token_usage: None,
|
||
+ }),
|
||
+ ];
|
||
+
|
||
+ let stream = stream::iter(events).aggregate();
|
||
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
|
||
+
|
||
+ assert_eq!(collected.len(), 3);
|
||
+
|
||
+ // First event should be the function call forwarded directly.
|
||
+ assert!(matches!(
|
||
+ collected[0],
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. })
|
||
+ ));
|
||
+
|
||
+ // Second is the combined assistant message.
|
||
+ match &collected[1] {
|
||
+ ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
|
||
+ let text = match &content[0] {
|
||
+ ContentItem::OutputText { text } => text,
|
||
+ _ => panic!("unexpected content item"),
|
||
+ };
|
||
+ assert_eq!(text, "foobar");
|
||
+ }
|
||
+ other => panic!("unexpected second event: {other:?}"),
|
||
+ }
|
||
+
|
||
+ // Final Completed event.
|
||
+ assert!(matches!(
|
||
+ collected[2],
|
||
+ ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r2"
|
||
+ ));
|
||
+ }
|
||
+
|
||
+ #[tokio::test]
|
||
+ async fn formats_tool_calls_in_chat_payload() {
|
||
+ use serde_json::Value;
|
||
+ use std::sync::Arc;
|
||
+ use std::sync::Mutex;
|
||
+ use wiremock::Mock;
|
||
+ use wiremock::MockServer;
|
||
+ use wiremock::Request;
|
||
+ use wiremock::Respond;
|
||
+ use wiremock::ResponseTemplate;
|
||
+ use wiremock::matchers::method;
|
||
+ use wiremock::matchers::path;
|
||
+
|
||
+ struct CaptureResponder(Arc<Mutex<Option<Value>>>);
|
||
+ impl Respond for CaptureResponder {
|
||
+ fn respond(&self, req: &Request) -> ResponseTemplate {
|
||
+ let v: Value = serde_json::from_slice(&req.body).unwrap();
|
||
+ *self.0.lock().unwrap() = Some(v);
|
||
+ ResponseTemplate::new(200)
|
||
+ .insert_header("content-type", "text/event-stream")
|
||
+ .set_body_raw(
|
||
+ "event: response.completed\n\
|
||
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n",
|
||
+ "text/event-stream",
|
||
+ )
|
||
+ }
|
||
+ }
|
||
+
|
||
+ let server = MockServer::start().await;
|
||
+ let captured = Arc::new(Mutex::new(None));
|
||
+
|
||
+ Mock::given(method("POST"))
|
||
+ .and(path("/v1/chat/completions"))
|
||
+ .respond_with(CaptureResponder(captured.clone()))
|
||
+ .expect(1)
|
||
+ .mount(&server)
|
||
+ .await;
|
||
+
|
||
+ unsafe {
|
||
+ std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
||
+ }
|
||
+
|
||
+ let provider = ModelProviderInfo {
|
||
+ name: "openai".into(),
|
||
+ base_url: format!("{}/v1", server.uri()),
|
||
+ env_key: Some("PATH".into()),
|
||
+ env_key_instructions: None,
|
||
+ wire_api: crate::WireApi::Chat,
|
||
+ query_params: None,
|
||
+ http_headers: None,
|
||
+ env_http_headers: None,
|
||
+ };
|
||
+
|
||
+ let mut prompt = Prompt::default();
|
||
+ prompt.input.push(ResponseItem::Message {
|
||
+ role: "user".into(),
|
||
+ content: vec![ContentItem::InputText { text: "hi".into() }],
|
||
+ });
|
||
+ prompt.input.push(ResponseItem::FunctionCall {
|
||
+ name: "shell".into(),
|
||
+ arguments: "[]".into(),
|
||
+ call_id: "call123".into(),
|
||
+ });
|
||
+ prompt.input.push(ResponseItem::FunctionCallOutput {
|
||
+ call_id: "call123".into(),
|
||
+ output: FunctionCallOutputPayload {
|
||
+ content: "ok".into(),
|
||
+ success: Some(true),
|
||
+ },
|
||
+ });
|
||
+ prompt.input.push(ResponseItem::LocalShellCall {
|
||
+ id: Some("ls1".into()),
|
||
+ call_id: Some("call456".into()),
|
||
+ status: LocalShellStatus::Completed,
|
||
+ action: LocalShellAction::Exec(LocalShellExecAction {
|
||
+ command: vec!["echo".into(), "hi".into()],
|
||
+ timeout_ms: Some(1),
|
||
+ working_directory: None,
|
||
+ env: None,
|
||
+ user: None,
|
||
+ }),
|
||
+ });
|
||
+
|
||
+ let client = reqwest::Client::new();
|
||
+ let _ = stream_chat_completions(&prompt, "model", &client, &provider)
|
||
+ .await
|
||
+ .unwrap();
|
||
+
|
||
+ let body = captured.lock().unwrap().take().unwrap();
|
||
+ let messages = body.get("messages").unwrap().as_array().unwrap();
|
||
```
|
||
|
||
> `assert_eq!()` for `body` |