Address OAuth persistence review feedback

Keep local OAuth login flows on the no-proxy client while remote flows continue to use the runtime-selected HTTP client. Harden OAuth persistence so expires_in drift does not look like an external token update and transient keyring read failures do not clear in-memory credentials.

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Ahmed Ibrahim
2026-04-27 11:58:06 +00:00
parent 35080d5cd0
commit f8e2beeafa
4 changed files with 392 additions and 102 deletions

View File

@@ -14,7 +14,10 @@ use codex_rmcp_client::determine_streamable_http_auth_status;
use codex_rmcp_client::determine_streamable_http_auth_status_with_client;
use codex_rmcp_client::discover_streamable_http_oauth;
use codex_rmcp_client::discover_streamable_http_oauth_with_client;
use codex_rmcp_client::perform_oauth_login;
use codex_rmcp_client::perform_oauth_login_return_url;
use codex_rmcp_client::perform_oauth_login_return_url_with_client;
use codex_rmcp_client::perform_oauth_login_silent;
use codex_rmcp_client::perform_oauth_login_silent_with_client;
use codex_rmcp_client::perform_oauth_login_with_client;
use futures::future::join_all;
@@ -66,6 +69,33 @@ pub enum McpOAuthLoginOutcome {
Unsupported,
}
#[derive(Clone)]
enum McpOAuthHttpClient {
LocalNoProxy,
Runtime(Arc<dyn HttpClient>),
}
impl McpOAuthHttpClient {
async fn login_support(&self, transport: &McpServerTransportConfig) -> McpOAuthLoginSupport {
match self {
Self::LocalNoProxy => oauth_login_support(transport).await,
Self::Runtime(http_client) => {
oauth_login_support_with_client(transport, http_client.clone()).await
}
}
}
async fn discover_supported_scopes(
&self,
transport: &McpServerTransportConfig,
) -> Option<Vec<String>> {
match self.login_support(transport).await {
McpOAuthLoginSupport::Supported(config) => config.discovered_scopes,
McpOAuthLoginSupport::Unsupported | McpOAuthLoginSupport::Unknown(_) => None,
}
}
}
pub async fn oauth_login_support(transport: &McpServerTransportConfig) -> McpOAuthLoginSupport {
oauth_login_support_with_discovery(transport).await
}
@@ -175,16 +205,6 @@ pub async fn discover_supported_scopes_for_server(
}
}
async fn discover_supported_scopes_with_client(
transport: &McpServerTransportConfig,
http_client: Arc<dyn HttpClient>,
) -> Option<Vec<String>> {
match oauth_login_support_with_client(transport, http_client).await {
McpOAuthLoginSupport::Supported(config) => config.discovered_scopes,
McpOAuthLoginSupport::Unsupported | McpOAuthLoginSupport::Unknown(_) => None,
}
}
pub fn resolve_oauth_scopes(
explicit_scopes: Option<Vec<String>>,
configured_scopes: Option<Vec<String>>,
@@ -245,29 +265,50 @@ pub async fn perform_oauth_login_return_url_for_server(
anyhow::bail!("OAuth login is only supported for streamable HTTP servers.");
};
let http_client = http_client_for_server(config, runtime_environment)?;
let oauth_http_client = oauth_http_client_for_server(config, runtime_environment)?;
let discovered_scopes = if explicit_scopes.is_none() && config.scopes.is_none() {
discover_supported_scopes_with_client(&config.transport, http_client.clone()).await
oauth_http_client
.discover_supported_scopes(&config.transport)
.await
} else {
None
};
let resolved_scopes =
resolve_oauth_scopes(explicit_scopes, config.scopes.clone(), discovered_scopes);
perform_oauth_login_return_url_with_client(
server_name,
url,
store_mode,
http_headers.clone(),
env_http_headers.clone(),
&resolved_scopes.scopes,
config.oauth_resource.as_deref(),
timeout_secs,
callback_port,
callback_url,
http_client,
)
.await
match oauth_http_client {
McpOAuthHttpClient::LocalNoProxy => {
perform_oauth_login_return_url(
server_name,
url,
store_mode,
http_headers.clone(),
env_http_headers.clone(),
&resolved_scopes.scopes,
config.oauth_resource.as_deref(),
timeout_secs,
callback_port,
callback_url,
)
.await
}
McpOAuthHttpClient::Runtime(http_client) => {
perform_oauth_login_return_url_with_client(
server_name,
url,
store_mode,
http_headers.clone(),
env_http_headers.clone(),
&resolved_scopes.scopes,
config.oauth_resource.as_deref(),
timeout_secs,
callback_port,
callback_url,
http_client,
)
.await
}
}
}
#[allow(clippy::too_many_arguments)]
@@ -280,13 +321,12 @@ pub async fn perform_oauth_login_silent_for_server(
callback_url: Option<&str>,
runtime_environment: McpRuntimeEnvironment,
) -> Result<McpOAuthLoginOutcome> {
let http_client = http_client_for_server(config, runtime_environment)?;
let oauth_config =
match oauth_login_support_with_client(&config.transport, http_client.clone()).await {
McpOAuthLoginSupport::Supported(config) => config,
McpOAuthLoginSupport::Unsupported => return Ok(McpOAuthLoginOutcome::Unsupported),
McpOAuthLoginSupport::Unknown(err) => return Err(err),
};
let oauth_http_client = oauth_http_client_for_server(config, runtime_environment)?;
let oauth_config = match oauth_http_client.login_support(&config.transport).await {
McpOAuthLoginSupport::Supported(config) => config,
McpOAuthLoginSupport::Unsupported => return Ok(McpOAuthLoginOutcome::Unsupported),
McpOAuthLoginSupport::Unknown(err) => return Err(err),
};
let resolved_scopes = resolve_oauth_scopes(
explicit_scopes,
@@ -294,37 +334,71 @@ pub async fn perform_oauth_login_silent_for_server(
oauth_config.discovered_scopes.clone(),
);
let first_attempt = perform_oauth_login_silent_with_client(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers.clone(),
oauth_config.env_http_headers.clone(),
&resolved_scopes.scopes,
config.oauth_resource.as_deref(),
callback_port,
callback_url,
http_client.clone(),
)
.await;
let final_result = match first_attempt {
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
perform_oauth_login_silent_with_client(
let final_result = match oauth_http_client {
McpOAuthHttpClient::LocalNoProxy => {
let first_attempt = perform_oauth_login_silent(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers,
oauth_config.env_http_headers,
&[],
oauth_config.http_headers.clone(),
oauth_config.env_http_headers.clone(),
&resolved_scopes.scopes,
config.oauth_resource.as_deref(),
callback_port,
callback_url,
http_client,
)
.await
.await;
match first_attempt {
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
perform_oauth_login_silent(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers,
oauth_config.env_http_headers,
&[],
config.oauth_resource.as_deref(),
callback_port,
callback_url,
)
.await
}
result => result,
}
}
McpOAuthHttpClient::Runtime(http_client) => {
let first_attempt = perform_oauth_login_silent_with_client(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers.clone(),
oauth_config.env_http_headers.clone(),
&resolved_scopes.scopes,
config.oauth_resource.as_deref(),
callback_port,
callback_url,
http_client.clone(),
)
.await;
match first_attempt {
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
perform_oauth_login_silent_with_client(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers,
oauth_config.env_http_headers,
&[],
config.oauth_resource.as_deref(),
callback_port,
callback_url,
http_client,
)
.await
}
result => result,
}
}
result => result,
};
final_result.map(|()| McpOAuthLoginOutcome::Completed)
@@ -340,13 +414,12 @@ pub async fn perform_oauth_login_for_server(
callback_url: Option<&str>,
runtime_environment: McpRuntimeEnvironment,
) -> Result<McpOAuthLoginOutcome> {
let http_client = http_client_for_server(config, runtime_environment)?;
let oauth_config =
match oauth_login_support_with_client(&config.transport, http_client.clone()).await {
McpOAuthLoginSupport::Supported(config) => config,
McpOAuthLoginSupport::Unsupported => return Ok(McpOAuthLoginOutcome::Unsupported),
McpOAuthLoginSupport::Unknown(err) => return Err(err),
};
let oauth_http_client = oauth_http_client_for_server(config, runtime_environment)?;
let oauth_config = match oauth_http_client.login_support(&config.transport).await {
McpOAuthLoginSupport::Supported(config) => config,
McpOAuthLoginSupport::Unsupported => return Ok(McpOAuthLoginOutcome::Unsupported),
McpOAuthLoginSupport::Unknown(err) => return Err(err),
};
let resolved_scopes = resolve_oauth_scopes(
explicit_scopes,
@@ -354,37 +427,71 @@ pub async fn perform_oauth_login_for_server(
oauth_config.discovered_scopes.clone(),
);
let first_attempt = perform_oauth_login_with_client(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers.clone(),
oauth_config.env_http_headers.clone(),
&resolved_scopes.scopes,
config.oauth_resource.as_deref(),
callback_port,
callback_url,
http_client.clone(),
)
.await;
let final_result = match first_attempt {
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
perform_oauth_login_with_client(
let final_result = match oauth_http_client {
McpOAuthHttpClient::LocalNoProxy => {
let first_attempt = perform_oauth_login(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers,
oauth_config.env_http_headers,
&[],
oauth_config.http_headers.clone(),
oauth_config.env_http_headers.clone(),
&resolved_scopes.scopes,
config.oauth_resource.as_deref(),
callback_port,
callback_url,
http_client,
)
.await
.await;
match first_attempt {
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
perform_oauth_login(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers,
oauth_config.env_http_headers,
&[],
config.oauth_resource.as_deref(),
callback_port,
callback_url,
)
.await
}
result => result,
}
}
McpOAuthHttpClient::Runtime(http_client) => {
let first_attempt = perform_oauth_login_with_client(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers.clone(),
oauth_config.env_http_headers.clone(),
&resolved_scopes.scopes,
config.oauth_resource.as_deref(),
callback_port,
callback_url,
http_client.clone(),
)
.await;
match first_attempt {
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
perform_oauth_login_with_client(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers,
oauth_config.env_http_headers,
&[],
config.oauth_resource.as_deref(),
callback_port,
callback_url,
http_client,
)
.await
}
result => result,
}
}
result => result,
};
final_result.map(|()| McpOAuthLoginOutcome::Completed)
@@ -510,6 +617,20 @@ pub fn http_client_for_server(
}
}
fn oauth_http_client_for_server(
config: &McpServerConfig,
runtime_environment: McpRuntimeEnvironment,
) -> Result<McpOAuthHttpClient> {
match config.experimental_environment.as_deref() {
None | Some("local") => Ok(McpOAuthHttpClient::LocalNoProxy),
Some("remote") => Ok(McpOAuthHttpClient::Runtime(http_client_for_server(
config,
runtime_environment,
)?)),
Some(environment) => anyhow::bail!("unsupported experimental_environment `{environment}`"),
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;

View File

@@ -40,7 +40,7 @@ pub async fn determine_streamable_http_auth_status(
http_headers,
env_http_headers,
store_mode,
Arc::new(NoProxyReqwestHttpClient),
no_proxy_http_client(),
)
.await
}
@@ -101,11 +101,15 @@ pub async fn discover_streamable_http_oauth(
url,
http_headers,
env_http_headers,
Arc::new(NoProxyReqwestHttpClient),
no_proxy_http_client(),
)
.await
}
pub(crate) fn no_proxy_http_client() -> Arc<dyn HttpClient> {
Arc::new(NoProxyReqwestHttpClient)
}
#[derive(Debug, Clone, Default)]
struct NoProxyReqwestHttpClient;

View File

@@ -95,6 +95,43 @@ pub(crate) fn load_oauth_tokens(
}
}
enum OAuthTokensStorageState {
Found(StoredOAuthTokens),
Missing,
Unavailable,
}
fn load_oauth_tokens_for_persist(
server_name: &str,
url: &str,
store_mode: OAuthCredentialsStoreMode,
) -> Result<OAuthTokensStorageState> {
let keyring_store = DefaultKeyringStore;
match store_mode {
OAuthCredentialsStoreMode::Auto => {
load_oauth_tokens_for_persist_from_keyring_with_fallback_to_file(
&keyring_store,
server_name,
url,
)
}
OAuthCredentialsStoreMode::File => {
Ok(match load_oauth_tokens_from_file(server_name, url)? {
Some(tokens) => OAuthTokensStorageState::Found(tokens),
None => OAuthTokensStorageState::Missing,
})
}
OAuthCredentialsStoreMode::Keyring => Ok(
match load_oauth_tokens_from_keyring(&keyring_store, server_name, url)
.with_context(|| "failed to read OAuth tokens from keyring".to_string())?
{
Some(tokens) => OAuthTokensStorageState::Found(tokens),
None => OAuthTokensStorageState::Missing,
},
),
}
}
pub(crate) fn has_oauth_tokens(
server_name: &str,
url: &str,
@@ -135,6 +172,48 @@ fn load_oauth_tokens_from_keyring_with_fallback_to_file<K: KeyringStore>(
}
}
fn stored_tokens_match_without_expires_in(
actual: &StoredOAuthTokens,
expected: &StoredOAuthTokens,
) -> bool {
let actual_response = &actual.token_response.0;
let expected_response = &expected.token_response.0;
actual.server_name == expected.server_name
&& actual.url == expected.url
&& actual.client_id == expected.client_id
&& actual.expires_at == expected.expires_at
&& actual_response.access_token().secret() == expected_response.access_token().secret()
&& actual_response.token_type() == expected_response.token_type()
&& actual_response.refresh_token().map(RefreshToken::secret)
== expected_response.refresh_token().map(RefreshToken::secret)
&& actual_response.scopes() == expected_response.scopes()
&& actual_response.extra_fields() == expected_response.extra_fields()
}
fn load_oauth_tokens_for_persist_from_keyring_with_fallback_to_file<K: KeyringStore>(
keyring_store: &K,
server_name: &str,
url: &str,
) -> Result<OAuthTokensStorageState> {
match load_oauth_tokens_from_keyring(keyring_store, server_name, url) {
Ok(Some(tokens)) => Ok(OAuthTokensStorageState::Found(tokens)),
Ok(None) => Ok(match load_oauth_tokens_from_file(server_name, url)? {
Some(tokens) => OAuthTokensStorageState::Found(tokens),
None => OAuthTokensStorageState::Missing,
}),
Err(error) => {
warn!("failed to read OAuth tokens from keyring: {error}");
match load_oauth_tokens_from_file(server_name, url)
.with_context(|| format!("failed to read OAuth tokens from keyring: {error}"))?
{
Some(tokens) => Ok(OAuthTokensStorageState::Found(tokens)),
None => Ok(OAuthTokensStorageState::Unavailable),
}
}
}
}
fn load_oauth_tokens_from_keyring<K: KeyringStore>(
keyring_store: &K,
server_name: &str,
@@ -295,23 +374,46 @@ impl OAuthPersistor {
let Some(stored) = credentials_state.current.clone() else {
return Ok(());
};
match load_oauth_tokens(&self.inner.server_name, &stored.url, self.inner.store_mode)? {
None => {
match load_oauth_tokens_for_persist(
&self.inner.server_name,
&stored.url,
self.inner.store_mode,
)? {
OAuthTokensStorageState::Missing => {
credentials_state.current = None;
credentials_state.last_persisted = None;
}
Some(current_store) if current_store == stored => {
credentials_state.last_persisted = Some(current_store);
OAuthTokensStorageState::Unavailable => {
let should_save = match credentials_state.last_persisted.as_ref() {
Some(last_persisted) => {
!stored_tokens_match_without_expires_in(&stored, last_persisted)
}
None => true,
};
if should_save {
save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?;
credentials_state.last_persisted = Some(stored);
}
}
Some(current_store)
if Some(&current_store) != credentials_state.last_persisted.as_ref() =>
OAuthTokensStorageState::Found(current_store)
if stored_tokens_match_without_expires_in(&current_store, &stored) =>
{
credentials_state.current = Some(current_store.clone());
credentials_state.last_persisted = Some(current_store);
}
Some(_) => {
save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?;
credentials_state.last_persisted = Some(stored);
OAuthTokensStorageState::Found(current_store) => {
let store_matches_last_persisted = credentials_state
.last_persisted
.as_ref()
.is_some_and(|last_persisted| {
stored_tokens_match_without_expires_in(&current_store, last_persisted)
});
if store_matches_last_persisted {
save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?;
credentials_state.last_persisted = Some(stored);
} else {
credentials_state.current = Some(current_store.clone());
credentials_state.last_persisted = Some(current_store);
}
}
}
Ok(())
@@ -722,6 +824,24 @@ mod tests {
Ok(())
}
#[test]
fn load_oauth_tokens_for_persist_keeps_keyring_errors_distinct() -> Result<()> {
let _env = TempCodexHome::new();
let store = MockKeyringStore::default();
let tokens = sample_tokens();
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
store.set_error(&key, KeyringError::Invalid("error".into(), "load".into()));
let state = super::load_oauth_tokens_for_persist_from_keyring_with_fallback_to_file(
&store,
&tokens.server_name,
&tokens.url,
)?;
assert!(matches!(state, super::OAuthTokensStorageState::Unavailable));
Ok(())
}
#[test]
fn save_oauth_tokens_prefers_keyring_when_available() -> Result<()> {
let _env = TempCodexHome::new();
@@ -942,6 +1062,51 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn persistor_saves_newer_memory_credentials_despite_expires_in_drift() -> Result<()> {
let _env = TempCodexHome::new();
let mut stored_tokens = sample_tokens();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0));
stored_tokens.expires_at = Some(now.as_millis() as u64 + 10_000);
let expires_in = Duration::from_secs(3600);
stored_tokens
.token_response
.0
.set_expires_in(Some(&expires_in));
super::save_oauth_tokens_to_file(&stored_tokens)?;
let persistor = OAuthPersistor::new(
stored_tokens.server_name.clone(),
OAuthHttpClient::new(
Arc::new(FailingHttpClient),
/*http_headers*/ None,
/*env_http_headers*/ None,
)?,
OAuthCredentialsStoreMode::File,
Some(stored_tokens.clone()),
);
let newer_tokens = sample_tokens_with_access_token("newer-access-token");
{
let mut state = persistor.inner.credentials_state.lock().await;
state.current = Some(newer_tokens.clone());
state.last_persisted = Some(stored_tokens);
}
persistor.persist_if_needed().await?;
let loaded = super::load_oauth_tokens(
&newer_tokens.server_name,
&newer_tokens.url,
OAuthCredentialsStoreMode::File,
)?
.expect("newer tokens should be stored");
assert_tokens_match_without_expiry(&loaded, &newer_tokens);
Ok(())
}
#[tokio::test]
async fn access_token_uses_cached_token_when_refresh_fails_before_expiry() -> Result<()> {
let mut tokens = sample_tokens();

View File

@@ -8,7 +8,6 @@ use anyhow::Result;
use anyhow::anyhow;
use anyhow::bail;
use codex_exec_server::HttpClient;
use codex_exec_server::ReqwestHttpClient;
use reqwest::Url;
use tiny_http::Response;
use tiny_http::Server;
@@ -18,6 +17,7 @@ use urlencoding::decode;
use crate::StoredOAuthTokens;
use crate::WrappedOAuthTokenResponse;
use crate::auth_status::no_proxy_http_client;
use crate::mcp_oauth_http::OAuthAuthorizationSession;
use crate::mcp_oauth_http::OAuthHttpClient;
use crate::oauth::compute_expires_at_millis;
@@ -91,7 +91,7 @@ pub async fn perform_oauth_login(
oauth_resource,
callback_port,
callback_url,
Arc::new(ReqwestHttpClient),
no_proxy_http_client(),
)
.await
}
@@ -147,7 +147,7 @@ pub async fn perform_oauth_login_silent(
oauth_resource,
callback_port,
callback_url,
Arc::new(ReqwestHttpClient),
no_proxy_http_client(),
)
.await
}
@@ -241,7 +241,7 @@ pub async fn perform_oauth_login_return_url(
timeout_secs,
callback_port,
callback_url,
Arc::new(ReqwestHttpClient),
no_proxy_http_client(),
)
.await
}