Compare commits

...

3 Commits

Author SHA1 Message Date
aibrahim-oai
77da2843a9 Fix retry tests by setting env once 2025-07-11 13:23:00 -07:00
aibrahim-oai
e92f464830 Fix tests compile warnings 2025-07-11 12:08:56 -07:00
aibrahim-oai
555cddb0d6 test: add unit tests for OpenAI payloads and retries 2025-07-11 11:49:56 -07:00
5 changed files with 364 additions and 0 deletions

35
codex-rs/Cargo.lock generated
View File

@@ -641,6 +641,8 @@ dependencies = [
"maplit",
"mcp-types",
"mime_guess",
"mockito",
"once_cell",
"openssl-sys",
"predicates",
"pretty_assertions",
@@ -848,6 +850,15 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
[[package]]
name = "colored"
version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "compact_str"
version = "0.8.1"
@@ -2545,6 +2556,30 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "mockito"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7760e0e418d9b7e5777c0374009ca4c93861b9066f18cb334a20ce50ab63aa48"
dependencies = [
"assert-json-diff",
"bytes",
"colored",
"futures-util",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-util",
"log",
"rand 0.9.1",
"regex",
"serde_json",
"serde_urlencoded",
"similar",
"tokio",
]
[[package]]
name = "multimap"
version = "0.10.1"

View File

@@ -65,3 +65,5 @@ predicates = "3"
pretty_assertions = "1.4.1"
tempfile = "3"
wiremock = "0.6"
mockito = "1"
once_cell = "1"

View File

@@ -462,3 +462,101 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use crate::config::{Config, ConfigOverrides, ConfigToml};
use crate::{ModelProviderInfo, WireApi};
use mockito::Server;
use once_cell::sync::Lazy;
use serde_json::Value;
use std::sync::Mutex;
static LAST_PAYLOAD: Lazy<Mutex<Option<Value>>> = Lazy::new(|| Mutex::new(None));
fn sample_config(server: &Server, wire_api: WireApi) -> Config {
use tempfile::TempDir;
let codex_home = TempDir::new().unwrap();
let mut cfg = Config::load_from_base_config_with_overrides(
ConfigToml::default(),
ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)
.unwrap();
cfg.model_provider = ModelProviderInfo {
name: "openai".into(),
base_url: format!("{}/v1", server.url()),
env_key: Some("PATH".into()),
env_key_instructions: None,
wire_api,
query_params: None,
http_headers: None,
env_http_headers: None,
};
cfg.model_provider_id = "openai".into();
cfg
}
#[tokio::test]
async fn assembles_chat_payload_correctly() {
let mut server = Server::new_async().await;
let _m = server
.mock("POST", "/v1/chat/completions")
.match_request(|req| {
let body = req.body().unwrap();
let v: Value = serde_json::from_slice(body).unwrap();
*LAST_PAYLOAD.lock().unwrap() = Some(v);
true
})
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body("data: [DONE]\n\n")
.create_async()
.await;
// Ensure we control the retry behaviour for the remainder of this test
// suite. This test is executed first alphabetically so it sets the
// value before any use of the flag.
unsafe {
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "1");
}
let config = sample_config(&server, WireApi::Chat);
let provider = config.model_provider.clone();
let mut prompt = Prompt::default();
prompt.input.push(ResponseItem::Message {
role: "user".into(),
content: vec![ContentItem::InputText { text: "hi".into() }],
});
prompt.input.push(ResponseItem::FunctionCall {
name: "do".into(),
arguments: "{}".into(),
call_id: "call1".into(),
});
prompt.input.push(ResponseItem::FunctionCallOutput {
call_id: "call1".into(),
output: crate::models::FunctionCallOutputPayload {
content: "ok".into(),
success: None,
},
});
// Fire the request; ignore the stream result
let http_client = reqwest::Client::new();
let _ = stream_chat_completions(&prompt, "gpt-4", &http_client, &provider).await;
let payload = LAST_PAYLOAD.lock().unwrap().take().unwrap();
let msgs = payload.get("messages").unwrap().as_array().unwrap();
assert_eq!(msgs[0]["role"], "system");
assert_eq!(msgs[1]["role"], "user");
assert_eq!(msgs[1]["content"], "hi");
assert_eq!(msgs[2]["role"], "assistant");
assert!(msgs[2]["tool_calls"].is_array());
assert_eq!(msgs[3]["role"], "tool");
assert_eq!(msgs[3]["tool_call_id"], "call1");
}
}

View File

