From f8e2beeafae92c91e5630efd4a79dec03fcc0606 Mon Sep 17 00:00:00 2001 From: Ahmed Ibrahim Date: Mon, 27 Apr 2026 11:58:06 +0000 Subject: [PATCH] Address OAuth persistence review feedback Keep local OAuth login flows on the no-proxy client while remote flows continue to use the runtime-selected HTTP client. Harden OAuth persistence so expires_in drift does not look like an external token update and transient keyring read failures do not clear in-memory credentials. Co-authored-by: Codex --- codex-rs/codex-mcp/src/mcp/auth.rs | 293 +++++++++++++----- codex-rs/rmcp-client/src/auth_status.rs | 8 +- codex-rs/rmcp-client/src/oauth.rs | 185 ++++++++++- .../rmcp-client/src/perform_oauth_login.rs | 8 +- 4 files changed, 392 insertions(+), 102 deletions(-) diff --git a/codex-rs/codex-mcp/src/mcp/auth.rs b/codex-rs/codex-mcp/src/mcp/auth.rs index bcd9e77d30..f2c7eeb73a 100644 --- a/codex-rs/codex-mcp/src/mcp/auth.rs +++ b/codex-rs/codex-mcp/src/mcp/auth.rs @@ -14,7 +14,10 @@ use codex_rmcp_client::determine_streamable_http_auth_status; use codex_rmcp_client::determine_streamable_http_auth_status_with_client; use codex_rmcp_client::discover_streamable_http_oauth; use codex_rmcp_client::discover_streamable_http_oauth_with_client; +use codex_rmcp_client::perform_oauth_login; +use codex_rmcp_client::perform_oauth_login_return_url; use codex_rmcp_client::perform_oauth_login_return_url_with_client; +use codex_rmcp_client::perform_oauth_login_silent; use codex_rmcp_client::perform_oauth_login_silent_with_client; use codex_rmcp_client::perform_oauth_login_with_client; use futures::future::join_all; @@ -66,6 +69,33 @@ pub enum McpOAuthLoginOutcome { Unsupported, } +#[derive(Clone)] +enum McpOAuthHttpClient { + LocalNoProxy, + Runtime(Arc), +} + +impl McpOAuthHttpClient { + async fn login_support(&self, transport: &McpServerTransportConfig) -> McpOAuthLoginSupport { + match self { + Self::LocalNoProxy => oauth_login_support(transport).await, + Self::Runtime(http_client) => { + oauth_login_support_with_client(transport, http_client.clone()).await + } + } + } + + async fn discover_supported_scopes( + &self, + transport: &McpServerTransportConfig, + ) -> Option> { + match self.login_support(transport).await { + McpOAuthLoginSupport::Supported(config) => config.discovered_scopes, + McpOAuthLoginSupport::Unsupported | McpOAuthLoginSupport::Unknown(_) => None, + } + } +} + pub async fn oauth_login_support(transport: &McpServerTransportConfig) -> McpOAuthLoginSupport { oauth_login_support_with_discovery(transport).await } @@ -175,16 +205,6 @@ pub async fn discover_supported_scopes_for_server( } } -async fn discover_supported_scopes_with_client( - transport: &McpServerTransportConfig, - http_client: Arc, -) -> Option> { - match oauth_login_support_with_client(transport, http_client).await { - McpOAuthLoginSupport::Supported(config) => config.discovered_scopes, - McpOAuthLoginSupport::Unsupported | McpOAuthLoginSupport::Unknown(_) => None, - } -} - pub fn resolve_oauth_scopes( explicit_scopes: Option>, configured_scopes: Option>, @@ -245,29 +265,50 @@ pub async fn perform_oauth_login_return_url_for_server( anyhow::bail!("OAuth login is only supported for streamable HTTP servers."); }; - let http_client = http_client_for_server(config, runtime_environment)?; + let oauth_http_client = oauth_http_client_for_server(config, runtime_environment)?; let discovered_scopes = if explicit_scopes.is_none() && config.scopes.is_none() { - discover_supported_scopes_with_client(&config.transport, http_client.clone()).await + oauth_http_client + .discover_supported_scopes(&config.transport) + .await } else { None }; let resolved_scopes = resolve_oauth_scopes(explicit_scopes, config.scopes.clone(), discovered_scopes); - perform_oauth_login_return_url_with_client( - server_name, - url, - store_mode, - http_headers.clone(), - env_http_headers.clone(), - &resolved_scopes.scopes, - config.oauth_resource.as_deref(), - timeout_secs, - callback_port, - callback_url, - http_client, - ) - .await + match oauth_http_client { + McpOAuthHttpClient::LocalNoProxy => { + perform_oauth_login_return_url( + server_name, + url, + store_mode, + http_headers.clone(), + env_http_headers.clone(), + &resolved_scopes.scopes, + config.oauth_resource.as_deref(), + timeout_secs, + callback_port, + callback_url, + ) + .await + } + McpOAuthHttpClient::Runtime(http_client) => { + perform_oauth_login_return_url_with_client( + server_name, + url, + store_mode, + http_headers.clone(), + env_http_headers.clone(), + &resolved_scopes.scopes, + config.oauth_resource.as_deref(), + timeout_secs, + callback_port, + callback_url, + http_client, + ) + .await + } + } } #[allow(clippy::too_many_arguments)] @@ -280,13 +321,12 @@ pub async fn perform_oauth_login_silent_for_server( callback_url: Option<&str>, runtime_environment: McpRuntimeEnvironment, ) -> Result { - let http_client = http_client_for_server(config, runtime_environment)?; - let oauth_config = - match oauth_login_support_with_client(&config.transport, http_client.clone()).await { - McpOAuthLoginSupport::Supported(config) => config, - McpOAuthLoginSupport::Unsupported => return Ok(McpOAuthLoginOutcome::Unsupported), - McpOAuthLoginSupport::Unknown(err) => return Err(err), - }; + let oauth_http_client = oauth_http_client_for_server(config, runtime_environment)?; + let oauth_config = match oauth_http_client.login_support(&config.transport).await { + McpOAuthLoginSupport::Supported(config) => config, + McpOAuthLoginSupport::Unsupported => return Ok(McpOAuthLoginOutcome::Unsupported), + McpOAuthLoginSupport::Unknown(err) => return Err(err), + }; let resolved_scopes = resolve_oauth_scopes( explicit_scopes, @@ -294,37 +334,71 @@ pub async fn perform_oauth_login_silent_for_server( oauth_config.discovered_scopes.clone(), ); - let first_attempt = perform_oauth_login_silent_with_client( - server_name, - &oauth_config.url, - store_mode, - oauth_config.http_headers.clone(), - oauth_config.env_http_headers.clone(), - &resolved_scopes.scopes, - config.oauth_resource.as_deref(), - callback_port, - callback_url, - http_client.clone(), - ) - .await; - - let final_result = match first_attempt { - Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => { - perform_oauth_login_silent_with_client( + let final_result = match oauth_http_client { + McpOAuthHttpClient::LocalNoProxy => { + let first_attempt = perform_oauth_login_silent( server_name, &oauth_config.url, store_mode, - oauth_config.http_headers, - oauth_config.env_http_headers, - &[], + oauth_config.http_headers.clone(), + oauth_config.env_http_headers.clone(), + &resolved_scopes.scopes, config.oauth_resource.as_deref(), callback_port, callback_url, - http_client, ) - .await + .await; + match first_attempt { + Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => { + perform_oauth_login_silent( + server_name, + &oauth_config.url, + store_mode, + oauth_config.http_headers, + oauth_config.env_http_headers, + &[], + config.oauth_resource.as_deref(), + callback_port, + callback_url, + ) + .await + } + result => result, + } + } + McpOAuthHttpClient::Runtime(http_client) => { + let first_attempt = perform_oauth_login_silent_with_client( + server_name, + &oauth_config.url, + store_mode, + oauth_config.http_headers.clone(), + oauth_config.env_http_headers.clone(), + &resolved_scopes.scopes, + config.oauth_resource.as_deref(), + callback_port, + callback_url, + http_client.clone(), + ) + .await; + match first_attempt { + Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => { + perform_oauth_login_silent_with_client( + server_name, + &oauth_config.url, + store_mode, + oauth_config.http_headers, + oauth_config.env_http_headers, + &[], + config.oauth_resource.as_deref(), + callback_port, + callback_url, + http_client, + ) + .await + } + result => result, + } } - result => result, }; final_result.map(|()| McpOAuthLoginOutcome::Completed) @@ -340,13 +414,12 @@ pub async fn perform_oauth_login_for_server( callback_url: Option<&str>, runtime_environment: McpRuntimeEnvironment, ) -> Result { - let http_client = http_client_for_server(config, runtime_environment)?; - let oauth_config = - match oauth_login_support_with_client(&config.transport, http_client.clone()).await { - McpOAuthLoginSupport::Supported(config) => config, - McpOAuthLoginSupport::Unsupported => return Ok(McpOAuthLoginOutcome::Unsupported), - McpOAuthLoginSupport::Unknown(err) => return Err(err), - }; + let oauth_http_client = oauth_http_client_for_server(config, runtime_environment)?; + let oauth_config = match oauth_http_client.login_support(&config.transport).await { + McpOAuthLoginSupport::Supported(config) => config, + McpOAuthLoginSupport::Unsupported => return Ok(McpOAuthLoginOutcome::Unsupported), + McpOAuthLoginSupport::Unknown(err) => return Err(err), + }; let resolved_scopes = resolve_oauth_scopes( explicit_scopes, @@ -354,37 +427,71 @@ pub async fn perform_oauth_login_for_server( oauth_config.discovered_scopes.clone(), ); - let first_attempt = perform_oauth_login_with_client( - server_name, - &oauth_config.url, - store_mode, - oauth_config.http_headers.clone(), - oauth_config.env_http_headers.clone(), - &resolved_scopes.scopes, - config.oauth_resource.as_deref(), - callback_port, - callback_url, - http_client.clone(), - ) - .await; - - let final_result = match first_attempt { - Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => { - perform_oauth_login_with_client( + let final_result = match oauth_http_client { + McpOAuthHttpClient::LocalNoProxy => { + let first_attempt = perform_oauth_login( server_name, &oauth_config.url, store_mode, - oauth_config.http_headers, - oauth_config.env_http_headers, - &[], + oauth_config.http_headers.clone(), + oauth_config.env_http_headers.clone(), + &resolved_scopes.scopes, config.oauth_resource.as_deref(), callback_port, callback_url, - http_client, ) - .await + .await; + match first_attempt { + Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => { + perform_oauth_login( + server_name, + &oauth_config.url, + store_mode, + oauth_config.http_headers, + oauth_config.env_http_headers, + &[], + config.oauth_resource.as_deref(), + callback_port, + callback_url, + ) + .await + } + result => result, + } + } + McpOAuthHttpClient::Runtime(http_client) => { + let first_attempt = perform_oauth_login_with_client( + server_name, + &oauth_config.url, + store_mode, + oauth_config.http_headers.clone(), + oauth_config.env_http_headers.clone(), + &resolved_scopes.scopes, + config.oauth_resource.as_deref(), + callback_port, + callback_url, + http_client.clone(), + ) + .await; + match first_attempt { + Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => { + perform_oauth_login_with_client( + server_name, + &oauth_config.url, + store_mode, + oauth_config.http_headers, + oauth_config.env_http_headers, + &[], + config.oauth_resource.as_deref(), + callback_port, + callback_url, + http_client, + ) + .await + } + result => result, + } } - result => result, }; final_result.map(|()| McpOAuthLoginOutcome::Completed) @@ -510,6 +617,20 @@ pub fn http_client_for_server( } } +fn oauth_http_client_for_server( + config: &McpServerConfig, + runtime_environment: McpRuntimeEnvironment, +) -> Result { + match config.experimental_environment.as_deref() { + None | Some("local") => Ok(McpOAuthHttpClient::LocalNoProxy), + Some("remote") => Ok(McpOAuthHttpClient::Runtime(http_client_for_server( + config, + runtime_environment, + )?)), + Some(environment) => anyhow::bail!("unsupported experimental_environment `{environment}`"), + } +} + #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/codex-rs/rmcp-client/src/auth_status.rs b/codex-rs/rmcp-client/src/auth_status.rs index a36624f279..ac227159df 100644 --- a/codex-rs/rmcp-client/src/auth_status.rs +++ b/codex-rs/rmcp-client/src/auth_status.rs @@ -40,7 +40,7 @@ pub async fn determine_streamable_http_auth_status( http_headers, env_http_headers, store_mode, - Arc::new(NoProxyReqwestHttpClient), + no_proxy_http_client(), ) .await } @@ -101,11 +101,15 @@ pub async fn discover_streamable_http_oauth( url, http_headers, env_http_headers, - Arc::new(NoProxyReqwestHttpClient), + no_proxy_http_client(), ) .await } +pub(crate) fn no_proxy_http_client() -> Arc { + Arc::new(NoProxyReqwestHttpClient) +} + #[derive(Debug, Clone, Default)] struct NoProxyReqwestHttpClient; diff --git a/codex-rs/rmcp-client/src/oauth.rs b/codex-rs/rmcp-client/src/oauth.rs index 22322faf14..39c6f50edb 100644 --- a/codex-rs/rmcp-client/src/oauth.rs +++ b/codex-rs/rmcp-client/src/oauth.rs @@ -95,6 +95,43 @@ pub(crate) fn load_oauth_tokens( } } +enum OAuthTokensStorageState { + Found(StoredOAuthTokens), + Missing, + Unavailable, +} + +fn load_oauth_tokens_for_persist( + server_name: &str, + url: &str, + store_mode: OAuthCredentialsStoreMode, +) -> Result { + let keyring_store = DefaultKeyringStore; + match store_mode { + OAuthCredentialsStoreMode::Auto => { + load_oauth_tokens_for_persist_from_keyring_with_fallback_to_file( + &keyring_store, + server_name, + url, + ) + } + OAuthCredentialsStoreMode::File => { + Ok(match load_oauth_tokens_from_file(server_name, url)? { + Some(tokens) => OAuthTokensStorageState::Found(tokens), + None => OAuthTokensStorageState::Missing, + }) + } + OAuthCredentialsStoreMode::Keyring => Ok( + match load_oauth_tokens_from_keyring(&keyring_store, server_name, url) + .with_context(|| "failed to read OAuth tokens from keyring".to_string())? + { + Some(tokens) => OAuthTokensStorageState::Found(tokens), + None => OAuthTokensStorageState::Missing, + }, + ), + } +} + pub(crate) fn has_oauth_tokens( server_name: &str, url: &str, @@ -135,6 +172,48 @@ fn load_oauth_tokens_from_keyring_with_fallback_to_file( } } +fn stored_tokens_match_without_expires_in( + actual: &StoredOAuthTokens, + expected: &StoredOAuthTokens, +) -> bool { + let actual_response = &actual.token_response.0; + let expected_response = &expected.token_response.0; + + actual.server_name == expected.server_name + && actual.url == expected.url + && actual.client_id == expected.client_id + && actual.expires_at == expected.expires_at + && actual_response.access_token().secret() == expected_response.access_token().secret() + && actual_response.token_type() == expected_response.token_type() + && actual_response.refresh_token().map(RefreshToken::secret) + == expected_response.refresh_token().map(RefreshToken::secret) + && actual_response.scopes() == expected_response.scopes() + && actual_response.extra_fields() == expected_response.extra_fields() +} + +fn load_oauth_tokens_for_persist_from_keyring_with_fallback_to_file( + keyring_store: &K, + server_name: &str, + url: &str, +) -> Result { + match load_oauth_tokens_from_keyring(keyring_store, server_name, url) { + Ok(Some(tokens)) => Ok(OAuthTokensStorageState::Found(tokens)), + Ok(None) => Ok(match load_oauth_tokens_from_file(server_name, url)? { + Some(tokens) => OAuthTokensStorageState::Found(tokens), + None => OAuthTokensStorageState::Missing, + }), + Err(error) => { + warn!("failed to read OAuth tokens from keyring: {error}"); + match load_oauth_tokens_from_file(server_name, url) + .with_context(|| format!("failed to read OAuth tokens from keyring: {error}"))? + { + Some(tokens) => Ok(OAuthTokensStorageState::Found(tokens)), + None => Ok(OAuthTokensStorageState::Unavailable), + } + } + } +} + fn load_oauth_tokens_from_keyring( keyring_store: &K, server_name: &str, @@ -295,23 +374,46 @@ impl OAuthPersistor { let Some(stored) = credentials_state.current.clone() else { return Ok(()); }; - match load_oauth_tokens(&self.inner.server_name, &stored.url, self.inner.store_mode)? { - None => { + match load_oauth_tokens_for_persist( + &self.inner.server_name, + &stored.url, + self.inner.store_mode, + )? { + OAuthTokensStorageState::Missing => { credentials_state.current = None; credentials_state.last_persisted = None; } - Some(current_store) if current_store == stored => { - credentials_state.last_persisted = Some(current_store); + OAuthTokensStorageState::Unavailable => { + let should_save = match credentials_state.last_persisted.as_ref() { + Some(last_persisted) => { + !stored_tokens_match_without_expires_in(&stored, last_persisted) + } + None => true, + }; + if should_save { + save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?; + credentials_state.last_persisted = Some(stored); + } } - Some(current_store) - if Some(¤t_store) != credentials_state.last_persisted.as_ref() => + OAuthTokensStorageState::Found(current_store) + if stored_tokens_match_without_expires_in(¤t_store, &stored) => { - credentials_state.current = Some(current_store.clone()); credentials_state.last_persisted = Some(current_store); } - Some(_) => { - save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?; - credentials_state.last_persisted = Some(stored); + OAuthTokensStorageState::Found(current_store) => { + let store_matches_last_persisted = credentials_state + .last_persisted + .as_ref() + .is_some_and(|last_persisted| { + stored_tokens_match_without_expires_in(¤t_store, last_persisted) + }); + if store_matches_last_persisted { + save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?; + credentials_state.last_persisted = Some(stored); + } else { + credentials_state.current = Some(current_store.clone()); + credentials_state.last_persisted = Some(current_store); + } } } Ok(()) @@ -722,6 +824,24 @@ mod tests { Ok(()) } + #[test] + fn load_oauth_tokens_for_persist_keeps_keyring_errors_distinct() -> Result<()> { + let _env = TempCodexHome::new(); + let store = MockKeyringStore::default(); + let tokens = sample_tokens(); + let key = super::compute_store_key(&tokens.server_name, &tokens.url)?; + store.set_error(&key, KeyringError::Invalid("error".into(), "load".into())); + + let state = super::load_oauth_tokens_for_persist_from_keyring_with_fallback_to_file( + &store, + &tokens.server_name, + &tokens.url, + )?; + + assert!(matches!(state, super::OAuthTokensStorageState::Unavailable)); + Ok(()) + } + #[test] fn save_oauth_tokens_prefers_keyring_when_available() -> Result<()> { let _env = TempCodexHome::new(); @@ -942,6 +1062,51 @@ mod tests { Ok(()) } + #[tokio::test] + async fn persistor_saves_newer_memory_credentials_despite_expires_in_drift() -> Result<()> { + let _env = TempCodexHome::new(); + let mut stored_tokens = sample_tokens(); + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_else(|_| Duration::from_secs(0)); + stored_tokens.expires_at = Some(now.as_millis() as u64 + 10_000); + let expires_in = Duration::from_secs(3600); + stored_tokens + .token_response + .0 + .set_expires_in(Some(&expires_in)); + super::save_oauth_tokens_to_file(&stored_tokens)?; + + let persistor = OAuthPersistor::new( + stored_tokens.server_name.clone(), + OAuthHttpClient::new( + Arc::new(FailingHttpClient), + /*http_headers*/ None, + /*env_http_headers*/ None, + )?, + OAuthCredentialsStoreMode::File, + Some(stored_tokens.clone()), + ); + + let newer_tokens = sample_tokens_with_access_token("newer-access-token"); + { + let mut state = persistor.inner.credentials_state.lock().await; + state.current = Some(newer_tokens.clone()); + state.last_persisted = Some(stored_tokens); + } + + persistor.persist_if_needed().await?; + + let loaded = super::load_oauth_tokens( + &newer_tokens.server_name, + &newer_tokens.url, + OAuthCredentialsStoreMode::File, + )? + .expect("newer tokens should be stored"); + assert_tokens_match_without_expiry(&loaded, &newer_tokens); + Ok(()) + } + #[tokio::test] async fn access_token_uses_cached_token_when_refresh_fails_before_expiry() -> Result<()> { let mut tokens = sample_tokens(); diff --git a/codex-rs/rmcp-client/src/perform_oauth_login.rs b/codex-rs/rmcp-client/src/perform_oauth_login.rs index f4408a0fc2..fa29f6cdd0 100644 --- a/codex-rs/rmcp-client/src/perform_oauth_login.rs +++ b/codex-rs/rmcp-client/src/perform_oauth_login.rs @@ -8,7 +8,6 @@ use anyhow::Result; use anyhow::anyhow; use anyhow::bail; use codex_exec_server::HttpClient; -use codex_exec_server::ReqwestHttpClient; use reqwest::Url; use tiny_http::Response; use tiny_http::Server; @@ -18,6 +17,7 @@ use urlencoding::decode; use crate::StoredOAuthTokens; use crate::WrappedOAuthTokenResponse; +use crate::auth_status::no_proxy_http_client; use crate::mcp_oauth_http::OAuthAuthorizationSession; use crate::mcp_oauth_http::OAuthHttpClient; use crate::oauth::compute_expires_at_millis; @@ -91,7 +91,7 @@ pub async fn perform_oauth_login( oauth_resource, callback_port, callback_url, - Arc::new(ReqwestHttpClient), + no_proxy_http_client(), ) .await } @@ -147,7 +147,7 @@ pub async fn perform_oauth_login_silent( oauth_resource, callback_port, callback_url, - Arc::new(ReqwestHttpClient), + no_proxy_http_client(), ) .await } @@ -241,7 +241,7 @@ pub async fn perform_oauth_login_return_url( timeout_secs, callback_port, callback_url, - Arc::new(ReqwestHttpClient), + no_proxy_http_client(), ) .await }