Files
codex/codex-rs/codex-api/tests/sse_end_to_end.rs
Andrey Mishchenko 355c40ad7e Support end_turn in response.completed (#19610)
Some providers of Responses API forward a model-defined `end_turn`
boolean indicating explicitly the model's indication of whether it would
like to end the turn or to be inferenced again. In this PR, we update
the sampling loop to use this field correctly if it's set. If the field
is not set by the provider, we fall back to the existing sampling logic.
2026-04-25 21:57:42 -07:00

172 lines
4.6 KiB
Rust

use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use codex_api::AuthProvider;
use codex_api::Compression;
use codex_api::Provider;
use codex_api::ResponseEvent;
use codex_api::ResponsesClient;
use codex_client::HttpTransport;
use codex_client::Request;
use codex_client::Response;
use codex_client::StreamResponse;
use codex_client::TransportError;
use codex_protocol::models::ResponseItem;
use futures::StreamExt;
use http::HeaderMap;
use http::StatusCode;
use pretty_assertions::assert_eq;
use serde_json::Value;
#[derive(Clone)]
struct FixtureSseTransport {
body: String,
}
impl FixtureSseTransport {
fn new(body: String) -> Self {
Self { body }
}
}
#[async_trait]
impl HttpTransport for FixtureSseTransport {
async fn execute(&self, _req: Request) -> Result<Response, TransportError> {
Err(TransportError::Build("execute should not run".to_string()))
}
async fn stream(&self, _req: Request) -> Result<StreamResponse, TransportError> {
let stream = futures::stream::iter(vec![Ok::<Bytes, TransportError>(Bytes::from(
self.body.clone(),
))]);
Ok(StreamResponse {
status: StatusCode::OK,
headers: HeaderMap::new(),
bytes: Box::pin(stream),
})
}
}
#[derive(Clone, Default)]
struct NoAuth;
impl AuthProvider for NoAuth {
fn add_auth_headers(&self, _headers: &mut HeaderMap) {}
}
fn provider(name: &str) -> Provider {
Provider {
name: name.to_string(),
base_url: "https://example.com/v1".to_string(),
query_params: None,
headers: HeaderMap::new(),
retry: codex_api::RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(1),
retry_429: false,
retry_5xx: false,
retry_transport: true,
},
stream_idle_timeout: Duration::from_millis(50),
}
}
fn build_responses_body(events: Vec<Value>) -> String {
let mut body = String::new();
for e in events {
let kind = e
.get("type")
.and_then(|v| v.as_str())
.unwrap_or_else(|| panic!("fixture event missing type in SSE fixture: {e}"));
if e.as_object().map(|o| o.len() == 1).unwrap_or(false) {
body.push_str(&format!("event: {kind}\n\n"));
} else {
body.push_str(&format!("event: {kind}\ndata: {e}\n\n"));
}
}
body
}
#[tokio::test]
async fn responses_stream_parses_items_and_completed_end_to_end() -> Result<()> {
let item1 = serde_json::json!({
"type": "response.output_item.done",
"item": {
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "Hello"}]
}
});
let item2 = serde_json::json!({
"type": "response.output_item.done",
"item": {
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "World"}]
}
});
let completed = serde_json::json!({
"type": "response.completed",
"response": { "id": "resp1" }
});
let body = build_responses_body(vec![item1, item2, completed]);
let transport = FixtureSseTransport::new(body);
let client = ResponsesClient::new(transport, provider("openai"), Arc::new(NoAuth));
let mut stream = client
.stream(
serde_json::json!({"echo": true}),
HeaderMap::new(),
Compression::None,
/*turn_state*/ None,
)
.await?;
let mut events = Vec::new();
while let Some(ev) = stream.next().await {
events.push(ev?);
}
let events: Vec<ResponseEvent> = events
.into_iter()
.filter(|ev| !matches!(ev, ResponseEvent::RateLimits(_)))
.collect();
assert_eq!(events.len(), 3);
match &events[0] {
ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }) => {
assert_eq!(role, "assistant");
}
other => panic!("unexpected first event: {other:?}"),
}
match &events[1] {
ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }) => {
assert_eq!(role, "assistant");
}
other => panic!("unexpected second event: {other:?}"),
}
match &events[2] {
ResponseEvent::Completed {
response_id,
token_usage,
end_turn,
} => {
assert_eq!(response_id, "resp1");
assert!(token_usage.is_none());
assert!(end_turn.is_none());
}
other => panic!("unexpected third event: {other:?}"),
}
Ok(())
}