@@ -391,3 +391,159 @@ async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
tokio::spawn(process_sse(stream, tx_event));
Ok(ResponseStream { rx_event })
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use crate::client_common::{Prompt, ResponseEvent};
use crate::config::{Config, ConfigOverrides, ConfigToml};
use crate::config_types::{
ReasoningEffort as ReasoningEffortConfig, ReasoningSummary as ReasoningSummaryConfig,
};
use crate::{ModelProviderInfo, WireApi};
use mockito::Server;
use once_cell::sync::Lazy;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tempfile::TempDir;
fn sample_config(server: &Server) -> Config {
let codex_home = TempDir::new().unwrap();
let mut cfg = Config::load_from_base_config_with_overrides(
ConfigToml::default(),
ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)
.unwrap();
cfg.model_provider = ModelProviderInfo {
name: "openai".into(),
base_url: format!("{}/v1", server.url()),
env_key: Some("PATH".into()),
env_key_instructions: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
};
cfg.model_provider_id = "openai".into();
cfg
}
fn sse_completed(id: &str) -> String {
format!(
"event: response.completed\n\
data: {{\"type\":\"response.completed\",\"response\":{{\"id\":\"{id}\",\"output\":[]}}}}\n\n\n"
)
}
#[tokio::test]
async fn retries_then_succeeds() {
let mut server = Server::new_async().await;
let fail = server
.mock("POST", "/v1/responses")
.with_status(500)
.create_async()
.await;
let success = server
.mock("POST", "/v1/responses")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse_completed("ok"))
.create_async()
.await;
let config = sample_config(&server);
let provider = config.model_provider.clone();
let client = ModelClient::new(
Arc::new(config),
provider,
ReasoningEffortConfig::None,
ReasoningSummaryConfig::None,
);
let prompt = Prompt::default();
let mut stream = client.stream(&prompt).await.unwrap();
while let Some(event) = stream.rx_event.recv().await {
if matches!(event.unwrap(), ResponseEvent::Completed { .. }) {
break;
}
}
fail.assert_async().await;
success.assert_async().await;
}
#[tokio::test]
async fn retry_after_header_respected() {
let mut server = Server::new_async().await;
static CALLS: Lazy<AtomicUsize> = Lazy::new(|| AtomicUsize::new(0));
let _m1 = server
.mock("POST", "/v1/responses")
.match_request(|_| {
CALLS.fetch_add(1, Ordering::SeqCst);
true
})
.with_status(429)
.with_header("Retry-After", "1")
.create_async()
.await;
let _m2 = server
.mock("POST", "/v1/responses")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse_completed("ok"))
.create_async()
.await;
let config = sample_config(&server);
let provider = config.model_provider.clone();
let client = ModelClient::new(
Arc::new(config),
provider,
ReasoningEffortConfig::None,
ReasoningSummaryConfig::None,
);
let prompt = Prompt::default();
let _ = client.stream(&prompt).await.unwrap();
assert_eq!(CALLS.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn returns_error_body_for_client_error() {
let mut server = Server::new_async().await;
let _m = server
.mock("POST", "/v1/responses")
.with_status(400)
.with_body("bad request")
.create_async()
.await;
let config = sample_config(&server);
let provider = config.model_provider.clone();
let client = ModelClient::new(
Arc::new(config),
provider,
ReasoningEffortConfig::None,
ReasoningSummaryConfig::None,
);
let prompt = Prompt::default();
let err = match client.stream(&prompt).await {
Ok(_) => panic!("expected error"),
Err(e) => e,
};
match err {
CodexErr::UnexpectedStatus(code, body) => {
assert_eq!(code, StatusCode::BAD_REQUEST);
assert_eq!(body, "bad request");
}
e => panic!("unexpected error: {e:?}"),
}
}
}

View File

@@ -155,3 +155,76 @@ fn mcp_tool_to_openai_tool(
"type": "function",
})
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use maplit::hashmap;
use serde_json::Value;
fn sample_tool() -> mcp_types::Tool {
mcp_types::Tool {
annotations: None,
description: Some("echo".into()),
input_schema: mcp_types::ToolInputSchema {
properties: Some(json!({"msg": {"type": "string"}})),
required: None,
r#type: "object".into(),
},
name: "echo".into(),
}
}
#[test]
fn responses_api_includes_default_and_extra_tools() {
let prompt = Prompt {
extra_tools: hashmap! {
"srv/echo".into() => sample_tool(),
},
..Default::default()
};
let tools = create_tools_json_for_responses_api(&prompt, "gpt-4").unwrap();
assert_eq!(tools.len(), 2);
assert!(
tools
.iter()
.any(|t| t.get("name") == Some(&Value::String("shell".into())))
);
assert!(
tools
.iter()
.any(|t| t.get("name") == Some(&Value::String("srv/echo".into())))
);
}
#[test]
fn codex_models_use_local_shell() {
let prompt = Prompt::default();
let tools = create_tools_json_for_responses_api(&prompt, "codex-test").unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].get("type").unwrap(), "local_shell");
}
#[test]
fn chat_api_tools_are_wrapped_correctly() {
let prompt = Prompt {
extra_tools: hashmap! {
"srv/echo".into() => sample_tool(),
},
..Default::default()
};
let tools = create_tools_json_for_chat_completions_api(&prompt, "gpt-4").unwrap();
assert_eq!(tools.len(), 2);
for tool in tools {
assert_eq!(tool.get("type").unwrap(), "function");
let func = tool.get("function").unwrap().as_object().unwrap();
assert!(func.get("name").is_some());
assert!(func.get("parameters").is_some());
assert!(!func.contains_key("type"));
}
}
}