Compare commits

...

2 Commits

2 changed files with 418 additions and 18 deletions

View File

@@ -7,6 +7,10 @@ use serde::Serialize;
use serial_test::serial;
use std::env;
use std::fmt::Debug;
use std::fs::File;
use std::fs::OpenOptions;
#[cfg(unix)]
use std::os::unix::fs::OpenOptionsExt;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
@@ -84,6 +88,9 @@ struct ChatgptAuthState {
}
const TOKEN_REFRESH_INTERVAL: i64 = 8;
const CHATGPT_ACCESS_TOKEN_REFRESH_WINDOW_MINUTES: i64 = 5;
const CHATGPT_TOKEN_REFRESH_LOCK_FILENAME: &str = "chatgpt-token-refresh.lock";
const CHATGPT_TOKEN_REFRESH_LOCK_POLL_INTERVAL_MS: u64 = 50;
const REFRESH_TOKEN_EXPIRED_MESSAGE: &str = "Your access token could not be refreshed because your refresh token has expired. Please log out and sign in again.";
const REFRESH_TOKEN_REUSED_MESSAGE: &str = "Your access token could not be refreshed because your refresh token was already used. Please log out and sign in again.";
@@ -1421,16 +1428,16 @@ impl AuthManager {
}
/// Current cached auth (clone). May be `None` if not logged in or load failed.
/// For stale managed ChatGPT auth, first performs a guarded reload and then
/// refreshes only if the on-disk auth is unchanged.
/// For managed ChatGPT auth that needs a proactive refresh, first performs
/// a guarded reload and then refreshes only if the on-disk auth is unchanged.
pub async fn auth(&self) -> Option<CodexAuth> {
if let Some(auth) = self.resolve_external_api_key_auth().await {
return Some(auth);
}
let auth = self.auth_cached()?;
if Self::is_stale_for_proactive_refresh(&auth)
&& let Err(err) = self.refresh_token().await
if Self::should_refresh_proactively(&auth)
&& let Err(err) = self.proactively_refresh_token().await
{
tracing::error!("Failed to refresh token: {}", err);
return Some(auth);
@@ -1680,12 +1687,37 @@ impl AuthManager {
/// can assume that some other instance already refreshed it. If the persisted
/// token is the same as the cached, then ask the token authority to refresh.
pub async fn refresh_token(&self) -> Result<(), RefreshTokenError> {
let _refresh_guard = self.refresh_lock.acquire().await.map_err(|_| {
let auth_before_wait = self.auth_cached();
let is_managed_chatgpt = matches!(auth_before_wait.as_ref(), Some(CodexAuth::Chatgpt(_)));
let _token_refresh_lock = if is_managed_chatgpt {
Some(self.acquire_chatgpt_token_refresh_lock().await?)
} else {
None
};
let _refresh_guard = self.acquire_refresh_guard().await?;
if is_managed_chatgpt
&& !Self::auths_equal_for_refresh(
auth_before_wait.as_ref(),
self.auth_cached().as_ref(),
)
{
return Ok(());
}
self.refresh_token_after_guarded_reload().await
}
async fn acquire_refresh_guard(
&self,
) -> Result<tokio::sync::SemaphorePermit<'_>, RefreshTokenError> {
self.refresh_lock.acquire().await.map_err(|_| {
RefreshTokenError::Permanent(RefreshTokenFailedError::new(
RefreshTokenFailedReason::Other,
REFRESH_TOKEN_UNKNOWN_MESSAGE.to_string(),
))
})?;
})
}
async fn refresh_token_after_guarded_reload(&self) -> Result<(), RefreshTokenError> {
let auth_before_reload = self.auth_cached();
if auth_before_reload
.as_ref()
@@ -1715,17 +1747,71 @@ impl AuthManager {
}
}
/// Attempt to refresh the current auth token from the authority that issued
/// the token. On success, reloads the auth state from disk so other components
/// observe refreshed token. If the token refresh fails, returns the error to
/// the caller.
pub async fn refresh_token_from_authority(&self) -> Result<(), RefreshTokenError> {
let _refresh_guard = self.refresh_lock.acquire().await.map_err(|_| {
RefreshTokenError::Permanent(RefreshTokenFailedError::new(
RefreshTokenFailedReason::Other,
REFRESH_TOKEN_UNKNOWN_MESSAGE.to_string(),
async fn proactively_refresh_token(&self) -> Result<(), RefreshTokenError> {
let _refresh_lock = self.acquire_chatgpt_token_refresh_lock().await?;
let _refresh_guard = self.acquire_refresh_guard().await?;
if !self
.auth_cached()
.as_ref()
.is_some_and(Self::should_refresh_proactively)
{
return Ok(());
}
self.refresh_token_after_guarded_reload().await
}
async fn acquire_chatgpt_token_refresh_lock(&self) -> Result<File, RefreshTokenError> {
let mut logged_wait = false;
loop {
if let Some(lock_file) = self.try_acquire_chatgpt_token_refresh_lock()? {
return Ok(lock_file);
}
if !logged_wait {
tracing::info!(
"Waiting to refresh managed ChatGPT tokens because another process is already refreshing them."
);
logged_wait = true;
}
tokio::time::sleep(std::time::Duration::from_millis(
CHATGPT_TOKEN_REFRESH_LOCK_POLL_INTERVAL_MS,
))
})?;
.await;
}
}
fn try_acquire_chatgpt_token_refresh_lock(&self) -> Result<Option<File>, RefreshTokenError> {
let lock_path = self.codex_home.join(CHATGPT_TOKEN_REFRESH_LOCK_FILENAME);
if let Some(parent) = lock_path.parent() {
std::fs::create_dir_all(parent).map_err(RefreshTokenError::Transient)?;
}
let mut options = OpenOptions::new();
options.read(true).write(true).create(true).truncate(false);
#[cfg(unix)]
{
options.mode(0o600);
}
let lock_file = options
.open(lock_path)
.map_err(RefreshTokenError::Transient)?;
match lock_file.try_lock() {
Ok(()) => Ok(Some(lock_file)),
Err(std::fs::TryLockError::WouldBlock) => Ok(None),
Err(err) => Err(RefreshTokenError::Transient(err.into())),
}
}
/// Attempt to refresh the current auth token from the authority that issued
/// the token. Managed ChatGPT auth reuses tokens refreshed by another process
/// while waiting to serialize refresh-token rotation. On success, reloads the
/// auth state from disk so other components observe refreshed token. If the
/// token refresh fails, returns the error to the caller.
pub async fn refresh_token_from_authority(&self) -> Result<(), RefreshTokenError> {
if matches!(self.auth_cached(), Some(CodexAuth::Chatgpt(_))) {
return self.refresh_token().await;
}
let _refresh_guard = self.acquire_refresh_guard().await?;
self.refresh_token_from_authority_impl().await
}
@@ -1808,7 +1894,7 @@ impl AuthManager {
)
}
fn is_stale_for_proactive_refresh(auth: &CodexAuth) -> bool {
fn should_refresh_proactively(auth: &CodexAuth) -> bool {
let chatgpt_auth = match auth {
CodexAuth::Chatgpt(chatgpt_auth) => chatgpt_auth,
_ => return false,
@@ -1821,7 +1907,9 @@ impl AuthManager {
if let Some(tokens) = auth_dot_json.tokens.as_ref()
&& let Ok(Some(expires_at)) = parse_jwt_expiration(&tokens.access_token)
{
return expires_at <= Utc::now();
return expires_at
<= Utc::now()
+ chrono::Duration::minutes(CHATGPT_ACCESS_TOKEN_REFRESH_WINDOW_MINUTES);
}
let last_refresh = match auth_dot_json.last_refresh {
Some(last_refresh) => last_refresh,

View File

@@ -19,6 +19,7 @@ use pretty_assertions::assert_eq;
use serde::Serialize;
use serde_json::json;
use std::ffi::OsString;
use std::fs::File;
use std::sync::Arc;
use tempfile::TempDir;
use wiremock::Mock;
@@ -158,6 +159,305 @@ async fn refresh_token_refreshes_when_auth_is_unchanged() -> Result<()> {
Ok(())
}
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn auth_refreshes_when_access_token_is_near_expiry() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/oauth/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"access_token": "new-access-token",
"refresh_token": "new-refresh-token"
})))
.expect(1)
.mount(&server)
.await;
let ctx = RefreshTokenTestContext::new(&server).await?;
let initial_last_refresh = Utc::now();
let near_expiry_access_token = access_token_with_expiration(Utc::now() + Duration::minutes(4));
let initial_tokens = build_tokens(&near_expiry_access_token, INITIAL_REFRESH_TOKEN);
let initial_auth = AuthDotJson {
auth_mode: Some(AuthMode::Chatgpt),
openai_api_key: None,
tokens: Some(initial_tokens.clone()),
last_refresh: Some(initial_last_refresh),
agent_identity: None,
};
ctx.write_auth(&initial_auth).await?;
let cached_auth = ctx
.auth_manager
.auth()
.await
.context("auth should be cached")?;
let refreshed_tokens = TokenData {
access_token: "new-access-token".to_string(),
refresh_token: "new-refresh-token".to_string(),
..initial_tokens.clone()
};
let cached = cached_auth
.get_token_data()
.context("token data should refresh")?;
assert_eq!(cached, refreshed_tokens);
let stored = ctx.load_auth()?;
let tokens = stored.tokens.as_ref().context("tokens should exist")?;
assert_eq!(tokens, &refreshed_tokens);
let refreshed_at = stored
.last_refresh
.as_ref()
.context("last_refresh should be recorded")?;
assert!(
*refreshed_at >= initial_last_refresh,
"last_refresh should advance"
);
server.verify().await;
Ok(())
}
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn auth_skips_access_token_outside_refresh_window() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
let ctx = RefreshTokenTestContext::new(&server).await?;
let initial_last_refresh = Utc::now();
let fresh_access_token = access_token_with_expiration(Utc::now() + Duration::minutes(6));
let initial_tokens = build_tokens(&fresh_access_token, INITIAL_REFRESH_TOKEN);
let initial_auth = AuthDotJson {
auth_mode: Some(AuthMode::Chatgpt),
openai_api_key: None,
tokens: Some(initial_tokens.clone()),
last_refresh: Some(initial_last_refresh),
agent_identity: None,
};
ctx.write_auth(&initial_auth).await?;
let cached_auth = ctx
.auth_manager
.auth()
.await
.context("auth should be cached")?;
let cached = cached_auth
.get_token_data()
.context("token data should remain cached")?;
assert_eq!(cached, initial_tokens);
assert_eq!(ctx.load_auth()?, initial_auth);
let requests = server.received_requests().await.unwrap_or_default();
assert!(requests.is_empty(), "expected no refresh token requests");
Ok(())
}
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn auth_waits_while_chatgpt_token_refresh_lock_is_held() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/oauth/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"access_token": "new-access-token",
"refresh_token": "new-refresh-token"
})))
.expect(1)
.mount(&server)
.await;
let ctx = RefreshTokenTestContext::new(&server).await?;
let initial_last_refresh = Utc::now();
let expired_access_token = access_token_with_expiration(Utc::now() - Duration::minutes(1));
let initial_tokens = build_tokens(&expired_access_token, INITIAL_REFRESH_TOKEN);
let initial_auth = AuthDotJson {
auth_mode: Some(AuthMode::Chatgpt),
openai_api_key: None,
tokens: Some(initial_tokens.clone()),
last_refresh: Some(initial_last_refresh),
agent_identity: None,
};
ctx.write_auth(&initial_auth).await?;
let lock_file = ctx.hold_chatgpt_token_refresh_lock()?;
let auth_manager = Arc::clone(&ctx.auth_manager);
let refresh_task = tokio::spawn(async move { auth_manager.auth().await });
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(
!refresh_task.is_finished(),
"auth should wait while another process holds the proactive refresh lock"
);
assert_eq!(ctx.load_auth()?, initial_auth);
let requests = server.received_requests().await.unwrap_or_default();
assert!(
requests.is_empty(),
"expected no refresh token requests before the proactive refresh lock is released"
);
drop(lock_file);
let cached_auth = refresh_task
.await
.context("proactive refresh task should join")?
.context("auth should stay cached after the lock is released")?;
let cached = cached_auth
.get_token_data()
.context("token data should refresh")?;
assert_eq!(cached.access_token, "new-access-token");
assert_eq!(cached.refresh_token, "new-refresh-token");
let stored = ctx.load_auth()?;
let tokens = stored.tokens.as_ref().context("tokens should exist")?;
assert_eq!(tokens.access_token, "new-access-token");
assert_eq!(tokens.refresh_token, "new-refresh-token");
server.verify().await;
Ok(())
}
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn concurrent_auth_requests_share_one_proactive_refresh() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/oauth/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"access_token": "new-access-token",
"refresh_token": "new-refresh-token"
})))
.expect(1)
.mount(&server)
.await;
let ctx = RefreshTokenTestContext::new(&server).await?;
let near_expiry_access_token = access_token_with_expiration(Utc::now() + Duration::minutes(4));
let initial_tokens = build_tokens(&near_expiry_access_token, INITIAL_REFRESH_TOKEN);
ctx.write_auth(&AuthDotJson {
auth_mode: Some(AuthMode::Chatgpt),
openai_api_key: None,
tokens: Some(initial_tokens.clone()),
last_refresh: Some(Utc::now()),
agent_identity: None,
})
.await?;
let lock_file = ctx.hold_chatgpt_token_refresh_lock()?;
let first_manager = Arc::clone(&ctx.auth_manager);
let first = tokio::spawn(async move { first_manager.auth().await });
let second_manager = Arc::clone(&ctx.auth_manager);
let second = tokio::spawn(async move { second_manager.auth().await });
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!first.is_finished());
assert!(!second.is_finished());
drop(lock_file);
let (first_auth, second_auth) = tokio::join!(first, second);
let refreshed_tokens = TokenData {
access_token: "new-access-token".to_string(),
refresh_token: "new-refresh-token".to_string(),
..initial_tokens
};
for auth in [first_auth, second_auth] {
let auth = auth
.context("proactive refresh task should join")?
.context("auth should stay cached after the lock is released")?;
assert_eq!(
auth.get_token_data().context("token data should refresh")?,
refreshed_tokens
);
}
assert_eq!(
ctx.load_auth()?.tokens.context("tokens should exist")?,
refreshed_tokens
);
server.verify().await;
Ok(())
}
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn refresh_token_reloads_managed_auth_after_waiting_for_token_refresh_lock() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/oauth/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"access_token": "unexpected-access-token",
"refresh_token": "unexpected-refresh-token"
})))
.expect(0)
.mount(&server)
.await;
let ctx = RefreshTokenTestContext::new(&server).await?;
let initial_last_refresh = Utc::now();
let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN);
ctx.write_auth(&AuthDotJson {
auth_mode: Some(AuthMode::Chatgpt),
openai_api_key: None,
tokens: Some(initial_tokens),
last_refresh: Some(initial_last_refresh),
agent_identity: None,
})
.await?;
let lock_file = ctx.hold_chatgpt_token_refresh_lock()?;
let auth_manager = Arc::clone(&ctx.auth_manager);
let refresh_task =
tokio::spawn(async move { auth_manager.refresh_token_from_authority().await });
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(
!refresh_task.is_finished(),
"managed refresh should wait while another process holds the token refresh lock"
);
let disk_tokens = build_tokens("disk-access-token", "disk-refresh-token");
let disk_auth = AuthDotJson {
auth_mode: Some(AuthMode::Chatgpt),
openai_api_key: None,
tokens: Some(disk_tokens.clone()),
last_refresh: Some(initial_last_refresh),
agent_identity: None,
};
save_auth(
ctx.codex_home.path(),
&disk_auth,
AuthCredentialsStoreMode::File,
)?;
drop(lock_file);
refresh_task
.await
.context("managed refresh task should join")?
.context("managed refresh should use newly persisted auth")?;
let stored = ctx.load_auth()?;
assert_eq!(stored, disk_auth);
let cached = ctx
.auth_manager
.auth_cached()
.context("auth should be cached")?
.get_token_data()
.context("token data should reload")?;
assert_eq!(cached, disk_tokens);
server.verify().await;
Ok(())
}
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn refresh_token_skips_refresh_when_auth_changed() -> Result<()> {
@@ -1010,6 +1310,18 @@ impl RefreshTokenTestContext {
self.auth_manager.reload().await;
Ok(())
}
fn hold_chatgpt_token_refresh_lock(&self) -> Result<File> {
let lock_path = self.codex_home.path().join("chatgpt-token-refresh.lock");
let lock_file = File::options()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(lock_path)?;
lock_file.try_lock()?;
Ok(lock_file)
}
}
struct EnvGuard {