mirror of
https://github.com/openai/codex.git
synced 2026-04-24 06:35:50 +00:00
Add unit tests for OpenAI request helpers and retry logic
This commit is contained in:
@@ -462,3 +462,106 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
|
||||
}
|
||||
|
||||
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
use super::*;
|
||||
use crate::WireApi;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::config::{Config, ConfigOverrides, ConfigToml};
|
||||
use crate::models::{ContentItem, FunctionCallOutputPayload, ResponseItem};
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tempfile::TempDir;
|
||||
use wiremock::matchers::{method, path};
|
||||
use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
|
||||
|
||||
struct CaptureResponder {
|
||||
body: Arc<Mutex<Option<serde_json::Value>>>,
|
||||
}
|
||||
|
||||
impl Respond for CaptureResponder {
|
||||
fn respond(&self, req: &Request) -> ResponseTemplate {
|
||||
let v: serde_json::Value = serde_json::from_slice(&req.body).unwrap();
|
||||
*self.body.lock().unwrap() = Some(v);
|
||||
ResponseTemplate::new(200).insert_header("content-type", "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn assembles_messages_correctly() {
|
||||
let server = MockServer::start().await;
|
||||
let capture = Arc::new(Mutex::new(None));
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/chat/completions"))
|
||||
.respond_with(CaptureResponder {
|
||||
body: capture.clone(),
|
||||
})
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "test".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = Config::load_from_base_config_with_overrides(
|
||||
ConfigToml::default(),
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)
|
||||
.unwrap();
|
||||
config.model_provider = provider.clone();
|
||||
config.model = "gpt-4".into();
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let prompt = Prompt {
|
||||
input: vec![
|
||||
ResponseItem::Message {
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText { text: "hi".into() }],
|
||||
},
|
||||
ResponseItem::Message {
|
||||
role: "assistant".into(),
|
||||
content: vec![ContentItem::OutputText { text: "ok".into() }],
|
||||
},
|
||||
ResponseItem::FunctionCall {
|
||||
name: "foo".into(),
|
||||
arguments: "{}".into(),
|
||||
call_id: "c1".into(),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "c1".into(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "out".into(),
|
||||
success: Some(true),
|
||||
},
|
||||
},
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let _ = stream_chat_completions(&prompt, &config.model, &client, &provider)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let body = capture.lock().unwrap().take().unwrap();
|
||||
let messages = body.get("messages").unwrap().as_array().unwrap();
|
||||
assert_eq!(messages[1]["role"], "user");
|
||||
assert_eq!(messages[1]["content"], "hi");
|
||||
assert_eq!(messages[2]["role"], "assistant");
|
||||
assert_eq!(messages[2]["content"], "ok");
|
||||
assert_eq!(messages[3]["tool_calls"][0]["function"]["name"], "foo");
|
||||
assert_eq!(messages[4]["role"], "tool");
|
||||
assert_eq!(messages[4]["tool_call_id"], "c1");
|
||||
assert_eq!(messages[4]["content"], "out");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,3 +391,269 @@ async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
|
||||
tokio::spawn(process_sse(stream, tx_event));
|
||||
Ok(ResponseStream { rx_event })
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unwrap_used, clippy::print_stdout)]
|
||||
use super::*;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::config::{Config, ConfigOverrides, ConfigToml};
|
||||
use futures::StreamExt;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
use tempfile::TempDir;
|
||||
use wiremock::matchers::{method, path};
|
||||
use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
|
||||
|
||||
fn default_config(provider: ModelProviderInfo) -> Arc<Config> {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut cfg = Config::load_from_base_config_with_overrides(
|
||||
ConfigToml::default(),
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)
|
||||
.unwrap();
|
||||
cfg.model_provider = provider.clone();
|
||||
cfg.model = "gpt-test".into();
|
||||
Arc::new(cfg)
|
||||
}
|
||||
|
||||
fn sse_completed(id: &str) -> String {
|
||||
format!(
|
||||
"event: response.completed\n\
|
||||
data: {{\"type\":\"response.completed\",\"response\":{{\"id\":\"{id}\",\"output\":[]}}}}\n\n\n"
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn retries_once_on_server_error() {
|
||||
if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!("Skipping test due to sandbox network restriction");
|
||||
return;
|
||||
}
|
||||
let server = MockServer::start().await;
|
||||
struct SeqResponder;
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, _req: &Request) -> ResponseTemplate {
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
static CALLS: AtomicUsize = AtomicUsize::new(0);
|
||||
let n = CALLS.fetch_add(1, Ordering::SeqCst);
|
||||
if n == 0 {
|
||||
ResponseTemplate::new(500)
|
||||
} else {
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("ok"), "text/event-stream")
|
||||
}
|
||||
}
|
||||
}
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(SeqResponder)
|
||||
.expect(2)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
unsafe { std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "1") };
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
};
|
||||
|
||||
let config = default_config(provider.clone());
|
||||
let client = ModelClient::new(
|
||||
config,
|
||||
provider,
|
||||
ReasoningEffortConfig::None,
|
||||
ReasoningSummaryConfig::None,
|
||||
);
|
||||
let prompt = Prompt::default();
|
||||
let mut stream = client.stream(&prompt).await.unwrap();
|
||||
while let Some(ev) = stream.next().await {
|
||||
if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn retry_after_header_delay() {
|
||||
if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!("Skipping test due to sandbox network restriction");
|
||||
return;
|
||||
}
|
||||
let server = MockServer::start().await;
|
||||
let times = Arc::new(Mutex::new(Vec::new()));
|
||||
struct SeqResponder {
|
||||
times: Arc<Mutex<Vec<Instant>>>,
|
||||
}
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, _req: &Request) -> ResponseTemplate {
|
||||
let mut t = self.times.lock().unwrap();
|
||||
t.push(Instant::now());
|
||||
if t.len() == 1 {
|
||||
ResponseTemplate::new(429).insert_header("retry-after", "1")
|
||||
} else {
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("ok"), "text/event-stream")
|
||||
}
|
||||
}
|
||||
}
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(SeqResponder {
|
||||
times: times.clone(),
|
||||
})
|
||||
.expect(2)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
unsafe { std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "1") };
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
};
|
||||
let config = default_config(provider.clone());
|
||||
let client = ModelClient::new(
|
||||
config,
|
||||
provider,
|
||||
ReasoningEffortConfig::None,
|
||||
ReasoningSummaryConfig::None,
|
||||
);
|
||||
let prompt = Prompt::default();
|
||||
let mut stream = client.stream(&prompt).await.unwrap();
|
||||
while let Some(ev) = stream.next().await {
|
||||
if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let times = times.lock().unwrap();
|
||||
assert!(times.len() == 2);
|
||||
assert!(times[1] - times[0] >= Duration::from_secs(1));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn retry_backoff_no_header() {
|
||||
if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!("Skipping test due to sandbox network restriction");
|
||||
return;
|
||||
}
|
||||
let server = MockServer::start().await;
|
||||
let times = Arc::new(Mutex::new(Vec::new()));
|
||||
struct SeqResponder {
|
||||
times: Arc<Mutex<Vec<Instant>>>,
|
||||
}
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, _req: &Request) -> ResponseTemplate {
|
||||
let mut t = self.times.lock().unwrap();
|
||||
t.push(Instant::now());
|
||||
if t.len() == 1 {
|
||||
ResponseTemplate::new(429)
|
||||
} else {
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("ok"), "text/event-stream")
|
||||
}
|
||||
}
|
||||
}
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(SeqResponder {
|
||||
times: times.clone(),
|
||||
})
|
||||
.expect(2)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
unsafe { std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "1") };
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
};
|
||||
let config = default_config(provider.clone());
|
||||
let client = ModelClient::new(
|
||||
config,
|
||||
provider,
|
||||
ReasoningEffortConfig::None,
|
||||
ReasoningSummaryConfig::None,
|
||||
);
|
||||
let prompt = Prompt::default();
|
||||
let mut stream = client.stream(&prompt).await.unwrap();
|
||||
while let Some(ev) = stream.next().await {
|
||||
if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let times = times.lock().unwrap();
|
||||
assert!(times.len() == 2);
|
||||
assert!(times[1] - times[0] >= Duration::from_millis(100));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn permanent_error_bubbles_body() {
|
||||
if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!("Skipping test due to sandbox network restriction");
|
||||
return;
|
||||
}
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(ResponseTemplate::new(400).set_body_string("bad"))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
unsafe { std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0") };
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
};
|
||||
let config = default_config(provider.clone());
|
||||
let client = ModelClient::new(
|
||||
config,
|
||||
provider,
|
||||
ReasoningEffortConfig::None,
|
||||
ReasoningSummaryConfig::None,
|
||||
);
|
||||
let prompt = Prompt::default();
|
||||
let res = client.stream(&prompt).await;
|
||||
match res {
|
||||
Ok(_) => panic!("expected error"),
|
||||
Err(err) => match err {
|
||||
CodexErr::UnexpectedStatus(code, body) => {
|
||||
assert_eq!(code, StatusCode::BAD_REQUEST);
|
||||
assert_eq!(body, "bad");
|
||||
}
|
||||
other => panic!("unexpected error: {other:?}"),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,3 +155,71 @@ fn mcp_tool_to_openai_tool(
|
||||
"type": "function",
|
||||
})
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
use super::*;
|
||||
use crate::client_common::Prompt;
|
||||
use mcp_types::{Tool, ToolInputSchema};
|
||||
|
||||
fn dummy_tool() -> (String, Tool) {
|
||||
(
|
||||
"srv.dummy".to_string(),
|
||||
Tool {
|
||||
annotations: None,
|
||||
description: Some("dummy".into()),
|
||||
input_schema: ToolInputSchema {
|
||||
properties: None,
|
||||
required: None,
|
||||
r#type: "object".to_string(),
|
||||
},
|
||||
name: "dummy".into(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn responses_includes_default_and_extra() {
|
||||
let mut prompt = Prompt::default();
|
||||
let (name, tool) = dummy_tool();
|
||||
prompt.extra_tools.insert(name.clone(), tool);
|
||||
|
||||
let tools = create_tools_json_for_responses_api(&prompt, "gpt-4").unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0]["type"], "function");
|
||||
assert_eq!(tools[0]["name"], "shell");
|
||||
assert!(
|
||||
tools
|
||||
.iter()
|
||||
.any(|t| t.get("name") == Some(&name.clone().into()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn responses_codex_model_uses_local_shell() {
|
||||
let mut prompt = Prompt::default();
|
||||
let (name, tool) = dummy_tool();
|
||||
prompt.extra_tools.insert(name, tool);
|
||||
|
||||
let tools = create_tools_json_for_responses_api(&prompt, "codex-model").unwrap();
|
||||
assert_eq!(tools[0]["type"], "local_shell");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_completions_tool_format() {
|
||||
let mut prompt = Prompt::default();
|
||||
let (name, tool) = dummy_tool();
|
||||
prompt.extra_tools.insert(name.clone(), tool);
|
||||
|
||||
let tools = create_tools_json_for_chat_completions_api(&prompt, "gpt-4").unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
for t in tools {
|
||||
assert_eq!(
|
||||
t.get("type"),
|
||||
Some(&serde_json::Value::String("function".into()))
|
||||
);
|
||||
let inner = t.get("function").and_then(|v| v.as_object()).unwrap();
|
||||
assert!(!inner.contains_key("type"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user