From 520ee70b737c9ac341f63fdcd780abca3d99a774 Mon Sep 17 00:00:00 2001 From: Ahmed Ibrahim Date: Mon, 27 Apr 2026 10:38:51 +0000 Subject: [PATCH] Address remote OAuth review feedback Co-authored-by: Codex --- codex-rs/cli/src/mcp_cmd.rs | 159 ++++++++---------------- codex-rs/rmcp-client/src/auth_status.rs | 96 +++++++++++++- codex-rs/rmcp-client/src/oauth.rs | 129 +++++++++++++++---- 3 files changed, 253 insertions(+), 131 deletions(-) diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index 5a8c95a9ff..f2b8c2a785 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -20,17 +20,14 @@ use codex_exec_server::Environment; use codex_exec_server::EnvironmentManager; use codex_exec_server::EnvironmentManagerArgs; use codex_exec_server::ExecServerRuntimePaths; +use codex_mcp::McpOAuthLoginOutcome; use codex_mcp::McpOAuthLoginSupport; use codex_mcp::McpRuntimeEnvironment; -use codex_mcp::ResolvedMcpOAuthScopes; use codex_mcp::compute_auth_statuses; -use codex_mcp::discover_supported_scopes; -use codex_mcp::oauth_login_support; -use codex_mcp::resolve_oauth_scopes; -use codex_mcp::should_retry_without_scopes; +use codex_mcp::oauth_login_support_for_server; +use codex_mcp::perform_oauth_login_for_server; use codex_protocol::protocol::McpAuthStatus; use codex_rmcp_client::delete_oauth_tokens; -use codex_rmcp_client::perform_oauth_login; use codex_utils_cli::CliConfigOverrides; use codex_utils_cli::format_env_display; @@ -194,54 +191,6 @@ impl McpCli { } } -/// Preserve compatibility with servers that still expect the legacy empty-scope -/// OAuth request. If a discovered-scope request is rejected by the provider, -/// retry the login flow once without scopes. -#[allow(clippy::too_many_arguments)] -async fn perform_oauth_login_retry_without_scopes( - name: &str, - url: &str, - store_mode: codex_config::types::OAuthCredentialsStoreMode, - http_headers: Option>, - env_http_headers: Option>, - resolved_scopes: &ResolvedMcpOAuthScopes, - oauth_resource: Option<&str>, - callback_port: Option, - callback_url: Option<&str>, -) -> Result<()> { - match perform_oauth_login( - name, - url, - store_mode, - http_headers.clone(), - env_http_headers.clone(), - &resolved_scopes.scopes, - oauth_resource, - callback_port, - callback_url, - ) - .await - { - Ok(()) => Ok(()), - Err(err) if should_retry_without_scopes(resolved_scopes, &err) => { - println!("OAuth provider rejected discovered scopes. Retrying without scopes…"); - perform_oauth_login( - name, - url, - store_mode, - http_headers, - env_http_headers, - &[], - oauth_resource, - callback_port, - callback_url, - ) - .await - } - Err(err) => Err(err), - } -} - async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<()> { // Validate any provided overrides even though they are not currently applied. let overrides = config_overrides @@ -303,7 +252,7 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re }; let new_entry = McpServerConfig { - transport: transport.clone(), + transport, experimental_environment: None, enabled: true, required: false, @@ -319,7 +268,7 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re tools: HashMap::new(), }; - servers.insert(name.clone(), new_entry); + servers.insert(name.clone(), new_entry.clone()); ConfigEditsBuilder::new(&codex_home) .replace_mcp_servers(&servers) @@ -329,27 +278,24 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re println!("Added global MCP server '{name}'."); - match oauth_login_support(&transport).await { - McpOAuthLoginSupport::Supported(oauth_config) => { + let runtime_environment = runtime_environment_from_config(&config); + match oauth_login_support_for_server(&new_entry, runtime_environment.clone()).await { + McpOAuthLoginSupport::Supported(_) => { println!("Detected OAuth support. Starting OAuth flow…"); - let resolved_scopes = resolve_oauth_scopes( - /*explicit_scopes*/ None, - /*configured_scopes*/ None, - oauth_config.discovered_scopes.clone(), - ); - perform_oauth_login_retry_without_scopes( + match perform_oauth_login_for_server( &name, - &oauth_config.url, + &new_entry, config.mcp_oauth_credentials_store_mode, - oauth_config.http_headers, - oauth_config.env_http_headers, - &resolved_scopes, - /*oauth_resource*/ None, + /*explicit_scopes*/ None, config.mcp_oauth_callback_port, config.mcp_oauth_callback_url.as_deref(), + runtime_environment, ) - .await?; - println!("Successfully logged in."); + .await? + { + McpOAuthLoginOutcome::Completed => println!("Successfully logged in."), + McpOAuthLoginOutcome::Unsupported => {} + } } McpOAuthLoginSupport::Unsupported => {} McpOAuthLoginSupport::Unknown(_) => println!( @@ -411,38 +357,32 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs) bail!("No MCP server named '{name}' found."); }; - let (url, http_headers, env_http_headers) = match &server.transport { - McpServerTransportConfig::StreamableHttp { - url, - http_headers, - env_http_headers, - .. - } => (url.clone(), http_headers.clone(), env_http_headers.clone()), - _ => bail!("OAuth login is only supported for streamable HTTP servers."), + if !matches!( + &server.transport, + McpServerTransportConfig::StreamableHttp { .. } + ) { + bail!("OAuth login is only supported for streamable HTTP servers."); }; let explicit_scopes = (!scopes.is_empty()).then_some(scopes); - let discovered_scopes = if explicit_scopes.is_none() && server.scopes.is_none() { - discover_supported_scopes(&server.transport).await - } else { - None - }; - let resolved_scopes = - resolve_oauth_scopes(explicit_scopes, server.scopes.clone(), discovered_scopes); - - perform_oauth_login_retry_without_scopes( + match perform_oauth_login_for_server( &name, - &url, + server, config.mcp_oauth_credentials_store_mode, - http_headers, - env_http_headers, - &resolved_scopes, - server.oauth_resource.as_deref(), + explicit_scopes, config.mcp_oauth_callback_port, config.mcp_oauth_callback_url.as_deref(), + runtime_environment_from_config(&config), ) - .await?; - println!("Successfully logged in to MCP server '{name}'."); + .await? + { + McpOAuthLoginOutcome::Completed => { + println!("Successfully logged in to MCP server '{name}'.") + } + McpOAuthLoginOutcome::Unsupported => { + bail!("MCP server '{name}' does not advertise OAuth support.") + } + } Ok(()) } @@ -478,17 +418,7 @@ async fn run_logout(config_overrides: &CliConfigOverrides, logout_args: LogoutAr Ok(()) } -async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Result<()> { - let overrides = config_overrides - .parse_overrides() - .map_err(anyhow::Error::msg)?; - let config = Config::load_with_cli_overrides(overrides) - .await - .context("failed to load configuration")?; - let mcp_manager = McpManager::new(Arc::new(PluginsManager::new( - config.codex_home.to_path_buf(), - ))); - let mcp_servers = mcp_manager.effective_servers(&config, /*auth*/ None).await; +fn runtime_environment_from_config(config: &Config) -> McpRuntimeEnvironment { let local_runtime_paths = ExecServerRuntimePaths::from_optional_paths( config.codex_self_exe.clone(), config.codex_linux_sandbox_exe.clone(), @@ -506,6 +436,21 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> .unwrap_or_else(|_| Environment::default_for_tests()), ), }; + McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf()) +} + +async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Result<()> { + let overrides = config_overrides + .parse_overrides() + .map_err(anyhow::Error::msg)?; + let config = Config::load_with_cli_overrides(overrides) + .await + .context("failed to load configuration")?; + let mcp_manager = McpManager::new(Arc::new(PluginsManager::new( + config.codex_home.to_path_buf(), + ))); + let mcp_servers = mcp_manager.effective_servers(&config, /*auth*/ None).await; + let runtime_environment = runtime_environment_from_config(&config); let mut entries: Vec<_> = mcp_servers.iter().collect(); entries.sort_by(|(a, _), (b, _)| a.cmp(b)); @@ -513,7 +458,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> mcp_servers.iter(), config.mcp_oauth_credentials_store_mode, /*auth*/ None, - McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf()), + runtime_environment, ) .await; diff --git a/codex-rs/rmcp-client/src/auth_status.rs b/codex-rs/rmcp-client/src/auth_status.rs index cbb4c5dc5a..a36624f279 100644 --- a/codex-rs/rmcp-client/src/auth_status.rs +++ b/codex-rs/rmcp-client/src/auth_status.rs @@ -1,11 +1,21 @@ use std::collections::HashMap; use std::sync::Arc; +use std::time::Duration; use anyhow::Result; +use codex_exec_server::ExecServerError; use codex_exec_server::HttpClient; -use codex_exec_server::ReqwestHttpClient; +use codex_exec_server::HttpHeader; +use codex_exec_server::HttpRequestParams; +use codex_exec_server::HttpRequestResponse; +use codex_exec_server::HttpResponseBodyStream; use codex_protocol::protocol::McpAuthStatus; +use futures::FutureExt; +use futures::future::BoxFuture; +use reqwest::Method; use reqwest::header::AUTHORIZATION; +use reqwest::header::HeaderName; +use reqwest::header::HeaderValue; use tracing::debug; use crate::mcp_oauth_http::OAuthHttpClient; @@ -30,7 +40,7 @@ pub async fn determine_streamable_http_auth_status( http_headers, env_http_headers, store_mode, - Arc::new(ReqwestHttpClient), + Arc::new(NoProxyReqwestHttpClient), ) .await } @@ -91,11 +101,91 @@ pub async fn discover_streamable_http_oauth( url, http_headers, env_http_headers, - Arc::new(ReqwestHttpClient), + Arc::new(NoProxyReqwestHttpClient), ) .await } +#[derive(Debug, Clone, Default)] +struct NoProxyReqwestHttpClient; + +impl HttpClient for NoProxyReqwestHttpClient { + fn http_request( + &self, + params: HttpRequestParams, + ) -> BoxFuture<'_, std::result::Result> { + async move { no_proxy_http_request(params).await }.boxed() + } + + fn http_request_stream( + &self, + _params: HttpRequestParams, + ) -> BoxFuture< + '_, + std::result::Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>, + > { + async move { + Err(ExecServerError::Protocol( + "streaming is not supported for OAuth discovery".to_string(), + )) + } + .boxed() + } +} + +async fn no_proxy_http_request( + params: HttpRequestParams, +) -> std::result::Result { + let method = Method::from_bytes(params.method.as_bytes()) + .map_err(|error| ExecServerError::HttpRequest(error.to_string()))?; + let mut builder = reqwest::Client::builder().no_proxy(); + if let Some(timeout_ms) = params.timeout_ms { + builder = builder.timeout(Duration::from_millis(timeout_ms)); + } + let client = builder + .build() + .map_err(|error| ExecServerError::HttpRequest(error.to_string()))?; + + let mut request = client.request(method, ¶ms.url); + for header in params.headers { + let name = HeaderName::from_bytes(header.name.as_bytes()) + .map_err(|error| ExecServerError::HttpRequest(error.to_string()))?; + let value = HeaderValue::from_str(&header.value) + .map_err(|error| ExecServerError::HttpRequest(error.to_string()))?; + request = request.header(name, value); + } + if let Some(body) = params.body { + request = request.body(body.into_inner()); + } + + let response = request + .send() + .await + .map_err(|error| ExecServerError::HttpRequest(error.to_string()))?; + let status = response.status().as_u16(); + let headers = response + .headers() + .iter() + .filter_map(|(name, value)| { + value.to_str().ok().map(|value| HttpHeader { + name: name.to_string(), + value: value.to_string(), + }) + }) + .collect(); + let body = response + .bytes() + .await + .map_err(|error| ExecServerError::HttpRequest(error.to_string()))? + .to_vec(); + + Ok(HttpRequestResponse { + status, + headers, + body: body.into(), + }) +} + pub async fn discover_streamable_http_oauth_with_client( url: &str, http_headers: Option>, diff --git a/codex-rs/rmcp-client/src/oauth.rs b/codex-rs/rmcp-client/src/oauth.rs index a26c24d962..4318b38e23 100644 --- a/codex-rs/rmcp-client/src/oauth.rs +++ b/codex-rs/rmcp-client/src/oauth.rs @@ -260,10 +260,15 @@ struct OAuthPersistorInner { server_name: String, oauth_http: OAuthHttpClient, store_mode: OAuthCredentialsStoreMode, - last_credentials: Mutex>, + credentials_state: Mutex, refresh_gate: Semaphore, } +struct OAuthCredentialsState { + current: Option, + last_persisted: Option, +} + impl OAuthPersistor { pub(crate) fn new( server_name: String, @@ -276,7 +281,10 @@ impl OAuthPersistor { server_name, oauth_http, store_mode, - last_credentials: Mutex::new(initial_credentials), + credentials_state: Mutex::new(OAuthCredentialsState { + current: initial_credentials.clone(), + last_persisted: initial_credentials, + }), refresh_gate: Semaphore::new(1), }), } @@ -299,15 +307,28 @@ impl OAuthPersistor { /// Persists the latest stored credentials if they have changed. pub(crate) async fn persist_if_needed(&self) -> Result<()> { - let mut last_credentials = self.inner.last_credentials.lock().await; - if let Some(stored) = last_credentials.as_ref() { - if load_oauth_tokens(&self.inner.server_name, &stored.url, self.inner.store_mode)? - .is_none() - { - *last_credentials = None; - return Ok(()); + let mut credentials_state = self.inner.credentials_state.lock().await; + let Some(stored) = credentials_state.current.clone() else { + return Ok(()); + }; + match load_oauth_tokens(&self.inner.server_name, &stored.url, self.inner.store_mode)? { + None => { + credentials_state.current = None; + credentials_state.last_persisted = None; + } + Some(current_store) if current_store == stored => { + credentials_state.last_persisted = Some(current_store); + } + Some(current_store) + if Some(¤t_store) != credentials_state.last_persisted.as_ref() => + { + credentials_state.current = Some(current_store.clone()); + credentials_state.last_persisted = Some(current_store); + } + Some(_) => { + save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?; + credentials_state.last_persisted = Some(stored); } - save_oauth_tokens(&self.inner.server_name, stored, self.inner.store_mode)?; } Ok(()) } @@ -320,8 +341,8 @@ impl OAuthPersistor { .await .context("OAuth refresh gate was closed")?; let tokens = { - let guard = self.inner.last_credentials.lock().await; - let Some(tokens) = guard.as_ref() else { + let state = self.inner.credentials_state.lock().await; + let Some(tokens) = state.current.as_ref() else { return Ok(()); }; if !token_needs_refresh(tokens.expires_at) { @@ -342,8 +363,8 @@ impl OAuthPersistor { ) })? { - let mut guard = self.inner.last_credentials.lock().await; - *guard = Some(refreshed); + let mut state = self.inner.credentials_state.lock().await; + state.current = Some(refreshed); } self.persist_if_needed().await @@ -351,8 +372,8 @@ impl OAuthPersistor { pub(crate) async fn access_token(&self) -> Result> { let cached = { - let guard = self.inner.last_credentials.lock().await; - guard.as_ref().map(|tokens| { + let state = self.inner.credentials_state.lock().await; + state.current.as_ref().map(|tokens| { ( tokens.token_response.0.access_token().secret().to_string(), tokens.expires_at, @@ -371,10 +392,16 @@ impl OAuthPersistor { } return Err(err); } - let guard = self.inner.last_credentials.lock().await; - Ok(guard - .as_ref() - .map(|tokens| tokens.token_response.0.access_token().secret().to_string())) + let state = self.inner.credentials_state.lock().await; + let Some(tokens) = state.current.as_ref() else { + return Ok(None); + }; + if expires_in_from_timestamp(tokens.expires_at.unwrap_or(u64::MAX)).is_none() { + return Ok(None); + } + Ok(Some( + tokens.token_response.0.access_token().secret().to_string(), + )) } } @@ -898,6 +925,38 @@ mod tests { Ok(()) } + #[tokio::test] + async fn persistor_does_not_clobber_newer_file_credentials() -> Result<()> { + let _env = TempCodexHome::new(); + let tokens = sample_tokens(); + super::save_oauth_tokens_to_file(&tokens)?; + let persistor = OAuthPersistor::new( + tokens.server_name.clone(), + OAuthHttpClient::new( + Arc::new(FailingHttpClient), + /*http_headers*/ None, + /*env_http_headers*/ None, + )?, + OAuthCredentialsStoreMode::File, + Some(tokens.clone()), + ); + + let newer_tokens = sample_tokens_with_access_token("newer-access-token"); + super::save_oauth_tokens_to_file(&newer_tokens)?; + + persistor.persist_if_needed().await?; + + assert_eq!( + super::load_oauth_tokens( + &tokens.server_name, + &tokens.url, + OAuthCredentialsStoreMode::File + )?, + Some(newer_tokens) + ); + Ok(()) + } + #[tokio::test] async fn access_token_uses_cached_token_when_refresh_fails_before_expiry() -> Result<()> { let mut tokens = sample_tokens(); @@ -924,6 +983,30 @@ mod tests { Ok(()) } + #[tokio::test] + async fn access_token_rejects_expired_token_without_refresh_token() -> Result<()> { + let mut tokens = sample_tokens(); + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_else(|_| Duration::from_secs(0)); + tokens.expires_at = Some((now.as_millis() as u64).saturating_sub(1000)); + tokens.token_response.0.set_refresh_token(None); + + let persistor = OAuthPersistor::new( + tokens.server_name.clone(), + OAuthHttpClient::new( + Arc::new(FailingHttpClient), + /*http_headers*/ None, + /*env_http_headers*/ None, + )?, + OAuthCredentialsStoreMode::File, + Some(tokens), + ); + + assert_eq!(persistor.access_token().await?, None); + Ok(()) + } + struct FailingHttpClient; impl HttpClient for FailingHttpClient { @@ -995,8 +1078,12 @@ mod tests { } fn sample_tokens() -> StoredOAuthTokens { + sample_tokens_with_access_token("access-token") + } + + fn sample_tokens_with_access_token(access_token: &str) -> StoredOAuthTokens { let mut response = OAuthTokenResponse::new( - AccessToken::new("access-token".to_string()), + AccessToken::new(access_token.to_string()), BasicTokenType::Bearer, EmptyExtraTokenFields {}, );