mirror of
https://github.com/openai/codex.git
synced 2026-02-02 23:13:37 +00:00
Compare commits
16 Commits
alpha-cli
...
codex/impl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
874a6a951c | ||
|
|
b4b255eda5 | ||
|
|
514b3cd8a8 | ||
|
|
b66180a20c | ||
|
|
86dcd0b4ad | ||
|
|
3330466ea3 | ||
|
|
75a1e4b768 | ||
|
|
e4f6b76eca | ||
|
|
86be2a6b58 | ||
|
|
aeb12fc569 | ||
|
|
39f88bc2f8 | ||
|
|
650e08fe49 | ||
|
|
7d316c9f1e | ||
|
|
3d85eabff9 | ||
|
|
3168d214e1 | ||
|
|
bdc60ef6c7 |
@@ -21,7 +21,6 @@ use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
|
||||
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::ResponseItem;
|
||||
@@ -34,6 +33,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
model: &str,
|
||||
client: &reqwest::Client,
|
||||
provider: &ModelProviderInfo,
|
||||
max_retries: u64,
|
||||
) -> Result<ResponseStream> {
|
||||
// Build messages array
|
||||
let mut messages = Vec::<serde_json::Value>::new();
|
||||
@@ -146,7 +146,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
return Err(CodexErr::UnexpectedStatus(status, body));
|
||||
}
|
||||
|
||||
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
|
||||
if attempt > max_retries {
|
||||
return Err(CodexErr::RetryLimit(status));
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
|
||||
if attempt > max_retries {
|
||||
return Err(e.into());
|
||||
}
|
||||
let delay = backoff(attempt);
|
||||
@@ -462,3 +462,134 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
|
||||
}
|
||||
|
||||
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
use super::*;
|
||||
use crate::WireApi;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::config::Config;
|
||||
use crate::config::ConfigOverrides;
|
||||
use crate::config::ConfigToml;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::FunctionCallOutputPayload;
|
||||
use crate::models::ResponseItem;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::Request;
|
||||
use wiremock::Respond;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
struct CaptureResponder {
|
||||
body: Arc<Mutex<Option<serde_json::Value>>>,
|
||||
}
|
||||
|
||||
impl Respond for CaptureResponder {
|
||||
fn respond(&self, req: &Request) -> ResponseTemplate {
|
||||
let v: serde_json::Value = serde_json::from_slice(&req.body).unwrap();
|
||||
*self.body.lock().unwrap() = Some(v);
|
||||
ResponseTemplate::new(200).insert_header("content-type", "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate that `stream_chat_completions` converts our internal `Prompt` into the exact
|
||||
/// Chat Completions JSON payload expected by OpenAI. We build a prompt containing user
|
||||
/// assistant turns, a function call and its output, issue the request against a
|
||||
/// `wiremock::MockServer`, capture the JSON body, and assert that the full `messages` array
|
||||
/// matches a golden value. The test is a pure unit-test; it is skipped automatically when
|
||||
/// the sandbox disables networking.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn assembles_messages_correctly() {
|
||||
// Skip when sandbox networking is disabled (e.g. on CI).
|
||||
if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
return;
|
||||
}
|
||||
let server = MockServer::start().await;
|
||||
let capture = Arc::new(Mutex::new(None));
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/chat/completions"))
|
||||
.respond_with(CaptureResponder {
|
||||
body: capture.clone(),
|
||||
})
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "test".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = Config::load_from_base_config_with_overrides(
|
||||
ConfigToml::default(),
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)
|
||||
.unwrap();
|
||||
config.model_provider = provider.clone();
|
||||
config.model = "gpt-4".into();
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let prompt = Prompt {
|
||||
input: vec![
|
||||
ResponseItem::Message {
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText { text: "hi".into() }],
|
||||
},
|
||||
ResponseItem::Message {
|
||||
role: "assistant".into(),
|
||||
content: vec![ContentItem::OutputText { text: "ok".into() }],
|
||||
},
|
||||
ResponseItem::FunctionCall {
|
||||
name: "foo".into(),
|
||||
arguments: "{}".into(),
|
||||
call_id: "c1".into(),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "c1".into(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "out".into(),
|
||||
success: Some(true),
|
||||
},
|
||||
},
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let _ = stream_chat_completions(
|
||||
&prompt,
|
||||
&config.model,
|
||||
&client,
|
||||
&provider,
|
||||
config.openai_request_max_retries,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let body = capture.lock().unwrap().take().unwrap();
|
||||
let messages = body.get("messages").unwrap();
|
||||
|
||||
let expected = serde_json::json!([
|
||||
{"role":"system","content":prompt.get_full_instructions(&config.model)},
|
||||
{"role":"user","content":"hi"},
|
||||
{"role":"assistant","content":"ok"},
|
||||
{"role":"assistant", "content": null, "tool_calls":[{"id":"c1","type":"function","function":{"name":"foo","arguments":"{}"}}]},
|
||||
{"role":"tool","tool_call_id":"c1","content":"out"}
|
||||
]);
|
||||
|
||||
assert_eq!(messages, &expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,6 @@ use crate::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
||||
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
|
||||
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
@@ -77,6 +76,7 @@ impl ModelClient {
|
||||
&self.config.model,
|
||||
&self.client,
|
||||
&self.provider,
|
||||
self.config.openai_request_max_retries,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -135,6 +135,7 @@ impl ModelClient {
|
||||
);
|
||||
|
||||
let mut attempt = 0;
|
||||
let max_retries = self.config.openai_request_max_retries;
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
@@ -171,7 +172,7 @@ impl ModelClient {
|
||||
return Err(CodexErr::UnexpectedStatus(status, body));
|
||||
}
|
||||
|
||||
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
|
||||
if attempt > max_retries {
|
||||
return Err(CodexErr::RetryLimit(status));
|
||||
}
|
||||
|
||||
@@ -188,7 +189,7 @@ impl ModelClient {
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
|
||||
if attempt > max_retries {
|
||||
return Err(e.into());
|
||||
}
|
||||
let delay = backoff(attempt);
|
||||
@@ -315,7 +316,7 @@ where
|
||||
// duplicated `output` array embedded in the `response.completed`
|
||||
// payload. That produced two concrete issues:
|
||||
// 1. No real‑time streaming – the user only saw output after the
|
||||
// entire turn had finished, which broke the “typing” UX and
|
||||
// entire turn had finished, which broke the "typing" UX and
|
||||
// made long‑running turns look stalled.
|
||||
// 2. Duplicate `function_call_output` items – both the
|
||||
// individual *and* the completed array were forwarded, which
|
||||
@@ -394,17 +395,76 @@ async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
#![allow(clippy::unwrap_used, clippy::print_stdout, clippy::expect_used)]
|
||||
|
||||
use super::*;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::config::Config;
|
||||
use crate::config::ConfigOverrides;
|
||||
use crate::config::ConfigToml;
|
||||
use crate::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use crate::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use futures::StreamExt;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use tempfile::TempDir;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_test::io::Builder as IoBuilder;
|
||||
use tokio_util::io::ReaderStream;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::Request;
|
||||
use wiremock::Respond;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
// ────────────────────────────
|
||||
// Helpers
|
||||
// ────────────────────────────
|
||||
// ─────────────────────────── Helpers ───────────────────────────
|
||||
|
||||
fn default_config(provider: ModelProviderInfo, max_retries: u64) -> Arc<Config> {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut cfg = Config::load_from_base_config_with_overrides(
|
||||
ConfigToml::default(),
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)
|
||||
.unwrap();
|
||||
cfg.model_provider = provider.clone();
|
||||
cfg.model = "gpt-test".into();
|
||||
cfg.openai_request_max_retries = max_retries;
|
||||
Arc::new(cfg)
|
||||
}
|
||||
|
||||
fn create_test_client(server: &MockServer, max_retries: u64) -> ModelClient {
|
||||
let provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
};
|
||||
let config = default_config(provider.clone(), max_retries);
|
||||
ModelClient::new(
|
||||
config,
|
||||
provider,
|
||||
ReasoningEffortConfig::None,
|
||||
ReasoningSummaryConfig::None,
|
||||
)
|
||||
}
|
||||
|
||||
fn sse_completed(id: &str) -> String {
|
||||
format!(
|
||||
"event: response.completed\n\
|
||||
data: {{\"type\":\"response.completed\",\"response\":{{\"id\":\"{id}\",\"output\":[]}}}}\n\n\n"
|
||||
)
|
||||
}
|
||||
|
||||
/// Runs the SSE parser on pre-chunked byte slices and returns every event
|
||||
/// (including any final `Err` from a stream-closure check).
|
||||
@@ -453,9 +513,172 @@ mod tests {
|
||||
out
|
||||
}
|
||||
|
||||
// ────────────────────────────
|
||||
// Tests from `implement-test-for-responses-api-sse-parser`
|
||||
// ────────────────────────────
|
||||
// ───────────── Retry / back-off behaviour tests ─────────────
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn retries_once_on_server_error() {
|
||||
if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
return;
|
||||
}
|
||||
let server = MockServer::start().await;
|
||||
|
||||
struct SeqResponder;
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, _req: &Request) -> ResponseTemplate {
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
static CALLS: AtomicUsize = AtomicUsize::new(0);
|
||||
let n = CALLS.fetch_add(1, Ordering::SeqCst);
|
||||
if n == 0 {
|
||||
ResponseTemplate::new(500)
|
||||
} else {
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("ok"), "text/event-stream")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(SeqResponder)
|
||||
.expect(2)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let client = create_test_client(&server, 1);
|
||||
let prompt = Prompt::default();
|
||||
let mut stream = client.stream(&prompt).await.unwrap();
|
||||
while let Some(ev) = stream.next().await {
|
||||
if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn retry_after_header_delay() {
|
||||
if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
return;
|
||||
}
|
||||
let server = MockServer::start().await;
|
||||
let times = Arc::new(Mutex::new(Vec::new()));
|
||||
|
||||
struct SeqResponder {
|
||||
times: Arc<Mutex<Vec<Instant>>>,
|
||||
}
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, _req: &Request) -> ResponseTemplate {
|
||||
let mut t = self.times.lock().unwrap();
|
||||
t.push(Instant::now());
|
||||
if t.len() == 1 {
|
||||
ResponseTemplate::new(429).insert_header("retry-after", "1")
|
||||
} else {
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("ok"), "text/event-stream")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(SeqResponder {
|
||||
times: times.clone(),
|
||||
})
|
||||
.expect(2)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let client = create_test_client(&server, 1);
|
||||
let prompt = Prompt::default();
|
||||
let mut stream = client.stream(&prompt).await.unwrap();
|
||||
while let Some(ev) = stream.next().await {
|
||||
if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let times = times.lock().unwrap();
|
||||
assert_eq!(times.len(), 2);
|
||||
assert!(times[1] - times[0] >= Duration::from_secs(1));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn retry_backoff_no_header() {
|
||||
if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
return;
|
||||
}
|
||||
let server = MockServer::start().await;
|
||||
let times = Arc::new(Mutex::new(Vec::new()));
|
||||
|
||||
struct SeqResponder {
|
||||
times: Arc<Mutex<Vec<Instant>>>,
|
||||
}
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, _req: &Request) -> ResponseTemplate {
|
||||
let mut t = self.times.lock().unwrap();
|
||||
t.push(Instant::now());
|
||||
if t.len() == 1 {
|
||||
ResponseTemplate::new(429)
|
||||
} else {
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("ok"), "text/event-stream")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(SeqResponder {
|
||||
times: times.clone(),
|
||||
})
|
||||
.expect(2)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let client = create_test_client(&server, 1);
|
||||
let prompt = Prompt::default();
|
||||
let mut stream = client.stream(&prompt).await.unwrap();
|
||||
while let Some(ev) = stream.next().await {
|
||||
if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let times = times.lock().unwrap();
|
||||
assert_eq!(times.len(), 2);
|
||||
assert!(times[1] - times[0] >= Duration::from_millis(100));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn permanent_error_bubbles_body() {
|
||||
if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
return;
|
||||
}
|
||||
let server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(ResponseTemplate::new(400).set_body_string("bad"))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let client = create_test_client(&server, 0);
|
||||
let prompt = Prompt::default();
|
||||
match client.stream(&prompt).await {
|
||||
Ok(_) => panic!("expected error"),
|
||||
Err(CodexErr::UnexpectedStatus(code, body)) => {
|
||||
assert_eq!(code, StatusCode::BAD_REQUEST);
|
||||
assert_eq!(body, "bad");
|
||||
}
|
||||
Err(other) => panic!("unexpected error: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// ───────────────────────────
|
||||
// SSE-parser tests
|
||||
// ───────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn parses_items_and_completed() {
|
||||
@@ -493,17 +716,17 @@ mod tests {
|
||||
|
||||
assert_eq!(events.len(), 3);
|
||||
|
||||
matches!(
|
||||
assert!(matches!(
|
||||
&events[0],
|
||||
Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }))
|
||||
if role == "assistant"
|
||||
);
|
||||
));
|
||||
|
||||
matches!(
|
||||
assert!(matches!(
|
||||
&events[1],
|
||||
Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }))
|
||||
if role == "assistant"
|
||||
);
|
||||
));
|
||||
|
||||
match &events[2] {
|
||||
Ok(ResponseEvent::Completed {
|
||||
@@ -535,7 +758,7 @@ mod tests {
|
||||
|
||||
assert_eq!(events.len(), 2);
|
||||
|
||||
matches!(events[0], Ok(ResponseEvent::OutputItemDone(_)));
|
||||
assert!(matches!(events[0], Ok(ResponseEvent::OutputItemDone(_))));
|
||||
|
||||
match &events[1] {
|
||||
Err(CodexErr::Stream(msg)) => {
|
||||
@@ -545,12 +768,10 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// ────────────────────────────
|
||||
// Table-driven test from `main`
|
||||
// ────────────────────────────
|
||||
// ───────────────────────────
|
||||
// Table-driven event-kind test
|
||||
// ───────────────────────────
|
||||
|
||||
/// Verifies that the adapter produces the right `ResponseEvent` for a
|
||||
/// variety of incoming `type` values.
|
||||
#[tokio::test]
|
||||
async fn table_driven_event_kinds() {
|
||||
struct TestCase {
|
||||
|
||||
@@ -10,6 +10,7 @@ use crate::config_types::ShellEnvironmentPolicyToml;
|
||||
use crate::config_types::Tui;
|
||||
use crate::config_types::UriBasedFileOpener;
|
||||
use crate::flags::OPENAI_DEFAULT_MODEL;
|
||||
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::built_in_model_providers;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
@@ -137,6 +138,9 @@ pub struct Config {
|
||||
|
||||
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
|
||||
pub chatgpt_base_url: String,
|
||||
|
||||
/// Max number of retries for a request to the model.
|
||||
pub openai_request_max_retries: u64,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -321,6 +325,9 @@ pub struct ConfigToml {
|
||||
|
||||
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
|
||||
pub chatgpt_base_url: Option<String>,
|
||||
|
||||
/// Max number of retries for a request to the model.
|
||||
pub openai_request_max_retries: Option<u64>,
|
||||
}
|
||||
|
||||
impl ConfigToml {
|
||||
@@ -353,6 +360,7 @@ pub struct ConfigOverrides {
|
||||
pub model_provider: Option<String>,
|
||||
pub config_profile: Option<String>,
|
||||
pub codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
pub openai_request_max_retries: Option<u64>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -374,6 +382,7 @@ impl Config {
|
||||
model_provider,
|
||||
config_profile: config_profile_key,
|
||||
codex_linux_sandbox_exe,
|
||||
openai_request_max_retries,
|
||||
} = overrides;
|
||||
|
||||
let config_profile = match config_profile_key.as_ref().or(cfg.profile.as_ref()) {
|
||||
@@ -448,6 +457,12 @@ impl Config {
|
||||
.as_ref()
|
||||
.map(|info| info.max_output_tokens)
|
||||
});
|
||||
|
||||
// Resolve the max-retry setting (CLI override > config.toml > env flag default).
|
||||
let resolved_openai_request_max_retries = openai_request_max_retries
|
||||
.or(cfg.openai_request_max_retries)
|
||||
.unwrap_or_else(|| *OPENAI_REQUEST_MAX_RETRIES);
|
||||
|
||||
let config = Self {
|
||||
model,
|
||||
model_context_window,
|
||||
@@ -494,6 +509,8 @@ impl Config {
|
||||
.chatgpt_base_url
|
||||
.or(cfg.chatgpt_base_url)
|
||||
.unwrap_or("https://chatgpt.com/backend-api/".to_string()),
|
||||
|
||||
openai_request_max_retries: resolved_openai_request_max_retries,
|
||||
};
|
||||
Ok(config)
|
||||
}
|
||||
@@ -559,6 +576,7 @@ pub fn log_dir(cfg: &Config) -> std::io::Result<PathBuf> {
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
use crate::config_types::HistoryPersistence;
|
||||
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
|
||||
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -800,6 +818,7 @@ disable_response_storage = true
|
||||
model_reasoning_summary: ReasoningSummary::Detailed,
|
||||
model_supports_reasoning_summaries: false,
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
openai_request_max_retries: *OPENAI_REQUEST_MAX_RETRIES,
|
||||
},
|
||||
o3_profile_config
|
||||
);
|
||||
@@ -846,6 +865,7 @@ disable_response_storage = true
|
||||
model_reasoning_summary: ReasoningSummary::default(),
|
||||
model_supports_reasoning_summaries: false,
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
openai_request_max_retries: *OPENAI_REQUEST_MAX_RETRIES,
|
||||
};
|
||||
|
||||
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
|
||||
@@ -907,6 +927,7 @@ disable_response_storage = true
|
||||
model_reasoning_summary: ReasoningSummary::default(),
|
||||
model_supports_reasoning_summaries: false,
|
||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||
openai_request_max_retries: *OPENAI_REQUEST_MAX_RETRIES,
|
||||
};
|
||||
|
||||
assert_eq!(expected_zdr_profile_config, zdr_profile_config);
|
||||
|
||||
@@ -155,3 +155,89 @@ fn mcp_tool_to_openai_tool(
|
||||
"type": "function",
|
||||
})
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
use super::*;
|
||||
use crate::client_common::Prompt;
|
||||
use mcp_types::Tool;
|
||||
use mcp_types::ToolInputSchema;
|
||||
|
||||
fn dummy_tool() -> (String, Tool) {
|
||||
(
|
||||
"srv.dummy".to_string(),
|
||||
Tool {
|
||||
annotations: None,
|
||||
description: Some("dummy".into()),
|
||||
input_schema: ToolInputSchema {
|
||||
properties: None,
|
||||
required: None,
|
||||
r#type: "object".to_string(),
|
||||
},
|
||||
name: "dummy".into(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Ensure that the default `shell` tool plus any prompt-supplied extra tool are encoded
|
||||
/// correctly for the Responses API. We compare against a golden JSON value rather than
|
||||
/// asserting individual fields so that future refactors will intentionally update the test.
|
||||
#[test]
|
||||
fn responses_includes_default_and_extra() {
|
||||
let mut prompt = Prompt::default();
|
||||
let (name, tool) = dummy_tool();
|
||||
prompt.extra_tools.insert(name.clone(), tool);
|
||||
|
||||
let tools = create_tools_json_for_responses_api(&prompt, "gpt-4").unwrap();
|
||||
|
||||
// Verify presence & order: builtin `shell` first, then our extra tool.
|
||||
assert_eq!(
|
||||
tools[0].get("name"),
|
||||
Some(&serde_json::Value::String("shell".into()))
|
||||
);
|
||||
|
||||
let dummy = tools
|
||||
.iter()
|
||||
.find(|t| t.get("name") == Some(&serde_json::Value::String(name.clone())))
|
||||
.unwrap_or_else(|| panic!("dummy tool not found in tools list"));
|
||||
|
||||
// The dummy tool should match what `mcp_tool_to_openai_tool` produces.
|
||||
let expected_dummy =
|
||||
mcp_tool_to_openai_tool(name, prompt.extra_tools.remove("srv.dummy").unwrap());
|
||||
assert_eq!(dummy, &expected_dummy);
|
||||
}
|
||||
|
||||
/// When the model name starts with `codex-`, the built-in shell tool should be encoded
|
||||
/// as `local_shell` rather than `shell`. Verify that the first tool in the JSON list has
|
||||
/// the adjusted type in that scenario.
|
||||
#[test]
|
||||
fn responses_codex_model_uses_local_shell() {
|
||||
let mut prompt = Prompt::default();
|
||||
let (name, tool) = dummy_tool();
|
||||
prompt.extra_tools.insert(name, tool);
|
||||
|
||||
let tools = create_tools_json_for_responses_api(&prompt, "codex-model").unwrap();
|
||||
assert_eq!(tools[0]["type"], "local_shell");
|
||||
}
|
||||
|
||||
/// Chat-Completions API expects the V2 tool schema (`{"type":"function","function":{..}}`).
|
||||
/// Confirm that every entry is shaped accordingly and the wrapper does not leak the internal
|
||||
/// `type` field inside the nested `function` object.
|
||||
#[test]
|
||||
fn chat_completions_tool_format() {
|
||||
let mut prompt = Prompt::default();
|
||||
let (name, tool) = dummy_tool();
|
||||
prompt.extra_tools.insert(name.clone(), tool);
|
||||
|
||||
let tools = create_tools_json_for_chat_completions_api(&prompt, "gpt-4").unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
for t in tools {
|
||||
assert_eq!(
|
||||
t.get("type"),
|
||||
Some(&serde_json::Value::String("function".into()))
|
||||
);
|
||||
let inner = t.get("function").and_then(|v| v.as_object()).unwrap();
|
||||
assert!(!inner.contains_key("type"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,13 +66,10 @@ async fn chat_mode_stream_cli() {
|
||||
.env("OPENAI_BASE_URL", format!("{}/v1", server.uri()));
|
||||
|
||||
let output = cmd.output().unwrap();
|
||||
println!("Status: {}", output.status);
|
||||
println!("Stdout:\n{}", String::from_utf8_lossy(&output.stdout));
|
||||
println!("Stderr:\n{}", String::from_utf8_lossy(&output.stderr));
|
||||
assert!(output.status.success());
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
assert!(stdout.contains("hi"));
|
||||
assert_eq!(stdout.matches("hi").count(), 1);
|
||||
let hi_lines = stdout.lines().filter(|line| line.trim() == "hi").count();
|
||||
assert_eq!(hi_lines, 1, "Expected exactly one line with 'hi'");
|
||||
|
||||
server.verify().await;
|
||||
}
|
||||
|
||||
@@ -55,12 +55,13 @@ async fn spawn_codex() -> Result<Codex, CodexErr> {
|
||||
// beginning of the test, before we spawn any background tasks that could
|
||||
// observe the environment.
|
||||
unsafe {
|
||||
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "2");
|
||||
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "2");
|
||||
}
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let config = load_default_config_for_test(&codex_home);
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
// Live tests keep retries low to avoid slow backoffs on flaky networks.
|
||||
config.openai_request_max_retries = 2;
|
||||
let (agent, _init_id) = Codex::spawn(config, std::sync::Arc::new(Notify::new())).await?;
|
||||
|
||||
Ok(agent)
|
||||
@@ -79,7 +80,7 @@ async fn live_streaming_and_prev_id_reset() {
|
||||
|
||||
let codex = spawn_codex().await.unwrap();
|
||||
|
||||
// ---------- Task 1 ----------
|
||||
// ---------- Task 1 ----------
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
@@ -113,7 +114,7 @@ async fn live_streaming_and_prev_id_reset() {
|
||||
"Agent did not stream any AgentMessage before TaskComplete"
|
||||
);
|
||||
|
||||
// ---------- Task 2 (same session) ----------
|
||||
// ---------- Task 2 (same session) ----------
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
|
||||
@@ -91,8 +91,8 @@ async fn keeps_previous_response_id_between_tasks() {
|
||||
// Environment
|
||||
// Update environment – `set_var` is `unsafe` starting with the 2024
|
||||
// edition so we group the calls into a single `unsafe { … }` block.
|
||||
// NOTE: per-request retry count is now configured directly on the Config.
|
||||
unsafe {
|
||||
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
||||
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "0");
|
||||
}
|
||||
let model_provider = ModelProviderInfo {
|
||||
@@ -113,6 +113,8 @@ async fn keeps_previous_response_id_between_tasks() {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
// No per-request retries so each new user input triggers exactly one HTTP request.
|
||||
config.openai_request_max_retries = 0;
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
||||
|
||||
|
||||
@@ -74,12 +74,11 @@ async fn retries_on_early_close() {
|
||||
//
|
||||
// As of Rust 2024 `std::env::set_var` has been made `unsafe` because
|
||||
// mutating the process environment is inherently racy when other threads
|
||||
// are running. We therefore have to wrap every call in an explicit
|
||||
// `unsafe` block. These are limited to the test-setup section so the
|
||||
// scope is very small and clearly delineated.
|
||||
// are running. We used to tweak the per-request retry counts via the
|
||||
// `OPENAI_REQUEST_MAX_RETRIES` env var but that caused data races in
|
||||
// multi-threaded tests. Configure the value directly on the Config instead.
|
||||
|
||||
unsafe {
|
||||
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
||||
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "1");
|
||||
std::env::set_var("OPENAI_STREAM_IDLE_TIMEOUT_MS", "2000");
|
||||
}
|
||||
@@ -102,6 +101,8 @@ async fn retries_on_early_close() {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
// Disable per-request retries (we want to exercise stream-level retries).
|
||||
config.openai_request_max_retries = 0;
|
||||
let (codex, _init_id) = Codex::spawn(config, ctrl_c).await.unwrap();
|
||||
|
||||
codex
|
||||
|
||||
@@ -104,6 +104,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
|
||||
model_provider: None,
|
||||
codex_linux_sandbox_exe,
|
||||
openai_request_max_retries: None,
|
||||
};
|
||||
// Parse `-c` overrides.
|
||||
let cli_kv_overrides = match config_overrides.parse_overrides() {
|
||||
|
||||
@@ -142,6 +142,7 @@ impl CodexToolCallParam {
|
||||
sandbox_mode: sandbox.map(Into::into),
|
||||
model_provider: None,
|
||||
codex_linux_sandbox_exe,
|
||||
openai_request_max_retries: None,
|
||||
};
|
||||
|
||||
let cli_overrides = cli_overrides
|
||||
|
||||
@@ -75,6 +75,7 @@ pub fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> std::io::
|
||||
model_provider: None,
|
||||
config_profile: cli.config_profile.clone(),
|
||||
codex_linux_sandbox_exe,
|
||||
openai_request_max_retries: None,
|
||||
};
|
||||
// Parse `-c` overrides from the CLI.
|
||||
let cli_kv_overrides = match cli.config_overrides.parse_overrides() {
|
||||
|
||||
Reference in New Issue
Block a user