Attempt to reload auth as a step in 401 recovery (#8880)

When authentication fails, first attempt to reload the auth from file
and then attempt to refresh it.
This commit is contained in:
pakrym-oai
2026-01-08 15:06:44 -08:00
committed by GitHub
parent be4364bb80
commit 62a73b6d58
3 changed files with 377 additions and 57 deletions

View File

@@ -16,8 +16,10 @@ use codex_core::token_data::TokenData;
use core_test_support::skip_if_no_network;
use pretty_assertions::assert_eq;
use serde::Serialize;
use serde_json::Value;
use serde_json::json;
use std::ffi::OsString;
use std::sync::Arc;
use tempfile::TempDir;
use wiremock::Mock;
use wiremock::MockServer;
@@ -54,12 +56,10 @@ async fn refresh_token_succeeds_updates_storage() -> Result<()> {
};
ctx.write_auth(&initial_auth)?;
let access = ctx
.auth_manager
ctx.auth_manager
.refresh_token()
.await
.context("refresh should succeed")?;
assert_eq!(access, Some("new-access-token".to_string()));
let refreshed_tokens = TokenData {
access_token: "new-access-token".to_string(),
@@ -294,9 +294,218 @@ async fn refresh_token_returns_transient_error_on_server_failure() -> Result<()>
Ok(())
}
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn unauthorized_recovery_reloads_then_refreshes_tokens() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/oauth/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"access_token": "recovered-access-token",
"refresh_token": "recovered-refresh-token"
})))
.expect(1)
.mount(&server)
.await;
let ctx = RefreshTokenTestContext::new(&server)?;
let initial_last_refresh = Utc::now() - Duration::days(1);
let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN);
let initial_auth = AuthDotJson {
openai_api_key: None,
tokens: Some(initial_tokens.clone()),
last_refresh: Some(initial_last_refresh),
};
ctx.write_auth(&initial_auth)?;
let disk_tokens = build_tokens("disk-access-token", "disk-refresh-token");
let disk_auth = AuthDotJson {
openai_api_key: None,
tokens: Some(disk_tokens.clone()),
last_refresh: Some(initial_last_refresh),
};
save_auth(
ctx.codex_home.path(),
&disk_auth,
AuthCredentialsStoreMode::File,
)?;
let cached_before = ctx
.auth_manager
.auth_cached()
.expect("auth should be cached");
let cached_before_tokens = cached_before
.get_token_data()
.context("token data should be cached")?;
assert_eq!(cached_before_tokens, initial_tokens);
let mut recovery = ctx.auth_manager.unauthorized_recovery();
assert!(recovery.has_next());
recovery.next().await?;
let cached_after = ctx
.auth_manager
.auth_cached()
.expect("auth should be cached after reload");
let cached_after_tokens = cached_after
.get_token_data()
.context("token data should reload")?;
assert_eq!(cached_after_tokens, disk_tokens);
let requests = server.received_requests().await.unwrap_or_default();
assert!(requests.is_empty(), "expected no refresh token requests");
recovery.next().await?;
let refreshed_tokens = TokenData {
access_token: "recovered-access-token".to_string(),
refresh_token: "recovered-refresh-token".to_string(),
..disk_tokens.clone()
};
let stored = ctx.load_auth()?;
let tokens = stored.tokens.as_ref().context("tokens should exist")?;
assert_eq!(tokens, &refreshed_tokens);
let cached_auth = ctx
.auth_manager
.auth()
.await
.expect("auth should be cached");
let cached_tokens = cached_auth
.get_token_data()
.context("token data should be cached")?;
assert_eq!(cached_tokens, refreshed_tokens);
assert!(!recovery.has_next());
server.verify().await;
Ok(())
}
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn unauthorized_recovery_skips_reload_on_account_mismatch() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/oauth/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"access_token": "recovered-access-token",
"refresh_token": "recovered-refresh-token"
})))
.expect(1)
.mount(&server)
.await;
let ctx = RefreshTokenTestContext::new(&server)?;
let initial_last_refresh = Utc::now() - Duration::days(1);
let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN);
let initial_auth = AuthDotJson {
openai_api_key: None,
tokens: Some(initial_tokens.clone()),
last_refresh: Some(initial_last_refresh),
};
ctx.write_auth(&initial_auth)?;
let mut disk_tokens = build_tokens("disk-access-token", "disk-refresh-token");
disk_tokens.account_id = Some("other-account".to_string());
let expected_tokens = TokenData {
access_token: "recovered-access-token".to_string(),
refresh_token: "recovered-refresh-token".to_string(),
..disk_tokens.clone()
};
let disk_auth = AuthDotJson {
openai_api_key: None,
tokens: Some(disk_tokens),
last_refresh: Some(initial_last_refresh),
};
save_auth(
ctx.codex_home.path(),
&disk_auth,
AuthCredentialsStoreMode::File,
)?;
let cached_before = ctx
.auth_manager
.auth_cached()
.expect("auth should be cached");
let cached_before_tokens = cached_before
.get_token_data()
.context("token data should be cached")?;
assert_eq!(cached_before_tokens, initial_tokens);
let mut recovery = ctx.auth_manager.unauthorized_recovery();
assert!(recovery.has_next());
recovery.next().await?;
let stored = ctx.load_auth()?;
let tokens = stored.tokens.as_ref().context("tokens should exist")?;
assert_eq!(tokens, &expected_tokens);
let requests = server.received_requests().await.unwrap_or_default();
let request = requests
.first()
.context("expected a refresh token request")?;
let body: Value =
serde_json::from_slice(&request.body).context("refresh request body should be json")?;
let refresh_token = body
.get("refresh_token")
.and_then(Value::as_str)
.context("refresh_token should be set")?;
assert_eq!(refresh_token, INITIAL_REFRESH_TOKEN);
let cached_after = ctx
.auth_manager
.auth()
.await
.context("auth should remain cached after refresh")?;
let cached_after_tokens = cached_after
.get_token_data()
.context("token data should reflect refreshed tokens")?;
assert_eq!(cached_after_tokens, expected_tokens);
assert!(!recovery.has_next());
server.verify().await;
Ok(())
}
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn unauthorized_recovery_requires_chatgpt_auth() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
let ctx = RefreshTokenTestContext::new(&server)?;
let auth = AuthDotJson {
openai_api_key: Some("sk-test".to_string()),
tokens: None,
last_refresh: None,
};
ctx.write_auth(&auth)?;
let mut recovery = ctx.auth_manager.unauthorized_recovery();
assert!(!recovery.has_next());
let err = recovery
.next()
.await
.err()
.context("recovery should fail")?;
assert_eq!(err.failed_reason(), Some(RefreshTokenFailedReason::Other));
let requests = server.received_requests().await.unwrap_or_default();
assert!(requests.is_empty(), "expected no refresh token requests");
Ok(())
}
struct RefreshTokenTestContext {
codex_home: TempDir,
auth_manager: AuthManager,
auth_manager: Arc<AuthManager>,
_env_guard: EnvGuard,
}
@@ -307,7 +516,7 @@ impl RefreshTokenTestContext {
let endpoint = format!("{}/oauth/token", server.uri());
let env_guard = EnvGuard::set(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR, endpoint);
let auth_manager = AuthManager::new(
let auth_manager = AuthManager::shared(
codex_home.path().to_path_buf(),
false,
AuthCredentialsStoreMode::File,