Compare commits

...

12 Commits

Author SHA1 Message Date
aibrahim-oai
99972b7586 Merge branch 'main' into codex/implement-unit-tests-for-event-aggregation-and-tool-calls 2025-07-16 22:32:28 -07:00
Ahmed Ibrahim
60cf8eba0c adressing reviews 2025-07-14 15:13:01 -07:00
Ahmed Ibrahim
dd879fedb1 adressing reviews 2025-07-14 15:05:49 -07:00
aibrahim-oai
4aff7927fc Merge branch 'main' into codex/implement-unit-tests-for-event-aggregation-and-tool-calls 2025-07-14 14:53:30 -07:00
Ahmed Ibrahim
37bacad560 remove expect dead code 2025-07-11 15:44:09 -07:00
Ahmed Ibrahim
b9b3d2505c remove expect dead code 2025-07-11 15:33:27 -07:00
aibrahim-oai
f13b6fca9c Update models.rs 2025-07-11 15:24:53 -07:00
aibrahim-oai
8ec612275a Annotate test helper with dead_code allow 2025-07-11 15:24:14 -07:00
aibrahim-oai
d55cb449a1 doc: reference aggregation test 2025-07-11 14:38:51 -07:00
aibrahim-oai
7340a07ab1 Format imports 2025-07-11 13:56:11 -07:00
aibrahim-oai
cda7c164d5 Fix clippy by allowing unused success field 2025-07-11 13:22:27 -07:00
aibrahim-oai
fc9f1d171f test: cover chat stream aggregation and tool calls 2025-07-11 12:03:01 -07:00
5 changed files with 323 additions and 12 deletions

View File

