diff --git a/codex-rs/login/src/auth/manager.rs b/codex-rs/login/src/auth/manager.rs index 0aeaaac70c..e73e0e83b0 100644 --- a/codex-rs/login/src/auth/manager.rs +++ b/codex-rs/login/src/auth/manager.rs @@ -1688,12 +1688,22 @@ impl AuthManager { /// can assume that some other instance already refreshed it. If the persisted /// token is the same as the cached, then ask the token authority to refresh. pub async fn refresh_token(&self) -> Result<(), RefreshTokenError> { - let _refresh_guard = self.refresh_lock.acquire().await.map_err(|_| { + let _refresh_guard = self.acquire_refresh_guard().await?; + self.refresh_token_after_guarded_reload().await + } + + async fn acquire_refresh_guard( + &self, + ) -> Result, RefreshTokenError> { + self.refresh_lock.acquire().await.map_err(|_| { RefreshTokenError::Permanent(RefreshTokenFailedError::new( RefreshTokenFailedReason::Other, REFRESH_TOKEN_UNKNOWN_MESSAGE.to_string(), )) - })?; + }) + } + + async fn refresh_token_after_guarded_reload(&self) -> Result<(), RefreshTokenError> { let auth_before_reload = self.auth_cached(); if auth_before_reload .as_ref() @@ -1725,7 +1735,15 @@ impl AuthManager { async fn proactively_refresh_token(&self) -> Result<(), RefreshTokenError> { let _refresh_lock = self.acquire_chatgpt_proactive_refresh_lock().await?; - self.refresh_token().await + let _refresh_guard = self.acquire_refresh_guard().await?; + if !self + .auth_cached() + .as_ref() + .is_some_and(Self::should_refresh_proactively) + { + return Ok(()); + } + self.refresh_token_after_guarded_reload().await } async fn acquire_chatgpt_proactive_refresh_lock(&self) -> Result { @@ -1779,12 +1797,7 @@ impl AuthManager { /// observe refreshed token. If the token refresh fails, returns the error to /// the caller. pub async fn refresh_token_from_authority(&self) -> Result<(), RefreshTokenError> { - let _refresh_guard = self.refresh_lock.acquire().await.map_err(|_| { - RefreshTokenError::Permanent(RefreshTokenFailedError::new( - RefreshTokenFailedReason::Other, - REFRESH_TOKEN_UNKNOWN_MESSAGE.to_string(), - )) - })?; + let _refresh_guard = self.acquire_refresh_guard().await?; self.refresh_token_from_authority_impl().await } diff --git a/codex-rs/login/tests/suite/auth_refresh.rs b/codex-rs/login/tests/suite/auth_refresh.rs index 0408a2fa77..9f75e13c99 100644 --- a/codex-rs/login/tests/suite/auth_refresh.rs +++ b/codex-rs/login/tests/suite/auth_refresh.rs @@ -321,6 +321,69 @@ async fn auth_waits_while_proactive_refresh_lock_is_held() -> Result<()> { Ok(()) } +#[serial_test::serial(auth_refresh)] +#[tokio::test] +async fn concurrent_auth_requests_share_one_proactive_refresh() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/oauth/token")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "access_token": "new-access-token", + "refresh_token": "new-refresh-token" + }))) + .expect(1) + .mount(&server) + .await; + + let ctx = RefreshTokenTestContext::new(&server).await?; + let near_expiry_access_token = access_token_with_expiration(Utc::now() + Duration::minutes(4)); + let initial_tokens = build_tokens(&near_expiry_access_token, INITIAL_REFRESH_TOKEN); + ctx.write_auth(&AuthDotJson { + auth_mode: Some(AuthMode::Chatgpt), + openai_api_key: None, + tokens: Some(initial_tokens.clone()), + last_refresh: Some(Utc::now()), + agent_identity: None, + }) + .await?; + + let lock_file = ctx.hold_proactive_refresh_lock()?; + let first_manager = Arc::clone(&ctx.auth_manager); + let first = tokio::spawn(async move { first_manager.auth().await }); + let second_manager = Arc::clone(&ctx.auth_manager); + let second = tokio::spawn(async move { second_manager.auth().await }); + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + assert!(!first.is_finished()); + assert!(!second.is_finished()); + drop(lock_file); + + let (first_auth, second_auth) = tokio::join!(first, second); + let refreshed_tokens = TokenData { + access_token: "new-access-token".to_string(), + refresh_token: "new-refresh-token".to_string(), + ..initial_tokens + }; + for auth in [first_auth, second_auth] { + let auth = auth + .context("proactive refresh task should join")? + .context("auth should stay cached after the lock is released")?; + assert_eq!( + auth.get_token_data().context("token data should refresh")?, + refreshed_tokens + ); + } + assert_eq!( + ctx.load_auth()?.tokens.context("tokens should exist")?, + refreshed_tokens + ); + server.verify().await; + + Ok(()) +} + #[serial_test::serial(auth_refresh)] #[tokio::test] async fn refresh_token_does_not_wait_while_proactive_refresh_lock_is_held() -> Result<()> {