diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index 816fc80f9b..842794340c 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -462,3 +462,106 @@ pub(crate) trait AggregateStreamExt: Stream> + Size } impl AggregateStreamExt for T where T: Stream> + Sized {} +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + use super::*; + use crate::WireApi; + use crate::client_common::Prompt; + use crate::config::{Config, ConfigOverrides, ConfigToml}; + use crate::models::{ContentItem, FunctionCallOutputPayload, ResponseItem}; + use pretty_assertions::assert_eq; + use std::sync::{Arc, Mutex}; + use tempfile::TempDir; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate}; + + struct CaptureResponder { + body: Arc>>, + } + + impl Respond for CaptureResponder { + fn respond(&self, req: &Request) -> ResponseTemplate { + let v: serde_json::Value = serde_json::from_slice(&req.body).unwrap(); + *self.body.lock().unwrap() = Some(v); + ResponseTemplate::new(200).insert_header("content-type", "text/event-stream") + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn assembles_messages_correctly() { + let server = MockServer::start().await; + let capture = Arc::new(Mutex::new(None)); + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(CaptureResponder { + body: capture.clone(), + }) + .mount(&server) + .await; + + let provider = ModelProviderInfo { + name: "test".into(), + base_url: format!("{}/v1", server.uri()), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Chat, + query_params: None, + http_headers: None, + env_http_headers: None, + }; + + let codex_home = TempDir::new().unwrap(); + let mut config = Config::load_from_base_config_with_overrides( + ConfigToml::default(), + ConfigOverrides::default(), + codex_home.path().to_path_buf(), + ) + .unwrap(); + config.model_provider = provider.clone(); + config.model = "gpt-4".into(); + + let client = reqwest::Client::new(); + + let prompt = Prompt { + input: vec![ + ResponseItem::Message { + role: "user".into(), + content: vec![ContentItem::InputText { text: "hi".into() }], + }, + ResponseItem::Message { + role: "assistant".into(), + content: vec![ContentItem::OutputText { text: "ok".into() }], + }, + ResponseItem::FunctionCall { + name: "foo".into(), + arguments: "{}".into(), + call_id: "c1".into(), + }, + ResponseItem::FunctionCallOutput { + call_id: "c1".into(), + output: FunctionCallOutputPayload { + content: "out".into(), + success: Some(true), + }, + }, + ], + ..Default::default() + }; + + let _ = stream_chat_completions(&prompt, &config.model, &client, &provider) + .await + .unwrap(); + + let body = capture.lock().unwrap().take().unwrap(); + let messages = body.get("messages").unwrap().as_array().unwrap(); + assert_eq!(messages[1]["role"], "user"); + assert_eq!(messages[1]["content"], "hi"); + assert_eq!(messages[2]["role"], "assistant"); + assert_eq!(messages[2]["content"], "ok"); + assert_eq!(messages[3]["tool_calls"][0]["function"]["name"], "foo"); + assert_eq!(messages[4]["role"], "tool"); + assert_eq!(messages[4]["tool_call_id"], "c1"); + assert_eq!(messages[4]["content"], "out"); + } +} diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index bd2eeb9457..dfb14219e9 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -391,3 +391,269 @@ async fn stream_from_fixture(path: impl AsRef) -> Result { tokio::spawn(process_sse(stream, tx_event)); Ok(ResponseStream { rx_event }) } +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used, clippy::print_stdout)] + use super::*; + use crate::client_common::Prompt; + use crate::config::{Config, ConfigOverrides, ConfigToml}; + use futures::StreamExt; + use std::sync::{Arc, Mutex}; + use std::time::{Duration, Instant}; + use tempfile::TempDir; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate}; + + fn default_config(provider: ModelProviderInfo) -> Arc { + let codex_home = TempDir::new().unwrap(); + let mut cfg = Config::load_from_base_config_with_overrides( + ConfigToml::default(), + ConfigOverrides::default(), + codex_home.path().to_path_buf(), + ) + .unwrap(); + cfg.model_provider = provider.clone(); + cfg.model = "gpt-test".into(); + Arc::new(cfg) + } + + fn sse_completed(id: &str) -> String { + format!( + "event: response.completed\n\ + data: {{\"type\":\"response.completed\",\"response\":{{\"id\":\"{id}\",\"output\":[]}}}}\n\n\n" + ) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn retries_once_on_server_error() { + if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!("Skipping test due to sandbox network restriction"); + return; + } + let server = MockServer::start().await; + struct SeqResponder; + impl Respond for SeqResponder { + fn respond(&self, _req: &Request) -> ResponseTemplate { + use std::sync::atomic::{AtomicUsize, Ordering}; + static CALLS: AtomicUsize = AtomicUsize::new(0); + let n = CALLS.fetch_add(1, Ordering::SeqCst); + if n == 0 { + ResponseTemplate::new(500) + } else { + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("ok"), "text/event-stream") + } + } + } + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(SeqResponder) + .expect(2) + .mount(&server) + .await; + + unsafe { std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "1") }; + + let provider = ModelProviderInfo { + name: "openai".into(), + base_url: format!("{}/v1", server.uri()), + env_key: Some("PATH".into()), + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + }; + + let config = default_config(provider.clone()); + let client = ModelClient::new( + config, + provider, + ReasoningEffortConfig::None, + ReasoningSummaryConfig::None, + ); + let prompt = Prompt::default(); + let mut stream = client.stream(&prompt).await.unwrap(); + while let Some(ev) = stream.next().await { + if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) { + break; + } + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn retry_after_header_delay() { + if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!("Skipping test due to sandbox network restriction"); + return; + } + let server = MockServer::start().await; + let times = Arc::new(Mutex::new(Vec::new())); + struct SeqResponder { + times: Arc>>, + } + impl Respond for SeqResponder { + fn respond(&self, _req: &Request) -> ResponseTemplate { + let mut t = self.times.lock().unwrap(); + t.push(Instant::now()); + if t.len() == 1 { + ResponseTemplate::new(429).insert_header("retry-after", "1") + } else { + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("ok"), "text/event-stream") + } + } + } + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(SeqResponder { + times: times.clone(), + }) + .expect(2) + .mount(&server) + .await; + + unsafe { std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "1") }; + + let provider = ModelProviderInfo { + name: "openai".into(), + base_url: format!("{}/v1", server.uri()), + env_key: Some("PATH".into()), + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + }; + let config = default_config(provider.clone()); + let client = ModelClient::new( + config, + provider, + ReasoningEffortConfig::None, + ReasoningSummaryConfig::None, + ); + let prompt = Prompt::default(); + let mut stream = client.stream(&prompt).await.unwrap(); + while let Some(ev) = stream.next().await { + if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) { + break; + } + } + let times = times.lock().unwrap(); + assert!(times.len() == 2); + assert!(times[1] - times[0] >= Duration::from_secs(1)); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn retry_backoff_no_header() { + if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!("Skipping test due to sandbox network restriction"); + return; + } + let server = MockServer::start().await; + let times = Arc::new(Mutex::new(Vec::new())); + struct SeqResponder { + times: Arc>>, + } + impl Respond for SeqResponder { + fn respond(&self, _req: &Request) -> ResponseTemplate { + let mut t = self.times.lock().unwrap(); + t.push(Instant::now()); + if t.len() == 1 { + ResponseTemplate::new(429) + } else { + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("ok"), "text/event-stream") + } + } + } + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(SeqResponder { + times: times.clone(), + }) + .expect(2) + .mount(&server) + .await; + + unsafe { std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "1") }; + + let provider = ModelProviderInfo { + name: "openai".into(), + base_url: format!("{}/v1", server.uri()), + env_key: Some("PATH".into()), + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + }; + let config = default_config(provider.clone()); + let client = ModelClient::new( + config, + provider, + ReasoningEffortConfig::None, + ReasoningSummaryConfig::None, + ); + let prompt = Prompt::default(); + let mut stream = client.stream(&prompt).await.unwrap(); + while let Some(ev) = stream.next().await { + if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) { + break; + } + } + let times = times.lock().unwrap(); + assert!(times.len() == 2); + assert!(times[1] - times[0] >= Duration::from_millis(100)); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn permanent_error_bubbles_body() { + if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!("Skipping test due to sandbox network restriction"); + return; + } + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(ResponseTemplate::new(400).set_body_string("bad")) + .expect(1) + .mount(&server) + .await; + + unsafe { std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0") }; + + let provider = ModelProviderInfo { + name: "openai".into(), + base_url: format!("{}/v1", server.uri()), + env_key: Some("PATH".into()), + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + }; + let config = default_config(provider.clone()); + let client = ModelClient::new( + config, + provider, + ReasoningEffortConfig::None, + ReasoningSummaryConfig::None, + ); + let prompt = Prompt::default(); + let res = client.stream(&prompt).await; + match res { + Ok(_) => panic!("expected error"), + Err(err) => match err { + CodexErr::UnexpectedStatus(code, body) => { + assert_eq!(code, StatusCode::BAD_REQUEST); + assert_eq!(body, "bad"); + } + other => panic!("unexpected error: {other:?}"), + }, + } + } +} diff --git a/codex-rs/core/src/openai_tools.rs b/codex-rs/core/src/openai_tools.rs index ef12a629b6..8215335c0d 100644 --- a/codex-rs/core/src/openai_tools.rs +++ b/codex-rs/core/src/openai_tools.rs @@ -155,3 +155,71 @@ fn mcp_tool_to_openai_tool( "type": "function", }) } +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + use super::*; + use crate::client_common::Prompt; + use mcp_types::{Tool, ToolInputSchema}; + + fn dummy_tool() -> (String, Tool) { + ( + "srv.dummy".to_string(), + Tool { + annotations: None, + description: Some("dummy".into()), + input_schema: ToolInputSchema { + properties: None, + required: None, + r#type: "object".to_string(), + }, + name: "dummy".into(), + }, + ) + } + + #[test] + fn responses_includes_default_and_extra() { + let mut prompt = Prompt::default(); + let (name, tool) = dummy_tool(); + prompt.extra_tools.insert(name.clone(), tool); + + let tools = create_tools_json_for_responses_api(&prompt, "gpt-4").unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0]["type"], "function"); + assert_eq!(tools[0]["name"], "shell"); + assert!( + tools + .iter() + .any(|t| t.get("name") == Some(&name.clone().into())) + ); + } + + #[test] + fn responses_codex_model_uses_local_shell() { + let mut prompt = Prompt::default(); + let (name, tool) = dummy_tool(); + prompt.extra_tools.insert(name, tool); + + let tools = create_tools_json_for_responses_api(&prompt, "codex-model").unwrap(); + assert_eq!(tools[0]["type"], "local_shell"); + } + + #[test] + fn chat_completions_tool_format() { + let mut prompt = Prompt::default(); + let (name, tool) = dummy_tool(); + prompt.extra_tools.insert(name.clone(), tool); + + let tools = create_tools_json_for_chat_completions_api(&prompt, "gpt-4").unwrap(); + assert_eq!(tools.len(), 2); + for t in tools { + assert_eq!( + t.get("type"), + Some(&serde_json::Value::String("function".into())) + ); + let inner = t.get("function").and_then(|v| v.as_object()).unwrap(); + assert!(!inner.contains_key("type")); + } + } +}