Add unit tests for OpenAI request helpers and retry logic

This commit is contained in:
aibrahim-oai
2025-07-11 14:37:40 -07:00
parent bfeb8c92a5
commit bdc60ef6c7
3 changed files with 437 additions and 0 deletions

View File

@@ -462,3 +462,106 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use crate::WireApi;
use crate::client_common::Prompt;
use crate::config::{Config, ConfigOverrides, ConfigToml};
use crate::models::{ContentItem, FunctionCallOutputPayload, ResponseItem};
use pretty_assertions::assert_eq;
use std::sync::{Arc, Mutex};
use tempfile::TempDir;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
struct CaptureResponder {
body: Arc<Mutex<Option<serde_json::Value>>>,
}
impl Respond for CaptureResponder {
fn respond(&self, req: &Request) -> ResponseTemplate {
let v: serde_json::Value = serde_json::from_slice(&req.body).unwrap();
*self.body.lock().unwrap() = Some(v);
ResponseTemplate::new(200).insert_header("content-type", "text/event-stream")
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn assembles_messages_correctly() {
let server = MockServer::start().await;
let capture = Arc::new(Mutex::new(None));
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(CaptureResponder {
body: capture.clone(),
})
.mount(&server)
.await;
let provider = ModelProviderInfo {
name: "test".into(),
base_url: format!("{}/v1", server.uri()),
env_key: None,
env_key_instructions: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: None,
env_http_headers: None,
};
let codex_home = TempDir::new().unwrap();
let mut config = Config::load_from_base_config_with_overrides(
ConfigToml::default(),
ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)
.unwrap();
config.model_provider = provider.clone();
config.model = "gpt-4".into();
let client = reqwest::Client::new();
let prompt = Prompt {
input: vec![
ResponseItem::Message {
role: "user".into(),
content: vec![ContentItem::InputText { text: "hi".into() }],
},
ResponseItem::Message {
role: "assistant".into(),
content: vec![ContentItem::OutputText { text: "ok".into() }],
},
ResponseItem::FunctionCall {
name: "foo".into(),
arguments: "{}".into(),
call_id: "c1".into(),
},
ResponseItem::FunctionCallOutput {
call_id: "c1".into(),
output: FunctionCallOutputPayload {
content: "out".into(),
success: Some(true),
},
},
],
..Default::default()
};
let _ = stream_chat_completions(&prompt, &config.model, &client, &provider)
.await
.unwrap();
let body = capture.lock().unwrap().take().unwrap();
let messages = body.get("messages").unwrap().as_array().unwrap();
assert_eq!(messages[1]["role"], "user");
assert_eq!(messages[1]["content"], "hi");
assert_eq!(messages[2]["role"], "assistant");
assert_eq!(messages[2]["content"], "ok");
assert_eq!(messages[3]["tool_calls"][0]["function"]["name"], "foo");
assert_eq!(messages[4]["role"], "tool");
assert_eq!(messages[4]["tool_call_id"], "c1");
assert_eq!(messages[4]["content"], "out");
}
}

View File

@@ -391,3 +391,269 @@ async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
tokio::spawn(process_sse(stream, tx_event));
Ok(ResponseStream { rx_event })
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used, clippy::print_stdout)]
use super::*;
use crate::client_common::Prompt;
use crate::config::{Config, ConfigOverrides, ConfigToml};
use futures::StreamExt;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tempfile::TempDir;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
fn default_config(provider: ModelProviderInfo) -> Arc<Config> {
let codex_home = TempDir::new().unwrap();
let mut cfg = Config::load_from_base_config_with_overrides(
ConfigToml::default(),
ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)
.unwrap();
cfg.model_provider = provider.clone();
cfg.model = "gpt-test".into();
Arc::new(cfg)
}
fn sse_completed(id: &str) -> String {
format!(
"event: response.completed\n\
data: {{\"type\":\"response.completed\",\"response\":{{\"id\":\"{id}\",\"output\":[]}}}}\n\n\n"
)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn retries_once_on_server_error() {
if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
println!("Skipping test due to sandbox network restriction");
return;
}
let server = MockServer::start().await;
struct SeqResponder;
impl Respond for SeqResponder {
fn respond(&self, _req: &Request) -> ResponseTemplate {
use std::sync::atomic::{AtomicUsize, Ordering};
static CALLS: AtomicUsize = AtomicUsize::new(0);
let n = CALLS.fetch_add(1, Ordering::SeqCst);
if n == 0 {
ResponseTemplate::new(500)
} else {
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_completed("ok"), "text/event-stream")
}
}
}
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(SeqResponder)
.expect(2)
.mount(&server)
.await;
unsafe { std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "1") };
let provider = ModelProviderInfo {
name: "openai".into(),
base_url: format!("{}/v1", server.uri()),
env_key: Some("PATH".into()),
env_key_instructions: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
};
let config = default_config(provider.clone());
let client = ModelClient::new(
config,
provider,
ReasoningEffortConfig::None,
ReasoningSummaryConfig::None,
);
let prompt = Prompt::default();
let mut stream = client.stream(&prompt).await.unwrap();
while let Some(ev) = stream.next().await {
if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) {
break;
}
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn retry_after_header_delay() {
if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
println!("Skipping test due to sandbox network restriction");
return;
}
let server = MockServer::start().await;
let times = Arc::new(Mutex::new(Vec::new()));
struct SeqResponder {
times: Arc<Mutex<Vec<Instant>>>,
}
impl Respond for SeqResponder {
fn respond(&self, _req: &Request) -> ResponseTemplate {
let mut t = self.times.lock().unwrap();
t.push(Instant::now());
if t.len() == 1 {
ResponseTemplate::new(429).insert_header("retry-after", "1")
} else {
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_completed("ok"), "text/event-stream")
}
}
}
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(SeqResponder {
times: times.clone(),
})
.expect(2)
.mount(&server)
.await;
unsafe { std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "1") };
let provider = ModelProviderInfo {
name: "openai".into(),
base_url: format!("{}/v1", server.uri()),
env_key: Some("PATH".into()),
env_key_instructions: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
};
let config = default_config(provider.clone());
let client = ModelClient::new(
config,
provider,
ReasoningEffortConfig::None,
ReasoningSummaryConfig::None,
);
let prompt = Prompt::default();
let mut stream = client.stream(&prompt).await.unwrap();
while let Some(ev) = stream.next().await {
if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) {
break;
}
}
let times = times.lock().unwrap();
assert!(times.len() == 2);
assert!(times[1] - times[0] >= Duration::from_secs(1));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn retry_backoff_no_header() {
if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
println!("Skipping test due to sandbox network restriction");
return;
}
let server = MockServer::start().await;
let times = Arc::new(Mutex::new(Vec::new()));
struct SeqResponder {
times: Arc<Mutex<Vec<Instant>>>,
}
impl Respond for SeqResponder {
fn respond(&self, _req: &Request) -> ResponseTemplate {
let mut t = self.times.lock().unwrap();
t.push(Instant::now());
if t.len() == 1 {
ResponseTemplate::new(429)
} else {
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_completed("ok"), "text/event-stream")
}
}
}
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(SeqResponder {
times: times.clone(),
})
.expect(2)
.mount(&server)
.await;
unsafe { std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "1") };
let provider = ModelProviderInfo {
name: "openai".into(),
base_url: format!("{}/v1", server.uri()),
env_key: Some("PATH".into()),
env_key_instructions: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
};
let config = default_config(provider.clone());
let client = ModelClient::new(
config,
provider,
ReasoningEffortConfig::None,
ReasoningSummaryConfig::None,
);
let prompt = Prompt::default();
let mut stream = client.stream(&prompt).await.unwrap();
while let Some(ev) = stream.next().await {
if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) {
break;
}
}
let times = times.lock().unwrap();
assert!(times.len() == 2);
assert!(times[1] - times[0] >= Duration::from_millis(100));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn permanent_error_bubbles_body() {
if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
println!("Skipping test due to sandbox network restriction");
return;
}
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(ResponseTemplate::new(400).set_body_string("bad"))
.expect(1)
.mount(&server)
.await;
unsafe { std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0") };
let provider = ModelProviderInfo {
name: "openai".into(),
base_url: format!("{}/v1", server.uri()),
env_key: Some("PATH".into()),
env_key_instructions: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
};
let config = default_config(provider.clone());
let client = ModelClient::new(
config,
provider,
ReasoningEffortConfig::None,
ReasoningSummaryConfig::None,
);
let prompt = Prompt::default();
let res = client.stream(&prompt).await;
match res {
Ok(_) => panic!("expected error"),
Err(err) => match err {
CodexErr::UnexpectedStatus(code, body) => {
assert_eq!(code, StatusCode::BAD_REQUEST);
assert_eq!(body, "bad");
}
other => panic!("unexpected error: {other:?}"),
},
}
}
}

