diff --git a/codex-rs/core/src/agent_identity.rs b/codex-rs/core/src/agent_identity.rs index 5ac40f2fe5..85670f50cd 100644 --- a/codex-rs/core/src/agent_identity.rs +++ b/codex-rs/core/src/agent_identity.rs @@ -21,7 +21,6 @@ use ed25519_dalek::pkcs8::DecodePrivateKey; use ed25519_dalek::pkcs8::EncodePrivateKey; use rand::TryRngCore; use rand::rngs::OsRng; -use reqwest::StatusCode; use serde::Deserialize; use serde::Serialize; use tokio::sync::Mutex; @@ -108,6 +107,10 @@ impl AgentIdentityManager { } } + pub(crate) fn is_enabled(&self) -> bool { + self.feature_enabled + } + pub(crate) async fn ensure_registered_identity(&self) -> Result> { if !self.feature_enabled { return Ok(None); @@ -153,69 +156,45 @@ impl AgentIdentityManager { }; let client = create_client(); - let urls = agent_registration_urls(&self.chatgpt_base_url); + let url = agent_registration_url(&self.chatgpt_base_url); + let response = client + .post(&url) + .bearer_auth(&binding.access_token) + .header("chatgpt-account-id", &binding.chatgpt_account_id) + .json(&request_body) + .timeout(AGENT_REGISTRATION_TIMEOUT) + .send() + .await + .with_context(|| { + format!("failed to send agent identity registration request to {url}") + })?; - for (index, url) in urls.iter().enumerate() { - let response = client - .post(url) - .bearer_auth(&binding.access_token) - .header("chatgpt-account-id", &binding.chatgpt_account_id) - .json(&request_body) - .timeout(AGENT_REGISTRATION_TIMEOUT) - .send() + if response.status().is_success() { + let response_body = response + .json::() .await - .with_context(|| { - format!("failed to send agent identity registration request to {url}") - })?; - - if response.status().is_success() { - let response_body = response - .json::() - .await - .with_context(|| { - format!("failed to parse agent identity response from {url}") - })?; - let stored_identity = StoredAgentIdentity { - binding_id: binding.binding_id.clone(), - chatgpt_account_id: binding.chatgpt_account_id.clone(), - chatgpt_user_id: binding.chatgpt_user_id.clone(), - agent_runtime_id: response_body.agent_runtime_id, - private_key_pkcs8_base64: key_material.private_key_pkcs8_base64, - public_key_ssh: key_material.public_key_ssh, - registered_at: Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true), - abom: self.abom.clone(), - }; - info!( - agent_runtime_id = %stored_identity.agent_runtime_id, - binding_id = %binding.binding_id, - "registered agent identity" - ); - return Ok(stored_identity); - } - - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - let is_last_candidate = index + 1 == urls.len(); - if !is_last_candidate - && matches!( - status, - StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED - ) - { - debug!( - url = %url, - status = %status, - "agent identity registration endpoint unavailable at candidate URL; trying fallback" - ); - continue; - } - - anyhow::bail!( - "agent identity registration failed with status {status} from {url}: {body}" + .with_context(|| format!("failed to parse agent identity response from {url}"))?; + let stored_identity = StoredAgentIdentity { + binding_id: binding.binding_id.clone(), + chatgpt_account_id: binding.chatgpt_account_id.clone(), + chatgpt_user_id: binding.chatgpt_user_id.clone(), + agent_runtime_id: response_body.agent_runtime_id, + private_key_pkcs8_base64: key_material.private_key_pkcs8_base64, + public_key_ssh: key_material.public_key_ssh, + registered_at: Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true), + abom: self.abom.clone(), + }; + info!( + agent_runtime_id = %stored_identity.agent_runtime_id, + binding_id = %binding.binding_id, + "registered agent identity" ); + return Ok(stored_identity); } - anyhow::bail!("no candidate URLs were available for agent identity registration") + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("agent identity registration failed with status {status} from {url}: {body}") } fn load_stored_identity( @@ -410,15 +389,12 @@ fn secret_scope(binding: &AgentIdentityBinding) -> Result { .context("agent identity binding must be a valid secrets scope") } -fn agent_registration_urls(chatgpt_base_url: &str) -> Vec { +fn agent_registration_url(chatgpt_base_url: &str) -> String { let trimmed = chatgpt_base_url.trim_end_matches('/'); if let Some(root) = trimmed.strip_suffix("/backend-api") { - return vec![ - format!("{root}/v1/agent/register"), - format!("{trimmed}/v1/agent/register"), - ]; + return format!("{root}/v1/agent/register"); } - vec![format!("{trimmed}/v1/agent/register")] + format!("{trimmed}/v1/agent/register") } #[cfg(test)] @@ -589,18 +565,12 @@ mod tests { } #[tokio::test] - async fn ensure_registered_identity_falls_back_to_backend_api_v1() { + async fn ensure_registered_identity_uses_canonical_agent_registration_url() { let server = MockServer::start().await; Mock::given(method("POST")) .and(path("/v1/agent/register")) - .respond_with(ResponseTemplate::new(404)) - .expect(1) - .mount(&server) - .await; - Mock::given(method("POST")) - .and(path("/backend-api/v1/agent/register")) .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "agent_runtime_id": "agent_fallback", + "agent_runtime_id": "agent_canonical", }))) .expect(1) .mount(&server) @@ -628,7 +598,7 @@ mod tests { .await .unwrap() .expect("identity should be registered"); - assert_eq!(stored.agent_runtime_id, "agent_fallback"); + assert_eq!(stored.agent_runtime_id, "agent_canonical"); } #[test] diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 9846f275cc..8533552020 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1479,22 +1479,55 @@ impl Session { } fn start_agent_identity_registration(self: &Arc) { + if !self.services.agent_identity_manager.is_enabled() { + return; + } + let weak_sess = Arc::downgrade(self); + let mut auth_state_rx = self.services.auth_manager.subscribe_auth_state(); tokio::spawn(async move { - let Some(sess) = weak_sess.upgrade() else { - return; - }; - if let Err(error) = sess - .services - .agent_identity_manager - .ensure_registered_identity() - .await - { - warn!(error = %error, "agent identity registration failed"); + loop { + let Some(sess) = weak_sess.upgrade() else { + return; + }; + match sess + .services + .agent_identity_manager + .ensure_registered_identity() + .await + { + Ok(Some(_)) => return, + Ok(None) => { + drop(sess); + if auth_state_rx.changed().await.is_err() { + return; + } + } + Err(error) => { + sess.fail_agent_identity_registration(error).await; + return; + } + } } }); } + async fn fail_agent_identity_registration(self: &Arc, error: anyhow::Error) { + warn!(error = %error, "agent identity registration failed"); + let message = format!( + "Agent identity registration failed. Codex cannot continue while `features.use_agent_identity` is enabled: {error}" + ); + self.send_event_raw(Event { + id: self.next_internal_sub_id(), + msg: EventMsg::Error(ErrorEvent { + message, + codex_error_info: Some(CodexErrorInfo::Other), + }), + }) + .await; + handlers::shutdown(self, self.next_internal_sub_id()).await; + } + #[allow(clippy::too_many_arguments)] fn make_turn_context( conversation_id: ThreadId, diff --git a/codex-rs/core/src/codex_tests.rs b/codex-rs/core/src/codex_tests.rs index 669d99245f..c4c5000749 100644 --- a/codex-rs/core/src/codex_tests.rs +++ b/codex-rs/core/src/codex_tests.rs @@ -110,6 +110,7 @@ use opentelemetry::trace::TraceId; use std::path::Path; use std::time::Duration; use tokio::time::sleep; +use tokio::time::timeout; use tracing_opentelemetry::OpenTelemetrySpanExt; use codex_protocol::mcp::CallToolResult as McpCallToolResult; @@ -3818,6 +3819,42 @@ pub(crate) async fn make_session_and_context_with_rx() -> ( make_session_and_context_with_dynamic_tools_and_rx(Vec::new()).await } +#[tokio::test] +async fn fail_agent_identity_registration_emits_error_and_shutdown() { + let (session, _turn_context, rx_event) = make_session_and_context_with_rx().await; + + session + .fail_agent_identity_registration(anyhow::anyhow!("registration exploded")) + .await; + + let error_event = timeout(Duration::from_secs(1), rx_event.recv()) + .await + .expect("error event should arrive") + .expect("error event should be readable"); + match error_event.msg { + EventMsg::Error(ErrorEvent { + message, + codex_error_info, + }) => { + assert_eq!( + message, + "Agent identity registration failed. Codex cannot continue while `features.use_agent_identity` is enabled: registration exploded".to_string() + ); + assert_eq!(codex_error_info, Some(CodexErrorInfo::Other)); + } + other => panic!("expected error event, got {other:?}"), + } + + let shutdown_event = timeout(Duration::from_secs(1), rx_event.recv()) + .await + .expect("shutdown event should arrive") + .expect("shutdown event should be readable"); + match shutdown_event.msg { + EventMsg::ShutdownComplete => {} + other => panic!("expected shutdown event, got {other:?}"), + } +} + #[tokio::test] async fn refresh_mcp_servers_is_deferred_until_next_turn() { let (session, turn_context) = make_session_and_context().await; diff --git a/codex-rs/login/src/auth/auth_tests.rs b/codex-rs/login/src/auth/auth_tests.rs index 27064b6831..d5cf30ca13 100644 --- a/codex-rs/login/src/auth/auth_tests.rs +++ b/codex-rs/login/src/auth/auth_tests.rs @@ -16,6 +16,8 @@ use serde_json::json; use std::sync::Arc; use tempfile::TempDir; use tempfile::tempdir; +use tokio::time::Duration; +use tokio::time::timeout; #[tokio::test] async fn refresh_without_id_token() { @@ -474,6 +476,65 @@ exit 1 } } +#[tokio::test] +async fn auth_manager_notifies_when_auth_state_changes() { + let dir = tempdir().unwrap(); + let manager = AuthManager::shared( + dir.path().to_path_buf(), + false, + AuthCredentialsStoreMode::File, + ); + let mut auth_state_rx = manager.subscribe_auth_state(); + + save_auth( + dir.path(), + &AuthDotJson { + auth_mode: Some(ApiAuthMode::ApiKey), + openai_api_key: Some("sk-test-key".to_string()), + tokens: None, + last_refresh: None, + }, + AuthCredentialsStoreMode::File, + ) + .expect("save auth"); + + assert!( + manager.reload(), + "reload should report a changed auth state" + ); + timeout(Duration::from_secs(1), auth_state_rx.changed()) + .await + .expect("auth change notification should arrive") + .expect("auth state watch should remain open"); + + save_auth( + dir.path(), + &AuthDotJson { + auth_mode: Some(ApiAuthMode::ApiKey), + openai_api_key: Some("sk-updated-key".to_string()), + tokens: None, + last_refresh: None, + }, + AuthCredentialsStoreMode::File, + ) + .expect("save updated auth"); + + assert!( + !manager.reload(), + "reload remains mode-stable even when the underlying credentials change" + ); + timeout(Duration::from_secs(1), auth_state_rx.changed()) + .await + .expect("auth reload notification should still arrive") + .expect("auth state watch should remain open"); + + manager.set_forced_chatgpt_workspace_id(Some("workspace-123".to_string())); + timeout(Duration::from_secs(1), auth_state_rx.changed()) + .await + .expect("workspace change notification should arrive") + .expect("auth state watch should remain open"); +} + struct AuthFileParams { openai_api_key: Option, chatgpt_plan_type: Option, diff --git a/codex-rs/login/src/auth/manager.rs b/codex-rs/login/src/auth/manager.rs index 71857c9700..9066fc3271 100644 --- a/codex-rs/login/src/auth/manager.rs +++ b/codex-rs/login/src/auth/manager.rs @@ -13,6 +13,7 @@ use std::sync::Arc; use std::sync::Mutex; use std::sync::RwLock; use tokio::sync::Mutex as AsyncMutex; +use tokio::sync::watch; use codex_app_server_protocol::AuthMode; use codex_app_server_protocol::AuthMode as ApiAuthMode; @@ -1106,6 +1107,7 @@ pub struct AuthManager { forced_chatgpt_workspace_id: RwLock>, refresh_lock: AsyncMutex<()>, external_auth: RwLock>>, + auth_state_tx: watch::Sender<()>, } /// Configuration view required to construct a shared [`AuthManager`]. @@ -1154,6 +1156,7 @@ impl AuthManager { enable_codex_api_key_env: bool, auth_credentials_store_mode: AuthCredentialsStoreMode, ) -> Self { + let (auth_state_tx, _) = watch::channel(()); let managed_auth = load_auth( &codex_home, enable_codex_api_key_env, @@ -1172,11 +1175,13 @@ impl AuthManager { forced_chatgpt_workspace_id: RwLock::new(None), refresh_lock: AsyncMutex::new(()), external_auth: RwLock::new(None), + auth_state_tx, } } /// Create an AuthManager with a specific CodexAuth, for testing only. pub fn from_auth_for_testing(auth: CodexAuth) -> Arc { + let (auth_state_tx, _) = watch::channel(()); let cached = CachedAuth { auth: Some(auth), permanent_refresh_failure: None, @@ -1190,11 +1195,13 @@ impl AuthManager { forced_chatgpt_workspace_id: RwLock::new(None), refresh_lock: AsyncMutex::new(()), external_auth: RwLock::new(None), + auth_state_tx, }) } /// Create an AuthManager with a specific CodexAuth and codex home, for testing only. pub fn from_auth_for_testing_with_home(auth: CodexAuth, codex_home: PathBuf) -> Arc { + let (auth_state_tx, _) = watch::channel(()); let cached = CachedAuth { auth: Some(auth), permanent_refresh_failure: None, @@ -1207,10 +1214,12 @@ impl AuthManager { forced_chatgpt_workspace_id: RwLock::new(None), refresh_lock: AsyncMutex::new(()), external_auth: RwLock::new(None), + auth_state_tx, }) } pub fn external_bearer_only(config: ModelProviderAuthInfo) -> Arc { + let (auth_state_tx, _) = watch::channel(()); Arc::new(Self { codex_home: PathBuf::from("non-existent"), inner: RwLock::new(CachedAuth { @@ -1224,6 +1233,7 @@ impl AuthManager { external_auth: RwLock::new(Some( Arc::new(BearerTokenRefresher::new(config)) as Arc )), + auth_state_tx, }) } @@ -1363,6 +1373,7 @@ impl AuthManager { } tracing::info!("Reloaded auth, changed: {changed}"); guard.auth = new_auth; + self.auth_state_tx.send_replace(()); changed } else { false @@ -1372,18 +1383,23 @@ impl AuthManager { pub fn set_external_auth(&self, external_auth: Arc) { if let Ok(mut guard) = self.external_auth.write() { *guard = Some(external_auth); + self.auth_state_tx.send_replace(()); } } pub fn clear_external_auth(&self) { if let Ok(mut guard) = self.external_auth.write() { *guard = None; + self.auth_state_tx.send_replace(()); } } pub fn set_forced_chatgpt_workspace_id(&self, workspace_id: Option) { - if let Ok(mut guard) = self.forced_chatgpt_workspace_id.write() { + if let Ok(mut guard) = self.forced_chatgpt_workspace_id.write() + && *guard != workspace_id + { *guard = workspace_id; + self.auth_state_tx.send_replace(()); } } @@ -1394,6 +1410,10 @@ impl AuthManager { .and_then(|guard| guard.clone()) } + pub fn subscribe_auth_state(&self) -> watch::Receiver<()> { + self.auth_state_tx.subscribe() + } + pub fn has_external_auth(&self) -> bool { self.external_auth().is_some() }