mirror of
https://github.com/openai/codex.git
synced 2026-05-24 13:04:29 +00:00
[codex-login] avoid redundant proactive token refreshes [ci changed_files]
This commit is contained in:
@@ -1688,12 +1688,22 @@ impl AuthManager {
|
||||
/// can assume that some other instance already refreshed it. If the persisted
|
||||
/// token is the same as the cached, then ask the token authority to refresh.
|
||||
pub async fn refresh_token(&self) -> Result<(), RefreshTokenError> {
|
||||
let _refresh_guard = self.refresh_lock.acquire().await.map_err(|_| {
|
||||
let _refresh_guard = self.acquire_refresh_guard().await?;
|
||||
self.refresh_token_after_guarded_reload().await
|
||||
}
|
||||
|
||||
async fn acquire_refresh_guard(
|
||||
&self,
|
||||
) -> Result<tokio::sync::SemaphorePermit<'_>, RefreshTokenError> {
|
||||
self.refresh_lock.acquire().await.map_err(|_| {
|
||||
RefreshTokenError::Permanent(RefreshTokenFailedError::new(
|
||||
RefreshTokenFailedReason::Other,
|
||||
REFRESH_TOKEN_UNKNOWN_MESSAGE.to_string(),
|
||||
))
|
||||
})?;
|
||||
})
|
||||
}
|
||||
|
||||
async fn refresh_token_after_guarded_reload(&self) -> Result<(), RefreshTokenError> {
|
||||
let auth_before_reload = self.auth_cached();
|
||||
if auth_before_reload
|
||||
.as_ref()
|
||||
@@ -1725,7 +1735,15 @@ impl AuthManager {
|
||||
|
||||
async fn proactively_refresh_token(&self) -> Result<(), RefreshTokenError> {
|
||||
let _refresh_lock = self.acquire_chatgpt_proactive_refresh_lock().await?;
|
||||
self.refresh_token().await
|
||||
let _refresh_guard = self.acquire_refresh_guard().await?;
|
||||
if !self
|
||||
.auth_cached()
|
||||
.as_ref()
|
||||
.is_some_and(Self::should_refresh_proactively)
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
self.refresh_token_after_guarded_reload().await
|
||||
}
|
||||
|
||||
async fn acquire_chatgpt_proactive_refresh_lock(&self) -> Result<File, RefreshTokenError> {
|
||||
@@ -1779,12 +1797,7 @@ impl AuthManager {
|
||||
/// observe refreshed token. If the token refresh fails, returns the error to
|
||||
/// the caller.
|
||||
pub async fn refresh_token_from_authority(&self) -> Result<(), RefreshTokenError> {
|
||||
let _refresh_guard = self.refresh_lock.acquire().await.map_err(|_| {
|
||||
RefreshTokenError::Permanent(RefreshTokenFailedError::new(
|
||||
RefreshTokenFailedReason::Other,
|
||||
REFRESH_TOKEN_UNKNOWN_MESSAGE.to_string(),
|
||||
))
|
||||
})?;
|
||||
let _refresh_guard = self.acquire_refresh_guard().await?;
|
||||
self.refresh_token_from_authority_impl().await
|
||||
}
|
||||
|
||||
|
||||
@@ -321,6 +321,69 @@ async fn auth_waits_while_proactive_refresh_lock_is_held() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[serial_test::serial(auth_refresh)]
|
||||
#[tokio::test]
|
||||
async fn concurrent_auth_requests_share_one_proactive_refresh() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/oauth/token"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
||||
"access_token": "new-access-token",
|
||||
"refresh_token": "new-refresh-token"
|
||||
})))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let ctx = RefreshTokenTestContext::new(&server).await?;
|
||||
let near_expiry_access_token = access_token_with_expiration(Utc::now() + Duration::minutes(4));
|
||||
let initial_tokens = build_tokens(&near_expiry_access_token, INITIAL_REFRESH_TOKEN);
|
||||
ctx.write_auth(&AuthDotJson {
|
||||
auth_mode: Some(AuthMode::Chatgpt),
|
||||
openai_api_key: None,
|
||||
tokens: Some(initial_tokens.clone()),
|
||||
last_refresh: Some(Utc::now()),
|
||||
agent_identity: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let lock_file = ctx.hold_proactive_refresh_lock()?;
|
||||
let first_manager = Arc::clone(&ctx.auth_manager);
|
||||
let first = tokio::spawn(async move { first_manager.auth().await });
|
||||
let second_manager = Arc::clone(&ctx.auth_manager);
|
||||
let second = tokio::spawn(async move { second_manager.auth().await });
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
assert!(!first.is_finished());
|
||||
assert!(!second.is_finished());
|
||||
drop(lock_file);
|
||||
|
||||
let (first_auth, second_auth) = tokio::join!(first, second);
|
||||
let refreshed_tokens = TokenData {
|
||||
access_token: "new-access-token".to_string(),
|
||||
refresh_token: "new-refresh-token".to_string(),
|
||||
..initial_tokens
|
||||
};
|
||||
for auth in [first_auth, second_auth] {
|
||||
let auth = auth
|
||||
.context("proactive refresh task should join")?
|
||||
.context("auth should stay cached after the lock is released")?;
|
||||
assert_eq!(
|
||||
auth.get_token_data().context("token data should refresh")?,
|
||||
refreshed_tokens
|
||||
);
|
||||
}
|
||||
assert_eq!(
|
||||
ctx.load_auth()?.tokens.context("tokens should exist")?,
|
||||
refreshed_tokens
|
||||
);
|
||||
server.verify().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[serial_test::serial(auth_refresh)]
|
||||
#[tokio::test]
|
||||
async fn refresh_token_does_not_wait_while_proactive_refresh_lock_is_held() -> Result<()> {
|
||||
|
||||
Reference in New Issue
Block a user