diff --git a/codex-rs/login/src/auth/auth_tests.rs b/codex-rs/login/src/auth/auth_tests.rs index 0d473d2fc2..a3c714ae4c 100644 --- a/codex-rs/login/src/auth/auth_tests.rs +++ b/codex-rs/login/src/auth/auth_tests.rs @@ -4,6 +4,7 @@ use crate::auth::storage::get_auth_file; use crate::token_data::IdTokenInfo; use crate::token_data::KnownPlan as InternalKnownPlan; use crate::token_data::PlanType as InternalPlanType; +use async_trait::async_trait; use codex_protocol::account::PlanType as AccountPlanType; use base64::Engine; @@ -12,6 +13,7 @@ use pretty_assertions::assert_eq; use serde::Serialize; use serde_json::json; use std::sync::Arc; +use std::sync::Mutex; use tempfile::tempdir; #[tokio::test] @@ -265,6 +267,122 @@ fn external_auth_tokens_without_chatgpt_metadata_cannot_seed_chatgpt_auth() { ); } +#[tokio::test] +async fn auth_manager_with_external_bearer_refresher_returns_provider_token_only_for_derived_manager() + { + let base_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("base-token")); + let derived_manager = + base_manager.with_external_bearer_refresher(Arc::new(StaticExternalAuthRefresher::new( + Some(ExternalAuthTokens::access_token_only("provider-token")), + ExternalAuthTokens::access_token_only("refreshed-provider-token"), + ))); + + assert_eq!( + base_manager + .auth() + .await + .and_then(|auth| auth.api_key().map(str::to_string)), + Some("base-token".to_string()) + ); + assert_eq!( + derived_manager + .auth() + .await + .and_then(|auth| auth.api_key().map(str::to_string)), + Some("provider-token".to_string()) + ); +} + +#[tokio::test] +async fn auth_manager_with_external_bearer_refresher_does_not_fallback_to_base_auth_when_resolve_fails() + { + let base_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("base-token")); + + let none_manager = + base_manager.with_external_bearer_refresher(Arc::new(StaticExternalAuthRefresher::new( + None, + ExternalAuthTokens::access_token_only("refreshed-provider-token"), + ))); + let err_manager = base_manager.with_external_bearer_refresher(Arc::new( + StaticExternalAuthRefresher::new_failing( + "boom", + ExternalAuthTokens::access_token_only("refreshed-provider-token"), + ), + )); + + assert_eq!(none_manager.auth().await, None); + assert_eq!(err_manager.auth().await, None); +} + +#[tokio::test] +async fn unauthorized_recovery_uses_external_refresh_for_bearer_manager() { + let base_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("base-token")); + let refresher = Arc::new(StaticExternalAuthRefresher::new( + Some(ExternalAuthTokens::access_token_only("provider-token")), + ExternalAuthTokens::access_token_only("refreshed-provider-token"), + )); + let derived_manager = base_manager.with_external_bearer_refresher(refresher.clone()); + let mut recovery = derived_manager.unauthorized_recovery(); + + assert!(recovery.has_next()); + assert_eq!(recovery.mode_name(), "external"); + assert_eq!(recovery.step_name(), "external_refresh"); + + let result = recovery + .next() + .await + .expect("external refresh should succeed"); + + assert_eq!(result.auth_state_changed(), Some(true)); + assert_eq!(*refresher.refresh_calls.lock().unwrap(), 1); +} + +#[derive(Debug)] +struct StaticExternalAuthRefresher { + resolved: Option, + resolve_error: Option, + refreshed: ExternalAuthTokens, + refresh_calls: Mutex, +} + +impl StaticExternalAuthRefresher { + fn new(resolved: Option, refreshed: ExternalAuthTokens) -> Self { + Self { + resolved, + resolve_error: None, + refreshed, + refresh_calls: Mutex::new(0), + } + } + + fn new_failing(error: impl Into, refreshed: ExternalAuthTokens) -> Self { + Self { + resolved: None, + resolve_error: Some(error.into()), + refreshed, + refresh_calls: Mutex::new(0), + } + } +} + +#[async_trait] +impl ExternalAuthRefresher for StaticExternalAuthRefresher { + async fn resolve(&self) -> std::io::Result> { + if let Some(error) = &self.resolve_error { + return Err(std::io::Error::other(error.clone())); + } + Ok(self.resolved.clone()) + } + + async fn refresh( + &self, + _context: ExternalAuthRefreshContext, + ) -> std::io::Result { + *self.refresh_calls.lock().unwrap() += 1; + Ok(self.refreshed.clone()) + } +} + struct AuthFileParams { openai_api_key: Option, chatgpt_plan_type: Option, diff --git a/codex-rs/login/src/auth/manager.rs b/codex-rs/login/src/auth/manager.rs index 9d1080c073..3f6d18c903 100644 --- a/codex-rs/login/src/auth/manager.rs +++ b/codex-rs/login/src/auth/manager.rs @@ -840,8 +840,6 @@ impl AuthDotJson { #[derive(Clone)] struct CachedAuth { auth: Option, - /// Callback used to refresh external auth by asking the parent app for new tokens. - external_refresher: Option>, /// Permanent refresh failure cached for the current auth snapshot so /// later refresh attempts for the same credentials fail fast without network. permanent_refresh_failure: Option, @@ -853,6 +851,27 @@ struct AuthScopedRefreshFailure { error: RefreshTokenFailedError, } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum ExternalAuthKind { + Bearer, + Chatgpt, +} + +#[derive(Clone)] +struct ExternalAuthHandle { + kind: ExternalAuthKind, + refresher: Arc, +} + +impl Debug for ExternalAuthHandle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExternalAuthHandle") + .field("kind", &self.kind) + .field("refresher", &"present") + .finish() + } +} + impl Debug for CachedAuth { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("CachedAuth") @@ -860,10 +879,6 @@ impl Debug for CachedAuth { "auth_mode", &self.auth.as_ref().map(CodexAuth::api_auth_mode), ) - .field( - "external_refresher", - &self.external_refresher.as_ref().map(|_| "present"), - ) .field( "permanent_refresh_failure", &self @@ -907,9 +922,14 @@ enum UnauthorizedRecoveryMode { // 2. Attempt to refresh the token using OAuth token refresh flow. // If after both steps the server still responds with 401 we let the error bubble to the user. // -// For external ChatGPT auth tokens (chatgptAuthTokens), UnauthorizedRecovery does not touch disk or refresh -// tokens locally. Instead it calls the ExternalAuthRefresher (account/chatgptAuthTokens/refresh) to ask the -// parent app for new tokens, stores them in the ephemeral auth store, and retries once. +// For external auth sources, UnauthorizedRecovery delegates to the configured +// ExternalAuthRefresher and retries once. +// +// - External ChatGPT auth tokens (`chatgptAuthTokens`) are refreshed by asking +// the parent app for new tokens, persisting them in the ephemeral auth +// store, and reloading the cached auth snapshot. +// - External bearer auth sources resolve bearer-only tokens for custom model +// providers and refresh them without touching disk. pub struct UnauthorizedRecovery { manager: Arc, step: UnauthorizedRecoveryStep, @@ -932,9 +952,10 @@ impl UnauthorizedRecovery { fn new(manager: Arc) -> Self { let cached_auth = manager.auth_cached(); let expected_account_id = cached_auth.as_ref().and_then(CodexAuth::get_account_id); - let mode = if cached_auth - .as_ref() - .is_some_and(CodexAuth::is_external_chatgpt_tokens) + let mode = if manager.external_auth_kind() == Some(ExternalAuthKind::Bearer) + || cached_auth + .as_ref() + .is_some_and(CodexAuth::is_external_chatgpt_tokens) { UnauthorizedRecoveryMode::External } else { @@ -953,6 +974,10 @@ impl UnauthorizedRecovery { } pub fn has_next(&self) -> bool { + if self.manager.external_auth_kind() == Some(ExternalAuthKind::Bearer) { + return !matches!(self.step, UnauthorizedRecoveryStep::Done); + } + if !self .manager .auth_cached() @@ -972,6 +997,16 @@ impl UnauthorizedRecovery { } pub fn unavailable_reason(&self) -> &'static str { + if self.manager.external_auth_kind() == Some(ExternalAuthKind::Bearer) { + return if matches!(self.step, UnauthorizedRecoveryStep::Done) { + "recovery_exhausted" + } else if self.manager.has_external_auth_refresher() { + "ready" + } else { + "no_external_refresher" + }; + } + if !self .manager .auth_cached() @@ -1080,11 +1115,12 @@ impl UnauthorizedRecovery { #[derive(Debug)] pub struct AuthManager { codex_home: PathBuf, - inner: RwLock, + inner: Arc>, enable_codex_api_key_env: bool, auth_credentials_store_mode: AuthCredentialsStoreMode, - forced_chatgpt_workspace_id: RwLock>, - refresh_lock: AsyncMutex<()>, + forced_chatgpt_workspace_id: Arc>>, + refresh_lock: Arc>, + external_auth: RwLock>, } impl AuthManager { @@ -1106,15 +1142,15 @@ impl AuthManager { .flatten(); Self { codex_home, - inner: RwLock::new(CachedAuth { + inner: Arc::new(RwLock::new(CachedAuth { auth: managed_auth, - external_refresher: None, permanent_refresh_failure: None, - }), + })), enable_codex_api_key_env, auth_credentials_store_mode, - forced_chatgpt_workspace_id: RwLock::new(None), - refresh_lock: AsyncMutex::new(()), + forced_chatgpt_workspace_id: Arc::new(RwLock::new(None)), + refresh_lock: Arc::new(AsyncMutex::new(())), + external_auth: RwLock::new(None), } } @@ -1122,17 +1158,17 @@ impl AuthManager { pub fn from_auth_for_testing(auth: CodexAuth) -> Arc { let cached = CachedAuth { auth: Some(auth), - external_refresher: None, permanent_refresh_failure: None, }; Arc::new(Self { codex_home: PathBuf::from("non-existent"), - inner: RwLock::new(cached), + inner: Arc::new(RwLock::new(cached)), enable_codex_api_key_env: false, auth_credentials_store_mode: AuthCredentialsStoreMode::File, - forced_chatgpt_workspace_id: RwLock::new(None), - refresh_lock: AsyncMutex::new(()), + forced_chatgpt_workspace_id: Arc::new(RwLock::new(None)), + refresh_lock: Arc::new(AsyncMutex::new(())), + external_auth: RwLock::new(None), }) } @@ -1140,16 +1176,16 @@ impl AuthManager { pub fn from_auth_for_testing_with_home(auth: CodexAuth, codex_home: PathBuf) -> Arc { let cached = CachedAuth { auth: Some(auth), - external_refresher: None, permanent_refresh_failure: None, }; Arc::new(Self { codex_home, - inner: RwLock::new(cached), + inner: Arc::new(RwLock::new(cached)), enable_codex_api_key_env: false, auth_credentials_store_mode: AuthCredentialsStoreMode::File, - forced_chatgpt_workspace_id: RwLock::new(None), - refresh_lock: AsyncMutex::new(()), + forced_chatgpt_workspace_id: Arc::new(RwLock::new(None)), + refresh_lock: Arc::new(AsyncMutex::new(())), + external_auth: RwLock::new(None), }) } @@ -1172,6 +1208,13 @@ impl AuthManager { /// For stale managed ChatGPT auth, first performs a guarded reload and then /// refreshes only if the on-disk auth is unchanged. pub async fn auth(&self) -> Option { + if self.external_auth_kind() == Some(ExternalAuthKind::Bearer) { + return self.resolve_external_api_key_auth().await; + } + if let Some(auth) = self.resolve_external_api_key_auth().await { + return Some(auth); + } + let auth = self.auth_cached()?; if Self::is_stale_for_proactive_refresh(&auth) && let Err(err) = self.refresh_token().await @@ -1292,17 +1335,38 @@ impl AuthManager { } pub fn set_external_auth_refresher(&self, refresher: Arc) { - if let Ok(mut guard) = self.inner.write() { - guard.external_refresher = Some(refresher); + if let Ok(mut guard) = self.external_auth.write() { + *guard = Some(ExternalAuthHandle { + kind: ExternalAuthKind::Chatgpt, + refresher, + }); } } pub fn clear_external_auth_refresher(&self) { - if let Ok(mut guard) = self.inner.write() { - guard.external_refresher = None; + if let Ok(mut guard) = self.external_auth.write() { + *guard = None; } } + pub fn with_external_bearer_refresher( + self: &Arc, + refresher: Arc, + ) -> Arc { + Arc::new(Self { + codex_home: self.codex_home.clone(), + inner: Arc::clone(&self.inner), + enable_codex_api_key_env: self.enable_codex_api_key_env, + auth_credentials_store_mode: self.auth_credentials_store_mode, + forced_chatgpt_workspace_id: Arc::clone(&self.forced_chatgpt_workspace_id), + refresh_lock: Arc::clone(&self.refresh_lock), + external_auth: RwLock::new(Some(ExternalAuthHandle { + kind: ExternalAuthKind::Bearer, + refresher, + })), + }) + } + pub fn set_forced_chatgpt_workspace_id(&self, workspace_id: Option) { if let Ok(mut guard) = self.forced_chatgpt_workspace_id.write() { *guard = workspace_id; @@ -1317,13 +1381,17 @@ impl AuthManager { } pub fn has_external_auth_refresher(&self) -> bool { - self.inner + self.external_auth .read() .ok() - .map(|guard| guard.external_refresher.is_some()) + .map(|guard| guard.is_some()) .unwrap_or(false) } + pub fn has_external_bearer_refresher(&self) -> bool { + self.external_auth_kind() == Some(ExternalAuthKind::Bearer) + } + pub fn is_external_auth_active(&self) -> bool { self.auth_cached() .as_ref() @@ -1351,6 +1419,33 @@ impl AuthManager { UnauthorizedRecovery::new(Arc::clone(self)) } + fn external_auth_handle(&self) -> Option { + self.external_auth + .read() + .ok() + .and_then(|guard| guard.clone()) + } + + fn external_auth_kind(&self) -> Option { + self.external_auth_handle().map(|handle| handle.kind) + } + + async fn resolve_external_api_key_auth(&self) -> Option { + let handle = self.external_auth_handle()?; + if handle.kind != ExternalAuthKind::Bearer { + return None; + } + + match handle.refresher.resolve().await { + Ok(Some(tokens)) => Some(CodexAuth::from_api_key(&tokens.access_token)), + Ok(None) => None, + Err(err) => { + tracing::error!("Failed to resolve external bearer auth: {err}"); + None + } + } + } + /// Attempt to refresh the token by first performing a guarded reload. Auth /// is reloaded from storage only when the account id matches the currently /// cached account id. If the persisted token differs from the cached token, we @@ -1473,16 +1568,7 @@ impl AuthManager { reason: ExternalAuthRefreshReason, ) -> Result<(), RefreshTokenError> { let forced_chatgpt_workspace_id = self.forced_chatgpt_workspace_id(); - let refresher = match self.inner.read() { - Ok(guard) => guard.external_refresher.clone(), - Err(_) => { - return Err(RefreshTokenError::Transient(std::io::Error::other( - "failed to read external auth state", - ))); - } - }; - - let Some(refresher) = refresher else { + let Some(handle) = self.external_auth_handle() else { return Err(RefreshTokenError::Transient(std::io::Error::other( "external auth refresher is not configured", ))); @@ -1497,7 +1583,10 @@ impl AuthManager { previous_account_id, }; - let refreshed = refresher.refresh(context).await?; + let refreshed = handle.refresher.refresh(context).await?; + if handle.kind == ExternalAuthKind::Bearer { + return Ok(()); + } let Some(chatgpt_metadata) = refreshed.chatgpt_metadata() else { return Err(RefreshTokenError::Transient(std::io::Error::other( "external auth refresh did not return ChatGPT metadata",