auth: let AuthManager own external bearer auth

This commit is contained in:
Michael Bolin
2026-03-30 16:27:55 -07:00
parent 4e7d4648ef
commit 0c6a595d1d
2 changed files with 254 additions and 45 deletions

View File

@@ -4,6 +4,7 @@ use crate::auth::storage::get_auth_file;
use crate::token_data::IdTokenInfo;
use crate::token_data::KnownPlan as InternalKnownPlan;
use crate::token_data::PlanType as InternalPlanType;
use async_trait::async_trait;
use codex_protocol::account::PlanType as AccountPlanType;
use base64::Engine;
@@ -12,6 +13,7 @@ use pretty_assertions::assert_eq;
use serde::Serialize;
use serde_json::json;
use std::sync::Arc;
use std::sync::Mutex;
use tempfile::tempdir;
#[tokio::test]
@@ -265,6 +267,122 @@ fn external_auth_tokens_without_chatgpt_metadata_cannot_seed_chatgpt_auth() {
);
}
#[tokio::test]
async fn auth_manager_with_external_bearer_refresher_returns_provider_token_only_for_derived_manager()
{
let base_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("base-token"));
let derived_manager =
base_manager.with_external_bearer_refresher(Arc::new(StaticExternalAuthRefresher::new(
Some(ExternalAuthTokens::access_token_only("provider-token")),
ExternalAuthTokens::access_token_only("refreshed-provider-token"),
)));
assert_eq!(
base_manager
.auth()
.await
.and_then(|auth| auth.api_key().map(str::to_string)),
Some("base-token".to_string())
);
assert_eq!(
derived_manager
.auth()
.await
.and_then(|auth| auth.api_key().map(str::to_string)),
Some("provider-token".to_string())
);
}
#[tokio::test]
async fn auth_manager_with_external_bearer_refresher_does_not_fallback_to_base_auth_when_resolve_fails()
{
let base_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("base-token"));
let none_manager =
base_manager.with_external_bearer_refresher(Arc::new(StaticExternalAuthRefresher::new(
None,
ExternalAuthTokens::access_token_only("refreshed-provider-token"),
)));
let err_manager = base_manager.with_external_bearer_refresher(Arc::new(
StaticExternalAuthRefresher::new_failing(
"boom",
ExternalAuthTokens::access_token_only("refreshed-provider-token"),
),
));
assert_eq!(none_manager.auth().await, None);
assert_eq!(err_manager.auth().await, None);
}
#[tokio::test]
async fn unauthorized_recovery_uses_external_refresh_for_bearer_manager() {
let base_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("base-token"));
let refresher = Arc::new(StaticExternalAuthRefresher::new(
Some(ExternalAuthTokens::access_token_only("provider-token")),
ExternalAuthTokens::access_token_only("refreshed-provider-token"),
));
let derived_manager = base_manager.with_external_bearer_refresher(refresher.clone());
let mut recovery = derived_manager.unauthorized_recovery();
assert!(recovery.has_next());
assert_eq!(recovery.mode_name(), "external");
assert_eq!(recovery.step_name(), "external_refresh");
let result = recovery
.next()
.await
.expect("external refresh should succeed");
assert_eq!(result.auth_state_changed(), Some(true));
assert_eq!(*refresher.refresh_calls.lock().unwrap(), 1);
}
#[derive(Debug)]
struct StaticExternalAuthRefresher {
resolved: Option<ExternalAuthTokens>,
resolve_error: Option<String>,
refreshed: ExternalAuthTokens,
refresh_calls: Mutex<usize>,
}
impl StaticExternalAuthRefresher {
fn new(resolved: Option<ExternalAuthTokens>, refreshed: ExternalAuthTokens) -> Self {
Self {
resolved,
resolve_error: None,
refreshed,
refresh_calls: Mutex::new(0),
}
}
fn new_failing(error: impl Into<String>, refreshed: ExternalAuthTokens) -> Self {
Self {
resolved: None,
resolve_error: Some(error.into()),
refreshed,
refresh_calls: Mutex::new(0),
}
}
}
#[async_trait]
impl ExternalAuthRefresher for StaticExternalAuthRefresher {
async fn resolve(&self) -> std::io::Result<Option<ExternalAuthTokens>> {
if let Some(error) = &self.resolve_error {
return Err(std::io::Error::other(error.clone()));
}
Ok(self.resolved.clone())
}
async fn refresh(
&self,
_context: ExternalAuthRefreshContext,
) -> std::io::Result<ExternalAuthTokens> {
*self.refresh_calls.lock().unwrap() += 1;
Ok(self.refreshed.clone())
}
}
struct AuthFileParams {
openai_api_key: Option<String>,
chatgpt_plan_type: Option<String>,

View File

@@ -840,8 +840,6 @@ impl AuthDotJson {
#[derive(Clone)]
struct CachedAuth {
auth: Option<CodexAuth>,
/// Callback used to refresh external auth by asking the parent app for new tokens.
external_refresher: Option<Arc<dyn ExternalAuthRefresher>>,
/// Permanent refresh failure cached for the current auth snapshot so
/// later refresh attempts for the same credentials fail fast without network.
permanent_refresh_failure: Option<AuthScopedRefreshFailure>,
@@ -853,6 +851,27 @@ struct AuthScopedRefreshFailure {
error: RefreshTokenFailedError,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum ExternalAuthKind {
Bearer,
Chatgpt,
}
#[derive(Clone)]
struct ExternalAuthHandle {
kind: ExternalAuthKind,
refresher: Arc<dyn ExternalAuthRefresher>,
}
impl Debug for ExternalAuthHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExternalAuthHandle")
.field("kind", &self.kind)
.field("refresher", &"present")
.finish()
}
}
impl Debug for CachedAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CachedAuth")
@@ -860,10 +879,6 @@ impl Debug for CachedAuth {
"auth_mode",
&self.auth.as_ref().map(CodexAuth::api_auth_mode),
)
.field(
"external_refresher",
&self.external_refresher.as_ref().map(|_| "present"),
)
.field(
"permanent_refresh_failure",
&self
@@ -907,9 +922,14 @@ enum UnauthorizedRecoveryMode {
// 2. Attempt to refresh the token using OAuth token refresh flow.
// If after both steps the server still responds with 401 we let the error bubble to the user.
//
// For external ChatGPT auth tokens (chatgptAuthTokens), UnauthorizedRecovery does not touch disk or refresh
// tokens locally. Instead it calls the ExternalAuthRefresher (account/chatgptAuthTokens/refresh) to ask the
// parent app for new tokens, stores them in the ephemeral auth store, and retries once.
// For external auth sources, UnauthorizedRecovery delegates to the configured
// ExternalAuthRefresher and retries once.
//
// - External ChatGPT auth tokens (`chatgptAuthTokens`) are refreshed by asking
// the parent app for new tokens, persisting them in the ephemeral auth
// store, and reloading the cached auth snapshot.
// - External bearer auth sources resolve bearer-only tokens for custom model
// providers and refresh them without touching disk.
pub struct UnauthorizedRecovery {
manager: Arc<AuthManager>,
step: UnauthorizedRecoveryStep,
@@ -932,9 +952,10 @@ impl UnauthorizedRecovery {
fn new(manager: Arc<AuthManager>) -> Self {
let cached_auth = manager.auth_cached();
let expected_account_id = cached_auth.as_ref().and_then(CodexAuth::get_account_id);
let mode = if cached_auth
.as_ref()
.is_some_and(CodexAuth::is_external_chatgpt_tokens)
let mode = if manager.external_auth_kind() == Some(ExternalAuthKind::Bearer)
|| cached_auth
.as_ref()
.is_some_and(CodexAuth::is_external_chatgpt_tokens)
{
UnauthorizedRecoveryMode::External
} else {
@@ -953,6 +974,10 @@ impl UnauthorizedRecovery {
}
pub fn has_next(&self) -> bool {
if self.manager.external_auth_kind() == Some(ExternalAuthKind::Bearer) {
return !matches!(self.step, UnauthorizedRecoveryStep::Done);
}
if !self
.manager
.auth_cached()
@@ -972,6 +997,16 @@ impl UnauthorizedRecovery {
}
pub fn unavailable_reason(&self) -> &'static str {
if self.manager.external_auth_kind() == Some(ExternalAuthKind::Bearer) {
return if matches!(self.step, UnauthorizedRecoveryStep::Done) {
"recovery_exhausted"
} else if self.manager.has_external_auth_refresher() {
"ready"
} else {
"no_external_refresher"
};
}
if !self
.manager
.auth_cached()
@@ -1080,11 +1115,12 @@ impl UnauthorizedRecovery {
#[derive(Debug)]
pub struct AuthManager {
codex_home: PathBuf,
inner: RwLock<CachedAuth>,
inner: Arc<RwLock<CachedAuth>>,
enable_codex_api_key_env: bool,
auth_credentials_store_mode: AuthCredentialsStoreMode,
forced_chatgpt_workspace_id: RwLock<Option<String>>,
refresh_lock: AsyncMutex<()>,
forced_chatgpt_workspace_id: Arc<RwLock<Option<String>>>,
refresh_lock: Arc<AsyncMutex<()>>,
external_auth: RwLock<Option<ExternalAuthHandle>>,
}
impl AuthManager {
@@ -1106,15 +1142,15 @@ impl AuthManager {
.flatten();
Self {
codex_home,
inner: RwLock::new(CachedAuth {
inner: Arc::new(RwLock::new(CachedAuth {
auth: managed_auth,
external_refresher: None,
permanent_refresh_failure: None,
}),
})),
enable_codex_api_key_env,
auth_credentials_store_mode,
forced_chatgpt_workspace_id: RwLock::new(None),
refresh_lock: AsyncMutex::new(()),
forced_chatgpt_workspace_id: Arc::new(RwLock::new(None)),
refresh_lock: Arc::new(AsyncMutex::new(())),
external_auth: RwLock::new(None),
}
}
@@ -1122,17 +1158,17 @@ impl AuthManager {
pub fn from_auth_for_testing(auth: CodexAuth) -> Arc<Self> {
let cached = CachedAuth {
auth: Some(auth),
external_refresher: None,
permanent_refresh_failure: None,
};
Arc::new(Self {
codex_home: PathBuf::from("non-existent"),
inner: RwLock::new(cached),
inner: Arc::new(RwLock::new(cached)),
enable_codex_api_key_env: false,
auth_credentials_store_mode: AuthCredentialsStoreMode::File,
forced_chatgpt_workspace_id: RwLock::new(None),
refresh_lock: AsyncMutex::new(()),
forced_chatgpt_workspace_id: Arc::new(RwLock::new(None)),
refresh_lock: Arc::new(AsyncMutex::new(())),
external_auth: RwLock::new(None),
})
}
@@ -1140,16 +1176,16 @@ impl AuthManager {
pub fn from_auth_for_testing_with_home(auth: CodexAuth, codex_home: PathBuf) -> Arc<Self> {
let cached = CachedAuth {
auth: Some(auth),
external_refresher: None,
permanent_refresh_failure: None,
};
Arc::new(Self {
codex_home,
inner: RwLock::new(cached),
inner: Arc::new(RwLock::new(cached)),
enable_codex_api_key_env: false,
auth_credentials_store_mode: AuthCredentialsStoreMode::File,
forced_chatgpt_workspace_id: RwLock::new(None),
refresh_lock: AsyncMutex::new(()),
forced_chatgpt_workspace_id: Arc::new(RwLock::new(None)),
refresh_lock: Arc::new(AsyncMutex::new(())),
external_auth: RwLock::new(None),
})
}
@@ -1172,6 +1208,13 @@ impl AuthManager {
/// For stale managed ChatGPT auth, first performs a guarded reload and then
/// refreshes only if the on-disk auth is unchanged.
pub async fn auth(&self) -> Option<CodexAuth> {
if self.external_auth_kind() == Some(ExternalAuthKind::Bearer) {
return self.resolve_external_api_key_auth().await;
}
if let Some(auth) = self.resolve_external_api_key_auth().await {
return Some(auth);
}
let auth = self.auth_cached()?;
if Self::is_stale_for_proactive_refresh(&auth)
&& let Err(err) = self.refresh_token().await
@@ -1292,17 +1335,38 @@ impl AuthManager {
}
pub fn set_external_auth_refresher(&self, refresher: Arc<dyn ExternalAuthRefresher>) {
if let Ok(mut guard) = self.inner.write() {
guard.external_refresher = Some(refresher);
if let Ok(mut guard) = self.external_auth.write() {
*guard = Some(ExternalAuthHandle {
kind: ExternalAuthKind::Chatgpt,
refresher,
});
}
}
pub fn clear_external_auth_refresher(&self) {
if let Ok(mut guard) = self.inner.write() {
guard.external_refresher = None;
if let Ok(mut guard) = self.external_auth.write() {
*guard = None;
}
}
pub fn with_external_bearer_refresher(
self: &Arc<Self>,
refresher: Arc<dyn ExternalAuthRefresher>,
) -> Arc<Self> {
Arc::new(Self {
codex_home: self.codex_home.clone(),
inner: Arc::clone(&self.inner),
enable_codex_api_key_env: self.enable_codex_api_key_env,
auth_credentials_store_mode: self.auth_credentials_store_mode,
forced_chatgpt_workspace_id: Arc::clone(&self.forced_chatgpt_workspace_id),
refresh_lock: Arc::clone(&self.refresh_lock),
external_auth: RwLock::new(Some(ExternalAuthHandle {
kind: ExternalAuthKind::Bearer,
refresher,
})),
})
}
pub fn set_forced_chatgpt_workspace_id(&self, workspace_id: Option<String>) {
if let Ok(mut guard) = self.forced_chatgpt_workspace_id.write() {
*guard = workspace_id;
@@ -1317,13 +1381,17 @@ impl AuthManager {
}
pub fn has_external_auth_refresher(&self) -> bool {
self.inner
self.external_auth
.read()
.ok()
.map(|guard| guard.external_refresher.is_some())
.map(|guard| guard.is_some())
.unwrap_or(false)
}
pub fn has_external_bearer_refresher(&self) -> bool {
self.external_auth_kind() == Some(ExternalAuthKind::Bearer)
}
pub fn is_external_auth_active(&self) -> bool {
self.auth_cached()
.as_ref()
@@ -1351,6 +1419,35 @@ impl AuthManager {
UnauthorizedRecovery::new(Arc::clone(self))
}
fn external_auth_handle(&self) -> Option<ExternalAuthHandle> {
self.external_auth
.read()
.ok()
.and_then(|guard| guard.clone())
}
fn external_auth_kind(&self) -> Option<ExternalAuthKind> {
self.external_auth_handle().map(|handle| handle.kind)
}
async fn resolve_external_api_key_auth(&self) -> Option<CodexAuth> {
let Some(handle) = self.external_auth_handle() else {
return None;
};
if handle.kind != ExternalAuthKind::Bearer {
return None;
}
match handle.refresher.resolve().await {
Ok(Some(tokens)) => Some(CodexAuth::from_api_key(&tokens.access_token)),
Ok(None) => None,
Err(err) => {
tracing::error!("Failed to resolve external bearer auth: {err}");
None
}
}
}
/// Attempt to refresh the token by first performing a guarded reload. Auth
/// is reloaded from storage only when the account id matches the currently
/// cached account id. If the persisted token differs from the cached token, we
@@ -1473,16 +1570,7 @@ impl AuthManager {
reason: ExternalAuthRefreshReason,
) -> Result<(), RefreshTokenError> {
let forced_chatgpt_workspace_id = self.forced_chatgpt_workspace_id();
let refresher = match self.inner.read() {
Ok(guard) => guard.external_refresher.clone(),
Err(_) => {
return Err(RefreshTokenError::Transient(std::io::Error::other(
"failed to read external auth state",
)));
}
};
let Some(refresher) = refresher else {
let Some(handle) = self.external_auth_handle() else {
return Err(RefreshTokenError::Transient(std::io::Error::other(
"external auth refresher is not configured",
)));
@@ -1497,7 +1585,10 @@ impl AuthManager {
previous_account_id,
};
let refreshed = refresher.refresh(context).await?;
let refreshed = handle.refresher.refresh(context).await?;
if handle.kind == ExternalAuthKind::Bearer {
return Ok(());
}
let Some(chatgpt_metadata) = refreshed.chatgpt_metadata() else {
return Err(RefreshTokenError::Transient(std::io::Error::other(
"external auth refresh did not return ChatGPT metadata",