@@ -458,6 +458,9 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
/// // event now contains cumulative text
/// }
/// ```
///
/// See [`tests::aggregates_consecutive_message_chunks`] for an example.
/// ```
fn aggregate(self) -> AggregatedChatStream<Self> {
AggregatedChatStream {
inner: self,
@@ -468,3 +471,237 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use crate::models::FunctionCallOutputPayload;
use crate::models::LocalShellAction;
use crate::models::LocalShellExecAction;
use crate::models::LocalShellStatus;
use crate::openai_tools::create_tools_json_for_chat_completions_api;
use futures::StreamExt;
use futures::stream;
use serde_json::json;
/// Helper constructing a minimal assistant text chunk.
fn text_chunk(txt: &str) -> ResponseEvent {
ResponseEvent::OutputItemDone(ResponseItem::Message {
role: "assistant".to_string(),
content: vec![ContentItem::OutputText { text: txt.into() }],
})
}
#[tokio::test]
async fn aggregates_consecutive_message_chunks() {
let events = vec![
Ok(text_chunk("Hello")),
Ok(text_chunk(", world")),
Ok(ResponseEvent::Completed {
response_id: "r1".to_string(),
token_usage: None,
}),
];
let stream = stream::iter(events).aggregate();
let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
let expected = vec![
ResponseEvent::OutputItemDone(ResponseItem::Message {
role: "assistant".into(),
content: vec![ContentItem::OutputText {
text: "Hello, world".into(),
}],
}),
ResponseEvent::Completed {
response_id: "r1".into(),
token_usage: None,
},
];
assert_eq!(
collected, expected,
"aggregated assistant message + Completed"
);
}
#[tokio::test]
async fn forwards_non_text_items_without_merging() {
let func_call = ResponseItem::FunctionCall {
name: "shell".to_string(),
arguments: "{}".to_string(),
call_id: "call1".to_string(),
};
let events = vec![
Ok(text_chunk("foo")),
Ok(ResponseEvent::OutputItemDone(func_call.clone())),
Ok(text_chunk("bar")),
Ok(ResponseEvent::Completed {
response_id: "r2".to_string(),
token_usage: None,
}),
];
let stream = stream::iter(events).aggregate();
let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
let expected = vec![
ResponseEvent::OutputItemDone(func_call.clone()),
ResponseEvent::OutputItemDone(ResponseItem::Message {
role: "assistant".into(),
content: vec![ContentItem::OutputText {
text: "foobar".into(),
}],
}),
ResponseEvent::Completed {
response_id: "r2".into(),
token_usage: None,
},
];
assert_eq!(
collected, expected,
"non-text items forwarded intact; text merged"
);
}
#[tokio::test]
async fn formats_tool_calls_in_chat_payload() {
use std::sync::Arc;
use std::sync::Mutex;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::Request;
use wiremock::Respond;
use wiremock::ResponseTemplate;
use wiremock::matchers::method;
use wiremock::matchers::path;
struct CaptureResponder(Arc<Mutex<Option<serde_json::Value>>>);
impl Respond for CaptureResponder {
fn respond(&self, req: &Request) -> ResponseTemplate {
let v: serde_json::Value = serde_json::from_slice(&req.body).unwrap();
*self.0.lock().unwrap() = Some(v);
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(
"event: response.completed\n\
data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n",
"text/event-stream",
)
}
}
let server = MockServer::start().await;
let captured = Arc::new(Mutex::new(None));
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(CaptureResponder(captured.clone()))
.expect(1)
.mount(&server)
.await;
// Build provider pointing at mock server; no need to mutate global env vars.
let provider = ModelProviderInfo {
name: "openai".into(),
base_url: format!("{}/v1", server.uri()),
env_key: Some("PATH".into()),
env_key_instructions: None,
wire_api: crate::WireApi::Chat,
query_params: None,
http_headers: None,
env_http_headers: None,
};
let mut prompt = Prompt::default();
prompt.input.push(ResponseItem::Message {
role: "user".into(),
content: vec![ContentItem::InputText { text: "hi".into() }],
});
prompt.input.push(ResponseItem::FunctionCall {
name: "shell".into(),
arguments: "[]".into(),
call_id: "call123".into(),
});
prompt.input.push(ResponseItem::FunctionCallOutput {
call_id: "call123".into(),
output: FunctionCallOutputPayload {
content: "ok".into(),
success: Some(true),
},
});
prompt.input.push(ResponseItem::LocalShellCall {
id: Some("ls1".into()),
call_id: Some("call456".into()),
status: LocalShellStatus::Completed,
action: LocalShellAction::Exec(LocalShellExecAction {
command: vec!["echo".into(), "hi".into()],
timeout_ms: Some(1),
working_directory: None,
env: None,
user: None,
}),
});
let client = reqwest::Client::new();
let _ = stream_chat_completions(&prompt, "model", &client, &provider)
.await
.unwrap();
let body = captured.lock().unwrap().take().unwrap();
// Build the expected payload exactly as stream_chat_completions() should.
let full_instructions = prompt.get_full_instructions("model");
let expected_messages = vec![
json!({"role":"system","content":full_instructions}),
json!({"role":"user","content":"hi"}),
json!({
"role":"assistant",
"content":null,
"tool_calls":[{
"id":"call123",
"type":"function",
"function":{
"name":"shell",
"arguments":"[]"
}
}]
}),
json!({
"role":"tool",
"tool_call_id":"call123",
"content":"ok"
}),
json!({
"role":"assistant",
"content":null,
"tool_calls":[{
"id":"ls1",
"type":"local_shell_call",
"status":"completed",
"action":{
"type":"exec",
"command":["echo","hi"],
"timeout_ms":1,
"working_directory":null,
"env":null,
"user":null
}
}]
}),
];
let tools_json = create_tools_json_for_chat_completions_api(&prompt, "model").unwrap();
let expected_body = json!({
"model":"model",
"messages": expected_messages,
"stream": true,
"tools": tools_json,
});
assert_eq!(body, expected_body, "chat payload encoded incorrectly");
}
}

View File

@@ -317,7 +317,7 @@ where
// duplicated `output` array embedded in the `response.completed`
// payload. That produced two concrete issues:
// 1. No realtime streaming the user only saw output after the
// entire turn had finished, which broke the typing UX and
// entire turn had finished, which broke the "typing" UX and
// made longrunning turns look stalled.
// 2. Duplicate `function_call_output` items both the
// individual *and* the completed array were forwarded, which
@@ -390,6 +390,7 @@ where
}
/// used in tests to stream from a text SSE file
#[allow(dead_code)]
async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let f = std::fs::File::open(path.as_ref())?;
@@ -413,6 +414,8 @@ mod tests {
#![allow(clippy::expect_used, clippy::unwrap_used)]
use super::*;
use crate::models::LocalShellAction;
use crate::models::LocalShellStatus;
use serde_json::json;
use tokio::sync::mpsc;
use tokio_test::io::Builder as IoBuilder;
@@ -422,6 +425,17 @@ mod tests {
// Helpers
// ────────────────────────────
/// Build a tiny SSE string with the provided *raw* event chunks (already formatted as
/// `"event: ...\ndata: ..."` lines). Each chunk is separated by a blank line.
fn build_sse(chunks: &[&str]) -> String {
let mut out = String::new();
for c in chunks {
out.push_str(c);
out.push_str("\n\n");
}
out
}
/// Runs the SSE parser on pre-chunked byte slices and returns every event
/// (including any final `Err` from a stream-closure check).
async fn collect_events(chunks: &[&[u8]]) -> Vec<Result<ResponseEvent>> {
@@ -469,6 +483,65 @@ mod tests {
out
}
// ────────────────────────────
// Tests from `implement-unit-tests-for-event-aggregation-and-tool-calls`
// ────────────────────────────
#[tokio::test]
async fn parses_function_and_local_shell_items() {
let func = "event: response.output_item.done\n\
data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call_output\",\"call_id\":\"call1\",\"output\":{\"content\":\"ok\",\"success\":true}}}";
let shell = "event: response.output_item.done\n\
data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"local_shell_call\",\"id\":\"ls1\",\"call_id\":\"call2\",\"status\":\"in_progress\",\"action\":{\"type\":\"exec\",\"command\":[\"echo\",\"hi\"],\"timeout_ms\":123,\"working_directory\":null,\"env\":null,\"user\":null}}}";
let done = "event: response.completed\n\
data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}";
let content = build_sse(&[func, shell, done]);
let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<ResponseEvent>>(8);
let stream = ReaderStream::new(std::io::Cursor::new(content)).map_err(CodexErr::Io);
tokio::spawn(super::process_sse(stream, tx));
// function_call_output
match rx.recv().await.unwrap().unwrap() {
ResponseEvent::OutputItemDone(ResponseItem::FunctionCallOutput { call_id, output }) => {
assert_eq!(call_id, "call1");
assert_eq!(output.content, "ok");
assert_eq!(output.success, Some(true));
}
other => panic!("unexpected first event: {other:?}"),
}
// local_shell_call
match rx.recv().await.unwrap().unwrap() {
ResponseEvent::OutputItemDone(ResponseItem::LocalShellCall {
id,
call_id,
status,
action,
}) => {
assert_eq!(id.as_deref(), Some("ls1"));
assert_eq!(call_id.as_deref(), Some("call2"));
if !matches!(status, LocalShellStatus::InProgress) {
panic!("unexpected status: {status:?}");
}
match action {
LocalShellAction::Exec(act) => {
assert_eq!(act.command, vec!["echo".to_string(), "hi".to_string()]);
assert_eq!(act.timeout_ms, Some(123));
}
}
}
other => panic!("unexpected second event: {other:?}"),
}
// completed
assert!(matches!(
rx.recv().await.unwrap().unwrap(),
ResponseEvent::Completed { response_id, .. } if response_id == "resp"
));
}
// ────────────────────────────
// Tests from `implement-test-for-responses-api-sse-parser`
// ────────────────────────────
@@ -549,6 +622,7 @@ mod tests {
let events = collect_events(&[sse1.as_bytes()]).await;
// We expect the item + a final Err complaining about the missing completed event.
assert_eq!(events.len(), 2);
matches!(events[0], Ok(ResponseEvent::OutputItemDone(_)));

View File

@@ -49,7 +49,7 @@ impl Prompt {
}
}
#[derive(Debug)]
#[derive(Debug, Clone, PartialEq)]
pub enum ResponseEvent {
Created,
OutputItemDone(ResponseItem),

View File

@@ -8,7 +8,7 @@ use serde::ser::Serializer;
use crate::protocol::InputItem;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseInputItem {
Message {
@@ -25,7 +25,7 @@ pub enum ResponseInputItem {
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentItem {
InputText { text: String },
@@ -33,7 +33,7 @@ pub enum ContentItem {
OutputText { text: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseItem {
Message {
@@ -99,7 +99,7 @@ impl From<ResponseInputItem> for ResponseItem {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum LocalShellStatus {
Completed,
@@ -107,13 +107,13 @@ pub enum LocalShellStatus {
Incomplete,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum LocalShellAction {
Exec(LocalShellExecAction),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct LocalShellExecAction {
pub command: Vec<String>,
pub timeout_ms: Option<u64>,
@@ -122,7 +122,7 @@ pub struct LocalShellExecAction {
pub user: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ReasoningItemReasoningSummary {
SummaryText { text: String },
@@ -177,10 +177,10 @@ pub struct ShellToolCallParams {
pub timeout_ms: Option<u64>,
}
#[derive(Deserialize, Debug, Clone)]
#[derive(Deserialize, Debug, Clone, PartialEq)]
pub struct FunctionCallOutputPayload {
pub content: String,
#[expect(dead_code)]
#[allow(dead_code)]
pub success: Option<bool>,
}

View File

@@ -332,7 +332,7 @@ pub struct TaskCompleteEvent {
pub last_agent_message: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
#[derive(Debug, Clone, Deserialize, Serialize, Default, PartialEq)]
pub struct TokenUsage {
pub input_tokens: u64,
pub cached_input_tokens: Option<u64>,