View File

@@ -155,3 +155,71 @@ fn mcp_tool_to_openai_tool(
"type": "function",
})
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use crate::client_common::Prompt;
use mcp_types::{Tool, ToolInputSchema};
fn dummy_tool() -> (String, Tool) {
(
"srv.dummy".to_string(),
Tool {
annotations: None,
description: Some("dummy".into()),
input_schema: ToolInputSchema {
properties: None,
required: None,
r#type: "object".to_string(),
},
name: "dummy".into(),
},
)
}
#[test]
fn responses_includes_default_and_extra() {
let mut prompt = Prompt::default();
let (name, tool) = dummy_tool();
prompt.extra_tools.insert(name.clone(), tool);
let tools = create_tools_json_for_responses_api(&prompt, "gpt-4").unwrap();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0]["type"], "function");
assert_eq!(tools[0]["name"], "shell");
assert!(
tools
.iter()
.any(|t| t.get("name") == Some(&name.clone().into()))
);
}
#[test]
fn responses_codex_model_uses_local_shell() {
let mut prompt = Prompt::default();
let (name, tool) = dummy_tool();
prompt.extra_tools.insert(name, tool);
let tools = create_tools_json_for_responses_api(&prompt, "codex-model").unwrap();
assert_eq!(tools[0]["type"], "local_shell");
}
#[test]
fn chat_completions_tool_format() {
let mut prompt = Prompt::default();
let (name, tool) = dummy_tool();
prompt.extra_tools.insert(name.clone(), tool);
let tools = create_tools_json_for_chat_completions_api(&prompt, "gpt-4").unwrap();
assert_eq!(tools.len(), 2);
for t in tools {
assert_eq!(
t.get("type"),
Some(&serde_json::Value::String("function".into()))
);
let inner = t.get("function").and_then(|v| v.as_object()).unwrap();
assert!(!inner.contains_key("type"));
}
}
}