diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index d6d79b3daa..5a8c95a9ff 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -15,7 +15,11 @@ use codex_core::config::edit::ConfigEditsBuilder; use codex_core::config::find_codex_home; use codex_core::config::load_global_mcp_servers; use codex_core::plugins::PluginsManager; +use codex_exec_server::CODEX_EXEC_SERVER_URL_ENV_VAR; use codex_exec_server::Environment; +use codex_exec_server::EnvironmentManager; +use codex_exec_server::EnvironmentManagerArgs; +use codex_exec_server::ExecServerRuntimePaths; use codex_mcp::McpOAuthLoginSupport; use codex_mcp::McpRuntimeEnvironment; use codex_mcp::ResolvedMcpOAuthScopes; @@ -485,6 +489,23 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> config.codex_home.to_path_buf(), ))); let mcp_servers = mcp_manager.effective_servers(&config, /*auth*/ None).await; + let local_runtime_paths = ExecServerRuntimePaths::from_optional_paths( + config.codex_self_exe.clone(), + config.codex_linux_sandbox_exe.clone(), + ); + let environment = match local_runtime_paths { + Ok(local_runtime_paths) => { + let environment_manager = + EnvironmentManager::new(EnvironmentManagerArgs::from_env(local_runtime_paths)); + environment_manager + .default_environment() + .unwrap_or_else(|| environment_manager.local_environment()) + } + Err(_) => Arc::new( + Environment::create_for_tests(std::env::var(CODEX_EXEC_SERVER_URL_ENV_VAR).ok()) + .unwrap_or_else(|_| Environment::default_for_tests()), + ), + }; let mut entries: Vec<_> = mcp_servers.iter().collect(); entries.sort_by(|(a, _), (b, _)| a.cmp(b)); @@ -492,10 +513,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> mcp_servers.iter(), config.mcp_oauth_credentials_store_mode, /*auth*/ None, - McpRuntimeEnvironment::new( - Arc::new(Environment::default_for_tests()), - config.cwd.to_path_buf(), - ), + McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf()), ) .await; diff --git a/codex-rs/rmcp-client/src/oauth.rs b/codex-rs/rmcp-client/src/oauth.rs index 48e6c0526c..35d5be0c59 100644 --- a/codex-rs/rmcp-client/src/oauth.rs +++ b/codex-rs/rmcp-client/src/oauth.rs @@ -47,6 +47,7 @@ use tracing::warn; use codex_keyring_store::DefaultKeyringStore; use codex_keyring_store::KeyringStore; use tokio::sync::Mutex; +use tokio::sync::Semaphore; use crate::mcp_oauth_http::OAuthHttpClient; use codex_utils_home_dir::find_codex_home; @@ -260,6 +261,7 @@ struct OAuthPersistorInner { oauth_http: OAuthHttpClient, store_mode: OAuthCredentialsStoreMode, last_credentials: Mutex>, + refresh_gate: Semaphore, } impl OAuthPersistor { @@ -275,6 +277,7 @@ impl OAuthPersistor { oauth_http, store_mode, last_credentials: Mutex::new(initial_credentials), + refresh_gate: Semaphore::new(1), }), } } @@ -296,14 +299,26 @@ impl OAuthPersistor { /// Persists the latest stored credentials if they have changed. pub(crate) async fn persist_if_needed(&self) -> Result<()> { - let last_credentials = self.inner.last_credentials.lock().await; + let mut last_credentials = self.inner.last_credentials.lock().await; if let Some(stored) = last_credentials.as_ref() { + if load_oauth_tokens(&self.inner.server_name, &stored.url, self.inner.store_mode)? + .is_none() + { + *last_credentials = None; + return Ok(()); + } save_oauth_tokens(&self.inner.server_name, stored, self.inner.store_mode)?; } Ok(()) } pub(crate) async fn refresh_if_needed(&self) -> Result<()> { + let _permit = self + .inner + .refresh_gate + .acquire() + .await + .context("OAuth refresh gate was closed")?; let tokens = { let guard = self.inner.last_credentials.lock().await; let Some(tokens) = guard.as_ref() else { @@ -335,7 +350,27 @@ impl OAuthPersistor { } pub(crate) async fn access_token(&self) -> Result> { - self.refresh_if_needed().await?; + let cached = { + let guard = self.inner.last_credentials.lock().await; + guard.as_ref().map(|tokens| { + ( + tokens.token_response.0.access_token().secret().to_string(), + tokens.expires_at, + ) + }) + }; + if let Err(err) = self.refresh_if_needed().await { + if let Some((access_token, expires_at)) = cached + && expires_in_from_timestamp(expires_at.unwrap_or(u64::MAX)).is_some() + { + warn!( + "failed to refresh OAuth tokens for server {}; using cached access token: {err:#}", + self.inner.server_name + ); + return Ok(Some(access_token)); + } + return Err(err); + } let guard = self.inner.last_credentials.lock().await; Ok(guard .as_ref() @@ -570,6 +605,12 @@ fn sha_256_prefix(value: &Value) -> Result { mod tests { use super::*; use anyhow::Result; + use codex_exec_server::ExecServerError; + use codex_exec_server::HttpClient; + use codex_exec_server::HttpRequestParams; + use codex_exec_server::HttpRequestResponse; + use codex_exec_server::HttpResponseBodyStream; + use futures::future::BoxFuture; use keyring::Error as KeyringError; use pretty_assertions::assert_eq; use std::sync::Mutex; @@ -821,6 +862,89 @@ mod tests { assert!(tokens.token_response.0.expires_in().is_none()); } + #[tokio::test] + async fn persistor_does_not_recreate_deleted_file_credentials() -> Result<()> { + let _env = TempCodexHome::new(); + let tokens = sample_tokens(); + super::save_oauth_tokens_to_file(&tokens)?; + assert!(super::fallback_file_path()?.exists()); + super::delete_oauth_tokens( + &tokens.server_name, + &tokens.url, + OAuthCredentialsStoreMode::File, + )?; + + let persistor = OAuthPersistor::new( + tokens.server_name.clone(), + OAuthHttpClient::new(Arc::new(FailingHttpClient), None, None)?, + OAuthCredentialsStoreMode::File, + Some(tokens.clone()), + ); + + persistor.persist_if_needed().await?; + + assert!( + super::load_oauth_tokens( + &tokens.server_name, + &tokens.url, + OAuthCredentialsStoreMode::File + )? + .is_none() + ); + Ok(()) + } + + #[tokio::test] + async fn access_token_uses_cached_token_when_refresh_fails_before_expiry() -> Result<()> { + let mut tokens = sample_tokens(); + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_else(|_| Duration::from_secs(0)); + tokens.expires_at = Some(now.as_millis() as u64 + 10_000); + + let persistor = OAuthPersistor::new( + tokens.server_name.clone(), + OAuthHttpClient::new(Arc::new(FailingHttpClient), None, None)?, + OAuthCredentialsStoreMode::File, + Some(tokens), + ); + + assert_eq!( + persistor.access_token().await?, + Some("access-token".to_string()) + ); + Ok(()) + } + + struct FailingHttpClient; + + impl HttpClient for FailingHttpClient { + fn http_request( + &self, + _params: HttpRequestParams, + ) -> BoxFuture<'_, std::result::Result> { + Box::pin(async { + Err(ExecServerError::Disconnected( + "forced refresh failure".to_string(), + )) + }) + } + + fn http_request_stream( + &self, + _params: HttpRequestParams, + ) -> BoxFuture< + '_, + std::result::Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>, + > { + Box::pin(async { + Err(ExecServerError::Disconnected( + "forced refresh failure".to_string(), + )) + }) + } + } + fn assert_tokens_match_without_expiry( actual: &StoredOAuthTokens, expected: &StoredOAuthTokens,