From a87e178824ef8ec55f6b2d7a2a6e2518e0c69795 Mon Sep 17 00:00:00 2001 From: adrian Date: Tue, 14 Apr 2026 12:25:09 -0700 Subject: [PATCH] Refresh stale agent tasks lazily --- codex-rs/core/src/agent_identity.rs | 4 + codex-rs/core/src/agent_identity/assertion.rs | 58 +++++++++- codex-rs/core/src/client.rs | 76 ++++++++++--- codex-rs/core/src/client_tests.rs | 107 ++++++++++++++++++ codex-rs/core/src/codex.rs | 53 ++++++++- codex-rs/core/src/session_startup_prewarm.rs | 10 ++ codex-rs/protocol/src/error.rs | 3 + 7 files changed, 286 insertions(+), 25 deletions(-) diff --git a/codex-rs/core/src/agent_identity.rs b/codex-rs/core/src/agent_identity.rs index 1061f45fba..e16ea49ee0 100644 --- a/codex-rs/core/src/agent_identity.rs +++ b/codex-rs/core/src/agent_identity.rs @@ -34,6 +34,7 @@ mod task_registration; #[cfg(test)] pub(crate) use assertion::AgentAssertionEnvelope; +pub(crate) use assertion::AgentTaskRuntimeMismatch; pub(crate) use task_registration::RegisteredAgentTask; const AGENT_REGISTRATION_TIMEOUT: Duration = Duration::from_secs(15); @@ -451,6 +452,9 @@ impl AgentIdentityBinding { } fn from_auth(auth: &CodexAuth, forced_workspace_id: Option) -> Option { + // AgentAssertion is currently supported only for ChatGPT-backed Codex sessions. API-key + // sessions keep using their API key until the registration service supports API-key + // identity binding. if !auth.is_chatgpt_auth() { return None; } diff --git a/codex-rs/core/src/agent_identity/assertion.rs b/codex-rs/core/src/agent_identity/assertion.rs index 147de72c15..6b8ead71d9 100644 --- a/codex-rs/core/src/agent_identity/assertion.rs +++ b/codex-rs/core/src/agent_identity/assertion.rs @@ -7,10 +7,21 @@ use base64::engine::general_purpose::URL_SAFE_NO_PAD; use ed25519_dalek::Signer as _; use serde::Deserialize; use serde::Serialize; +use thiserror::Error; use tracing::debug; use super::*; +#[derive(Debug, Error)] +#[error( + "agent task runtime {agent_runtime_id} does not match stored agent identity {stored_agent_runtime_id}" +)] +pub(crate) struct AgentTaskRuntimeMismatch { + pub(crate) agent_runtime_id: String, + pub(crate) task_id: String, + pub(crate) stored_agent_runtime_id: String, +} + #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] pub(crate) struct AgentAssertionEnvelope { pub(crate) agent_runtime_id: String, @@ -31,12 +42,14 @@ impl AgentIdentityManager { let Some(stored_identity) = self.ensure_registered_identity().await? else { return Ok(None); }; - anyhow::ensure!( - stored_identity.agent_runtime_id == agent_task.agent_runtime_id, - "agent task runtime {} does not match stored agent identity {}", - agent_task.agent_runtime_id, - stored_identity.agent_runtime_id - ); + if stored_identity.agent_runtime_id != agent_task.agent_runtime_id { + return Err(AgentTaskRuntimeMismatch { + agent_runtime_id: agent_task.agent_runtime_id.clone(), + task_id: agent_task.task_id.clone(), + stored_agent_runtime_id: stored_identity.agent_runtime_id, + } + .into()); + } let timestamp = Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true); let envelope = AgentAssertionEnvelope { @@ -176,6 +189,39 @@ mod tests { .expect("signature should verify"); } + #[tokio::test] + async fn authorization_header_for_task_reports_runtime_mismatch() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let auth = make_chatgpt_auth(codex_home.path(), "account-123", Some("user-123")); + let auth_manager = AuthManager::from_auth_for_testing(auth); + let manager = AgentIdentityManager::new_for_tests( + auth_manager, + /*feature_enabled*/ true, + "https://chatgpt.com/backend-api/".to_string(), + SessionSource::Cli, + ); + manager + .seed_generated_identity_for_tests("agent-current") + .await + .expect("seed test identity"); + let agent_task = RegisteredAgentTask { + agent_runtime_id: "agent-stale".to_string(), + task_id: "task-123".to_string(), + registered_at: "2026-03-23T12:00:00Z".to_string(), + }; + + let error = manager + .authorization_header_for_task(&agent_task) + .await + .expect_err("stale task should be reported"); + let mismatch = error + .downcast_ref::() + .expect("runtime mismatch error"); + assert_eq!(mismatch.agent_runtime_id, "agent-stale"); + assert_eq!(mismatch.task_id, "task-123"); + assert_eq!(mismatch.stored_agent_runtime_id, "agent-current"); + } + fn make_chatgpt_auth( codex_home: &std::path::Path, account_id: &str, diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 1ec7ca83bb..7a24f43765 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -32,6 +32,7 @@ use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; use crate::agent_identity::AgentIdentityManager; +use crate::agent_identity::AgentTaskRuntimeMismatch; use crate::agent_identity::RegisteredAgentTask; use codex_api::ApiError; use codex_api::CompactClient as ApiCompactClient; @@ -159,7 +160,7 @@ struct ModelClientState { include_timing_metrics: bool, beta_features_header: Option, disable_websockets: AtomicBool, - cached_websocket_session: StdMutex, + cached_websocket_session: StdMutex, } /// Resolved API client setup for a single request attempt. @@ -244,6 +245,12 @@ struct WebsocketSession { connection_reused: StdMutex, } +#[derive(Debug, Default)] +struct CachedWebsocketSession { + agent_task: Option, + websocket_session: WebsocketSession, +} + impl WebsocketSession { fn set_connection_reused(&self, connection_reused: bool) { *self @@ -360,7 +367,7 @@ impl ModelClient { include_timing_metrics, beta_features_header, disable_websockets: AtomicBool::new(false), - cached_websocket_session: StdMutex::new(WebsocketSession::default()), + cached_websocket_session: StdMutex::new(CachedWebsocketSession::default()), }), } } @@ -377,18 +384,15 @@ impl ModelClient { &self, agent_task: Option, ) -> ModelClientSession { - let cache_websocket_session_on_drop = agent_task.is_none(); - let websocket_session = if agent_task.is_some() { - drop(self.take_cached_websocket_session()); - WebsocketSession::default() - } else { - self.take_cached_websocket_session() - }; + // WebSocket auth is bound to the task that opened the connection. Reuse only when the + // cached connection was created for the same task, and drop mismatched taskless/task-scoped + // sessions rather than mixing auth contexts. + let websocket_session = self.take_cached_websocket_session(agent_task.as_ref()); ModelClientSession { client: self.clone(), websocket_session, agent_task, - cache_websocket_session_on_drop, + cache_websocket_session_on_drop: true, turn_state: Arc::new(OnceLock::new()), } } @@ -401,12 +405,12 @@ impl ModelClient { self.state .window_generation .store(window_generation, Ordering::Relaxed); - self.store_cached_websocket_session(WebsocketSession::default()); + self.clear_cached_websocket_session(); } pub(crate) fn advance_window_generation(&self) { self.state.window_generation.fetch_add(1, Ordering::Relaxed); - self.store_cached_websocket_session(WebsocketSession::default()); + self.clear_cached_websocket_session(); } fn current_window_id(&self) -> String { @@ -415,21 +419,44 @@ impl ModelClient { format!("{conversation_id}:{window_generation}") } - fn take_cached_websocket_session(&self) -> WebsocketSession { + fn take_cached_websocket_session( + &self, + agent_task: Option<&RegisteredAgentTask>, + ) -> WebsocketSession { let mut cached_websocket_session = self .state .cached_websocket_session .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); - std::mem::take(&mut *cached_websocket_session) + if cached_websocket_session.agent_task.as_ref() == agent_task { + return std::mem::take(&mut *cached_websocket_session).websocket_session; + } + + *cached_websocket_session = CachedWebsocketSession::default(); + WebsocketSession::default() } - fn store_cached_websocket_session(&self, websocket_session: WebsocketSession) { + fn store_cached_websocket_session( + &self, + agent_task: Option, + websocket_session: WebsocketSession, + ) { *self .state .cached_websocket_session .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) = websocket_session; + .unwrap_or_else(std::sync::PoisonError::into_inner) = CachedWebsocketSession { + agent_task, + websocket_session, + }; + } + + fn clear_cached_websocket_session(&self) { + *self + .state + .cached_websocket_session + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = CachedWebsocketSession::default(); } pub(crate) fn force_http_fallback( @@ -449,7 +476,7 @@ impl ModelClient { ); } - self.store_cached_websocket_session(WebsocketSession::default()); + self.clear_cached_websocket_session(); activated } @@ -727,6 +754,15 @@ impl ModelClient { .authorization_header_for_task(agent_task) .await .map_err(|err| { + if let Some(mismatch) = err.downcast_ref::() { + debug!( + agent_runtime_id = %mismatch.agent_runtime_id, + task_id = %mismatch.task_id, + stored_agent_runtime_id = %mismatch.stored_agent_runtime_id, + "agent task no longer matches stored identity" + ); + return CodexErr::AgentTaskStale; + } CodexErr::Stream( format!("failed to build agent assertion authorization: {err}"), None, @@ -883,12 +919,16 @@ impl Drop for ModelClientSession { let websocket_session = std::mem::take(&mut self.websocket_session); if self.cache_websocket_session_on_drop { self.client - .store_cached_websocket_session(websocket_session); + .store_cached_websocket_session(self.agent_task.clone(), websocket_session); } } } impl ModelClientSession { + pub(crate) fn agent_task(&self) -> Option<&RegisteredAgentTask> { + self.agent_task.as_ref() + } + pub(crate) fn disable_cached_websocket_session_on_drop(&mut self) { self.cache_websocket_session_on_drop = false; } diff --git a/codex-rs/core/src/client_tests.rs b/codex-rs/core/src/client_tests.rs index ae77ab456a..9f83df60d1 100644 --- a/codex-rs/core/src/client_tests.rs +++ b/codex-rs/core/src/client_tests.rs @@ -33,6 +33,7 @@ use codex_model_provider_info::create_oss_provider_with_base_url; use codex_otel::SessionTelemetry; use codex_protocol::ThreadId; use codex_protocol::config_types::ReasoningSummary; +use codex_protocol::error::CodexErr; use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; use codex_protocol::openai_models::ModelInfo; @@ -393,6 +394,35 @@ async fn responses_http_uses_agent_assertion_when_agent_task_is_present() { assert_eq!(request.header("chatgpt-account-id"), None); } +#[tokio::test] +async fn responses_http_reports_stale_agent_task_when_identity_changed() { + let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses); + let (_codex_home, client, mut agent_task, _stored_identity) = + model_client_with_agent_task(provider).await; + agent_task.agent_runtime_id = "agent-stale".to_string(); + let model_info = test_model_info(); + let session_telemetry = test_session_telemetry(); + let mut client_session = client.new_session_with_agent_task(Some(agent_task)); + + let error = match client_session + .stream( + &test_prompt("hello"), + &model_info, + &session_telemetry, + /*effort*/ None, + ReasoningSummary::Auto, + /*service_tier*/ None, + /*turn_metadata_header*/ None, + ) + .await + { + Ok(_) => panic!("stale task should be reported before sending a request"), + Err(error) => error, + }; + + assert!(matches!(error, CodexErr::AgentTaskStale)); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn websocket_agent_task_bypasses_cached_bearer_prewarm() { core_test_support::skip_if_no_network!(); @@ -469,3 +499,80 @@ async fn websocket_agent_task_bypasses_cached_bearer_prewarm() { server.shutdown().await; } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn websocket_agent_task_reuses_cached_connection_for_same_task() { + core_test_support::skip_if_no_network!(); + + let server = responses::start_websocket_server(vec![vec![ + vec![ + responses::ev_response_created("resp-1"), + responses::ev_completed("resp-1"), + ], + vec![ + responses::ev_response_created("resp-2"), + responses::ev_completed("resp-2"), + ], + ]]) + .await; + let mut provider = + create_oss_provider_with_base_url(&format!("{}/v1", server.uri()), WireApi::Responses); + provider.supports_websockets = true; + provider.websocket_connect_timeout_ms = Some(5_000); + let (_codex_home, client, agent_task, stored_identity) = + model_client_with_agent_task(provider).await; + let model_info = test_model_info(); + let session_telemetry = test_session_telemetry(); + let prompt = test_prompt("hello"); + + { + let mut first_session = client.new_session_with_agent_task(Some(agent_task.clone())); + let mut stream = first_session + .stream( + &prompt, + &model_info, + &session_telemetry, + /*effort*/ None, + ReasoningSummary::Auto, + /*service_tier*/ None, + /*turn_metadata_header*/ None, + ) + .await + .expect("first agent task stream should succeed"); + drain_stream_to_completion(&mut stream) + .await + .expect("first agent task websocket stream should complete"); + } + + let mut second_session = client.new_session_with_agent_task(Some(agent_task.clone())); + let mut stream = second_session + .stream( + &prompt, + &model_info, + &session_telemetry, + /*effort*/ None, + ReasoningSummary::Auto, + /*service_tier*/ None, + /*turn_metadata_header*/ None, + ) + .await + .expect("second agent task stream should succeed"); + drain_stream_to_completion(&mut stream) + .await + .expect("second agent task websocket stream should complete"); + + let handshakes = server.handshakes(); + assert_eq!(handshakes.len(), 1); + let agent_authorization = handshakes[0] + .header("authorization") + .expect("agent handshake should include authorization"); + assert_agent_assertion_header( + &agent_authorization, + &stored_identity, + &agent_task.agent_runtime_id, + &agent_task.task_id, + ); + assert_eq!(server.single_connection().len(), 2); + + server.shutdown().await; +} diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 38be995a38..21bf19971d 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -6354,6 +6354,14 @@ pub(crate) async fn run_turn( .await; user_prompt_submit_outcome.additional_contexts }; + let agent_task_registration = if sess.services.agent_identity_manager.is_enabled() { + let sess = Arc::clone(&sess); + Some(tokio::spawn(async move { + sess.ensure_agent_task_registered().await + })) + } else { + None + }; sess.services .analytics_events_client .track_app_mentioned(tracking.clone(), mentioned_app_invocations); @@ -6375,7 +6383,15 @@ pub(crate) async fn run_turn( })) .await; } - let agent_task = match sess.ensure_agent_task_registered().await { + let agent_task_result = match agent_task_registration { + Some(registration) => registration.await.unwrap_or_else(|error| { + Err(anyhow::anyhow!( + "agent task registration task failed: {error}" + )) + }), + None => sess.ensure_agent_task_registered().await, + }; + let agent_task = match agent_task_result { Ok(agent_task) => agent_task, Err(error) => { warn!(error = %error, "agent task registration failed"); @@ -7031,6 +7047,7 @@ async fn run_sampling_request( ) .await; let mut retries = 0; + let mut stale_agent_task_refreshed = false; loop { let err = match try_run_sampling_request( tool_runtime.clone(), @@ -7052,6 +7069,40 @@ async fn run_sampling_request( sess.set_total_tokens_full(&turn_context).await; return Err(CodexErr::ContextWindowExceeded); } + Err(CodexErr::AgentTaskStale) => { + if stale_agent_task_refreshed { + return Err(CodexErr::AgentTaskStale); + } + stale_agent_task_refreshed = true; + let stale_agent_task = client_session.agent_task().cloned(); + client_session.disable_cached_websocket_session_on_drop(); + if let Some(stale_agent_task) = stale_agent_task.as_ref() { + sess.clear_cached_agent_task(stale_agent_task).await; + } + match sess.ensure_agent_task_registered().await { + Ok(Some(agent_task)) => { + *client_session = sess + .services + .model_client + .new_session_with_agent_task(Some(agent_task)); + retries = 0; + continue; + } + Ok(None) => { + return Err(CodexErr::Stream( + "agent assertion task became unavailable after identity changed" + .to_string(), + None, + )); + } + Err(error) => { + return Err(CodexErr::Stream( + format!("failed to refresh stale agent task: {error}"), + None, + )); + } + } + } Err(CodexErr::UsageLimitReached(e)) => { let rate_limits = e.rate_limits.clone(); if let Some(rate_limits) = rate_limits { diff --git a/codex-rs/core/src/session_startup_prewarm.rs b/codex-rs/core/src/session_startup_prewarm.rs index acd9232c09..cbce8e2baa 100644 --- a/codex-rs/core/src/session_startup_prewarm.rs +++ b/codex-rs/core/src/session_startup_prewarm.rs @@ -13,6 +13,7 @@ use crate::codex::INITIAL_SUBMIT_ID; use crate::codex::Session; use crate::codex::build_prompt; use crate::codex::built_tools; +use codex_app_server_protocol::AuthMode; use codex_otel::STARTUP_PREWARM_AGE_AT_FIRST_TURN_METRIC; use codex_otel::STARTUP_PREWARM_DURATION_METRIC; use codex_otel::SessionTelemetry; @@ -157,6 +158,15 @@ impl SessionStartupPrewarmHandle { impl Session { pub(crate) async fn schedule_startup_prewarm(self: &Arc, base_instructions: String) { + if self.services.agent_identity_manager.is_enabled() + && self.services.auth_manager.auth_mode() != Some(AuthMode::ApiKey) + { + info!( + "skipping startup websocket prewarm because agent identity requires task-scoped auth" + ); + return; + } + let session_telemetry = self.services.session_telemetry.clone(); let websocket_connect_timeout = self.provider().await.websocket_connect_timeout(); let started_at = Instant::now(); diff --git a/codex-rs/protocol/src/error.rs b/codex-rs/protocol/src/error.rs index 3ffdeb48e0..1cb13c9b14 100644 --- a/codex-rs/protocol/src/error.rs +++ b/codex-rs/protocol/src/error.rs @@ -76,6 +76,8 @@ pub enum CodexErr { /// Optionally includes the requested delay before retrying the turn. #[error("stream disconnected before completion: {0}")] Stream(String, Option), + #[error("agent task no longer matches the current agent identity")] + AgentTaskStale, #[error( "Codex ran out of room in the model's context window. Start a new thread or clear earlier history before retrying." )] @@ -183,6 +185,7 @@ impl CodexErr { | CodexErr::ContextWindowExceeded | CodexErr::ThreadNotFound(_) | CodexErr::AgentLimitReached { .. } + | CodexErr::AgentTaskStale | CodexErr::Spawn | CodexErr::SessionConfiguredNotFirstEvent | CodexErr::UsageLimitReached(_)