diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 8ec68d02e8..ae7904b8ff 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -15,6 +15,7 @@ use tokio_util::io::ReaderStream; use tracing::debug; use tracing::trace; use tracing::warn; +use uuid::Uuid; use crate::chat_completions::AggregateStreamExt; use crate::chat_completions::stream_chat_completions; @@ -44,6 +45,7 @@ pub struct ModelClient { config: Arc, client: reqwest::Client, provider: ModelProviderInfo, + session_id: Uuid, effort: ReasoningEffortConfig, summary: ReasoningSummaryConfig, } @@ -54,11 +56,13 @@ impl ModelClient { provider: ModelProviderInfo, effort: ReasoningEffortConfig, summary: ReasoningSummaryConfig, + session_id: Uuid, ) -> Self { Self { config, client: reqwest::Client::new(), provider, + session_id, effort, summary, } @@ -143,6 +147,7 @@ impl ModelClient { .provider .create_request_builder(&self.client)? .header("OpenAI-Beta", "responses=experimental") + .header("session_id", self.session_id.to_string()) .header(reqwest::header::ACCEPT, "text/event-stream") .json(&payload); diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index d4e73b2ebf..246198c006 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -591,6 +591,7 @@ async fn submission_loop( provider.clone(), model_reasoning_effort, model_reasoning_summary, + session_id, ); // abort any current running session and clone its state diff --git a/codex-rs/core/tests/client.rs b/codex-rs/core/tests/client.rs new file mode 100644 index 0000000000..f4fb58f5a4 --- /dev/null +++ b/codex-rs/core/tests/client.rs @@ -0,0 +1,117 @@ +use std::time::Duration; + +use codex_core::Codex; +use codex_core::ModelProviderInfo; +use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; +use codex_core::protocol::EventMsg; +use codex_core::protocol::InputItem; +use codex_core::protocol::Op; +use codex_core::protocol::SessionConfiguredEvent; +mod test_support; +use tempfile::TempDir; +use test_support::load_default_config_for_test; +use test_support::load_sse_fixture_with_id; +use tokio::time::timeout; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::ResponseTemplate; +use wiremock::matchers::method; +use wiremock::matchers::path; + +/// Build minimal SSE stream with completed marker using the JSON fixture. +fn sse_completed(id: &str) -> String { + load_sse_fixture_with_id("tests/fixtures/completed_template.json", id) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn includes_session_id_and_model_headers_in_request() { + #![allow(clippy::unwrap_used)] + + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + // Mock server + let server = MockServer::start().await; + + // First request – must NOT include `previous_response_id`. + let first = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("resp1"), "text/event-stream"); + + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(first) + .expect(1) + .mount(&server) + .await; + + // Environment + // Update environment – `set_var` is `unsafe` starting with the 2024 + // edition so we group the calls into a single `unsafe { … }` block. + unsafe { + std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0"); + std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "0"); + } + let model_provider = ModelProviderInfo { + name: "openai".into(), + base_url: format!("{}/v1", server.uri()), + // Environment variable that should exist in the test environment. + // ModelClient will return an error if the environment variable for the + // provider is not set. + env_key: Some("PATH".into()), + env_key_instructions: None, + wire_api: codex_core::WireApi::Responses, + query_params: None, + http_headers: Some( + [("originator".to_string(), "codex_cli_rs".to_string())] + .into_iter() + .collect(), + ), + env_http_headers: None, + }; + + // Init session + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home); + config.model_provider = model_provider; + let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); + let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap(); + + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "hello".into(), + }], + }) + .await + .unwrap(); + + let mut current_session_id = None; + // Wait for TaskComplete + loop { + let ev = timeout(Duration::from_secs(1), codex.next_event()) + .await + .unwrap() + .unwrap(); + + if let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) = ev.msg { + current_session_id = Some(session_id.to_string()); + } + if matches!(ev.msg, EventMsg::TaskComplete(_)) { + break; + } + } + + // get request from the server + let request = &server.received_requests().await.unwrap()[0]; + let request_body = request.headers.get("session_id").unwrap(); + let originator = request.headers.get("originator").unwrap(); + + assert!(current_session_id.is_some()); + assert_eq!(request_body.to_str().unwrap(), ¤t_session_id.unwrap()); + assert_eq!(originator.to_str().unwrap(), "codex_cli_rs"); +} diff --git a/codex-rs/tui/src/app.rs b/codex-rs/tui/src/app.rs index d8af5d33be..37c2616d5b 100644 --- a/codex-rs/tui/src/app.rs +++ b/codex-rs/tui/src/app.rs @@ -19,7 +19,8 @@ use crossterm::event::MouseEvent; use crossterm::event::MouseEventKind; use std::path::PathBuf; use std::sync::Arc; -use std::sync::Mutex; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; use std::sync::mpsc::Receiver; use std::sync::mpsc::channel; use std::thread; @@ -54,7 +55,7 @@ pub(crate) struct App<'a> { file_search: FileSearchManager, /// True when a redraw has been scheduled but not yet executed. - pending_redraw: Arc>, + pending_redraw: Arc, /// Stored parameters needed to instantiate the ChatWidget later, e.g., /// after dismissing the Git-repo warning. @@ -80,7 +81,7 @@ impl App<'_> { ) -> Self { let (app_event_tx, app_event_rx) = channel(); let app_event_tx = AppEventSender::new(app_event_tx); - let pending_redraw = Arc::new(Mutex::new(false)); + let pending_redraw = Arc::new(AtomicBool::new(false)); let scroll_event_helper = ScrollEventHelper::new(app_event_tx.clone()); // Spawn a dedicated thread for reading the crossterm event loop and @@ -177,13 +178,14 @@ impl App<'_> { /// Schedule a redraw if one is not already pending. #[allow(clippy::unwrap_used)] fn schedule_redraw(&self) { + // Attempt to set the flag to `true`. If it was already `true`, another + // redraw is already pending so we can return early. + if self + .pending_redraw + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .is_err() { - #[allow(clippy::unwrap_used)] - let mut flag = self.pending_redraw.lock().unwrap(); - if *flag { - return; - } - *flag = true; + return; } let tx = self.app_event_tx.clone(); @@ -191,9 +193,7 @@ impl App<'_> { thread::spawn(move || { thread::sleep(REDRAW_DEBOUNCE); tx.send(AppEvent::Redraw); - #[allow(clippy::unwrap_used)] - let mut f = pending_redraw.lock().unwrap(); - *f = false; + pending_redraw.store(false, Ordering::SeqCst); }); }