Address remote OAuth review feedback

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Ahmed Ibrahim
2026-04-27 10:38:51 +00:00
parent 276a97340c
commit 520ee70b73
3 changed files with 253 additions and 131 deletions

View File

@@ -20,17 +20,14 @@ use codex_exec_server::Environment;
use codex_exec_server::EnvironmentManager;
use codex_exec_server::EnvironmentManagerArgs;
use codex_exec_server::ExecServerRuntimePaths;
use codex_mcp::McpOAuthLoginOutcome;
use codex_mcp::McpOAuthLoginSupport;
use codex_mcp::McpRuntimeEnvironment;
use codex_mcp::ResolvedMcpOAuthScopes;
use codex_mcp::compute_auth_statuses;
use codex_mcp::discover_supported_scopes;
use codex_mcp::oauth_login_support;
use codex_mcp::resolve_oauth_scopes;
use codex_mcp::should_retry_without_scopes;
use codex_mcp::oauth_login_support_for_server;
use codex_mcp::perform_oauth_login_for_server;
use codex_protocol::protocol::McpAuthStatus;
use codex_rmcp_client::delete_oauth_tokens;
use codex_rmcp_client::perform_oauth_login;
use codex_utils_cli::CliConfigOverrides;
use codex_utils_cli::format_env_display;
@@ -194,54 +191,6 @@ impl McpCli {
}
}
/// Preserve compatibility with servers that still expect the legacy empty-scope
/// OAuth request. If a discovered-scope request is rejected by the provider,
/// retry the login flow once without scopes.
#[allow(clippy::too_many_arguments)]
async fn perform_oauth_login_retry_without_scopes(
name: &str,
url: &str,
store_mode: codex_config::types::OAuthCredentialsStoreMode,
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
resolved_scopes: &ResolvedMcpOAuthScopes,
oauth_resource: Option<&str>,
callback_port: Option<u16>,
callback_url: Option<&str>,
) -> Result<()> {
match perform_oauth_login(
name,
url,
store_mode,
http_headers.clone(),
env_http_headers.clone(),
&resolved_scopes.scopes,
oauth_resource,
callback_port,
callback_url,
)
.await
{
Ok(()) => Ok(()),
Err(err) if should_retry_without_scopes(resolved_scopes, &err) => {
println!("OAuth provider rejected discovered scopes. Retrying without scopes…");
perform_oauth_login(
name,
url,
store_mode,
http_headers,
env_http_headers,
&[],
oauth_resource,
callback_port,
callback_url,
)
.await
}
Err(err) => Err(err),
}
}
async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<()> {
// Validate any provided overrides even though they are not currently applied.
let overrides = config_overrides
@@ -303,7 +252,7 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re
};
let new_entry = McpServerConfig {
transport: transport.clone(),
transport,
experimental_environment: None,
enabled: true,
required: false,
@@ -319,7 +268,7 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re
tools: HashMap::new(),
};
servers.insert(name.clone(), new_entry);
servers.insert(name.clone(), new_entry.clone());
ConfigEditsBuilder::new(&codex_home)
.replace_mcp_servers(&servers)
@@ -329,27 +278,24 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re
println!("Added global MCP server '{name}'.");
match oauth_login_support(&transport).await {
McpOAuthLoginSupport::Supported(oauth_config) => {
let runtime_environment = runtime_environment_from_config(&config);
match oauth_login_support_for_server(&new_entry, runtime_environment.clone()).await {
McpOAuthLoginSupport::Supported(_) => {
println!("Detected OAuth support. Starting OAuth flow…");
let resolved_scopes = resolve_oauth_scopes(
/*explicit_scopes*/ None,
/*configured_scopes*/ None,
oauth_config.discovered_scopes.clone(),
);
perform_oauth_login_retry_without_scopes(
match perform_oauth_login_for_server(
&name,
&oauth_config.url,
&new_entry,
config.mcp_oauth_credentials_store_mode,
oauth_config.http_headers,
oauth_config.env_http_headers,
&resolved_scopes,
/*oauth_resource*/ None,
/*explicit_scopes*/ None,
config.mcp_oauth_callback_port,
config.mcp_oauth_callback_url.as_deref(),
runtime_environment,
)
.await?;
println!("Successfully logged in.");
.await?
{
McpOAuthLoginOutcome::Completed => println!("Successfully logged in."),
McpOAuthLoginOutcome::Unsupported => {}
}
}
McpOAuthLoginSupport::Unsupported => {}
McpOAuthLoginSupport::Unknown(_) => println!(
@@ -411,38 +357,32 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs)
bail!("No MCP server named '{name}' found.");
};
let (url, http_headers, env_http_headers) = match &server.transport {
McpServerTransportConfig::StreamableHttp {
url,
http_headers,
env_http_headers,
..
} => (url.clone(), http_headers.clone(), env_http_headers.clone()),
_ => bail!("OAuth login is only supported for streamable HTTP servers."),
if !matches!(
&server.transport,
McpServerTransportConfig::StreamableHttp { .. }
) {
bail!("OAuth login is only supported for streamable HTTP servers.");
};
let explicit_scopes = (!scopes.is_empty()).then_some(scopes);
let discovered_scopes = if explicit_scopes.is_none() && server.scopes.is_none() {
discover_supported_scopes(&server.transport).await
} else {
None
};
let resolved_scopes =
resolve_oauth_scopes(explicit_scopes, server.scopes.clone(), discovered_scopes);
perform_oauth_login_retry_without_scopes(
match perform_oauth_login_for_server(
&name,
&url,
server,
config.mcp_oauth_credentials_store_mode,
http_headers,
env_http_headers,
&resolved_scopes,
server.oauth_resource.as_deref(),
explicit_scopes,
config.mcp_oauth_callback_port,
config.mcp_oauth_callback_url.as_deref(),
runtime_environment_from_config(&config),
)
.await?;
println!("Successfully logged in to MCP server '{name}'.");
.await?
{
McpOAuthLoginOutcome::Completed => {
println!("Successfully logged in to MCP server '{name}'.")
}
McpOAuthLoginOutcome::Unsupported => {
bail!("MCP server '{name}' does not advertise OAuth support.")
}
}
Ok(())
}
@@ -478,17 +418,7 @@ async fn run_logout(config_overrides: &CliConfigOverrides, logout_args: LogoutAr
Ok(())
}
async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Result<()> {
let overrides = config_overrides
.parse_overrides()
.map_err(anyhow::Error::msg)?;
let config = Config::load_with_cli_overrides(overrides)
.await
.context("failed to load configuration")?;
let mcp_manager = McpManager::new(Arc::new(PluginsManager::new(
config.codex_home.to_path_buf(),
)));
let mcp_servers = mcp_manager.effective_servers(&config, /*auth*/ None).await;
fn runtime_environment_from_config(config: &Config) -> McpRuntimeEnvironment {
let local_runtime_paths = ExecServerRuntimePaths::from_optional_paths(
config.codex_self_exe.clone(),
config.codex_linux_sandbox_exe.clone(),
@@ -506,6 +436,21 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
.unwrap_or_else(|_| Environment::default_for_tests()),
),
};
McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf())
}
async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Result<()> {
let overrides = config_overrides
.parse_overrides()
.map_err(anyhow::Error::msg)?;
let config = Config::load_with_cli_overrides(overrides)
.await
.context("failed to load configuration")?;
let mcp_manager = McpManager::new(Arc::new(PluginsManager::new(
config.codex_home.to_path_buf(),
)));
let mcp_servers = mcp_manager.effective_servers(&config, /*auth*/ None).await;
let runtime_environment = runtime_environment_from_config(&config);
let mut entries: Vec<_> = mcp_servers.iter().collect();
entries.sort_by(|(a, _), (b, _)| a.cmp(b));
@@ -513,7 +458,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
mcp_servers.iter(),
config.mcp_oauth_credentials_store_mode,
/*auth*/ None,
McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf()),
runtime_environment,
)
.await;

View File

@@ -1,11 +1,21 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use codex_exec_server::ExecServerError;
use codex_exec_server::HttpClient;
use codex_exec_server::ReqwestHttpClient;
use codex_exec_server::HttpHeader;
use codex_exec_server::HttpRequestParams;
use codex_exec_server::HttpRequestResponse;
use codex_exec_server::HttpResponseBodyStream;
use codex_protocol::protocol::McpAuthStatus;
use futures::FutureExt;
use futures::future::BoxFuture;
use reqwest::Method;
use reqwest::header::AUTHORIZATION;
use reqwest::header::HeaderName;
use reqwest::header::HeaderValue;
use tracing::debug;
use crate::mcp_oauth_http::OAuthHttpClient;
@@ -30,7 +40,7 @@ pub async fn determine_streamable_http_auth_status(
http_headers,
env_http_headers,
store_mode,
Arc::new(ReqwestHttpClient),
Arc::new(NoProxyReqwestHttpClient),
)
.await
}
@@ -91,11 +101,91 @@ pub async fn discover_streamable_http_oauth(
url,
http_headers,
env_http_headers,
Arc::new(ReqwestHttpClient),
Arc::new(NoProxyReqwestHttpClient),
)
.await
}
#[derive(Debug, Clone, Default)]
struct NoProxyReqwestHttpClient;
impl HttpClient for NoProxyReqwestHttpClient {
fn http_request(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, std::result::Result<HttpRequestResponse, ExecServerError>> {
async move { no_proxy_http_request(params).await }.boxed()
}
fn http_request_stream(
&self,
_params: HttpRequestParams,
) -> BoxFuture<
'_,
std::result::Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>,
> {
async move {
Err(ExecServerError::Protocol(
"streaming is not supported for OAuth discovery".to_string(),
))
}
.boxed()
}
}
async fn no_proxy_http_request(
params: HttpRequestParams,
) -> std::result::Result<HttpRequestResponse, ExecServerError> {
let method = Method::from_bytes(params.method.as_bytes())
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?;
let mut builder = reqwest::Client::builder().no_proxy();
if let Some(timeout_ms) = params.timeout_ms {
builder = builder.timeout(Duration::from_millis(timeout_ms));
}
let client = builder
.build()
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?;
let mut request = client.request(method, &params.url);
for header in params.headers {
let name = HeaderName::from_bytes(header.name.as_bytes())
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?;
let value = HeaderValue::from_str(&header.value)
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?;
request = request.header(name, value);
}
if let Some(body) = params.body {
request = request.body(body.into_inner());
}
let response = request
.send()
.await
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?;
let status = response.status().as_u16();
let headers = response
.headers()
.iter()
.filter_map(|(name, value)| {
value.to_str().ok().map(|value| HttpHeader {
name: name.to_string(),
value: value.to_string(),
})
})
.collect();
let body = response
.bytes()
.await
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?
.to_vec();
Ok(HttpRequestResponse {
status,
headers,
body: body.into(),
})
}
pub async fn discover_streamable_http_oauth_with_client(
url: &str,
http_headers: Option<HashMap<String, String>>,

View File

@@ -260,10 +260,15 @@ struct OAuthPersistorInner {
server_name: String,
oauth_http: OAuthHttpClient,
store_mode: OAuthCredentialsStoreMode,
last_credentials: Mutex<Option<StoredOAuthTokens>>,
credentials_state: Mutex<OAuthCredentialsState>,
refresh_gate: Semaphore,
}
struct OAuthCredentialsState {
current: Option<StoredOAuthTokens>,
last_persisted: Option<StoredOAuthTokens>,
}
impl OAuthPersistor {
pub(crate) fn new(
server_name: String,
@@ -276,7 +281,10 @@ impl OAuthPersistor {
server_name,
oauth_http,
store_mode,
last_credentials: Mutex::new(initial_credentials),
credentials_state: Mutex::new(OAuthCredentialsState {
current: initial_credentials.clone(),
last_persisted: initial_credentials,
}),
refresh_gate: Semaphore::new(1),
}),
}
@@ -299,15 +307,28 @@ impl OAuthPersistor {
/// Persists the latest stored credentials if they have changed.
pub(crate) async fn persist_if_needed(&self) -> Result<()> {
let mut last_credentials = self.inner.last_credentials.lock().await;
if let Some(stored) = last_credentials.as_ref() {
if load_oauth_tokens(&self.inner.server_name, &stored.url, self.inner.store_mode)?
.is_none()
{
*last_credentials = None;
return Ok(());
let mut credentials_state = self.inner.credentials_state.lock().await;
let Some(stored) = credentials_state.current.clone() else {
return Ok(());
};
match load_oauth_tokens(&self.inner.server_name, &stored.url, self.inner.store_mode)? {
None => {
credentials_state.current = None;
credentials_state.last_persisted = None;
}
Some(current_store) if current_store == stored => {
credentials_state.last_persisted = Some(current_store);
}
Some(current_store)
if Some(&current_store) != credentials_state.last_persisted.as_ref() =>
{
credentials_state.current = Some(current_store.clone());
credentials_state.last_persisted = Some(current_store);
}
Some(_) => {
save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?;
credentials_state.last_persisted = Some(stored);
}
save_oauth_tokens(&self.inner.server_name, stored, self.inner.store_mode)?;
}
Ok(())
}
@@ -320,8 +341,8 @@ impl OAuthPersistor {
.await
.context("OAuth refresh gate was closed")?;
let tokens = {
let guard = self.inner.last_credentials.lock().await;
let Some(tokens) = guard.as_ref() else {
let state = self.inner.credentials_state.lock().await;
let Some(tokens) = state.current.as_ref() else {
return Ok(());
};
if !token_needs_refresh(tokens.expires_at) {
@@ -342,8 +363,8 @@ impl OAuthPersistor {
)
})?
{
let mut guard = self.inner.last_credentials.lock().await;
*guard = Some(refreshed);
let mut state = self.inner.credentials_state.lock().await;
state.current = Some(refreshed);
}
self.persist_if_needed().await
@@ -351,8 +372,8 @@ impl OAuthPersistor {
pub(crate) async fn access_token(&self) -> Result<Option<String>> {
let cached = {
let guard = self.inner.last_credentials.lock().await;
guard.as_ref().map(|tokens| {
let state = self.inner.credentials_state.lock().await;
state.current.as_ref().map(|tokens| {
(
tokens.token_response.0.access_token().secret().to_string(),
tokens.expires_at,
@@ -371,10 +392,16 @@ impl OAuthPersistor {
}
return Err(err);
}
let guard = self.inner.last_credentials.lock().await;
Ok(guard
.as_ref()
.map(|tokens| tokens.token_response.0.access_token().secret().to_string()))
let state = self.inner.credentials_state.lock().await;
let Some(tokens) = state.current.as_ref() else {
return Ok(None);
};
if expires_in_from_timestamp(tokens.expires_at.unwrap_or(u64::MAX)).is_none() {
return Ok(None);
}
Ok(Some(
tokens.token_response.0.access_token().secret().to_string(),
))
}
}
@@ -898,6 +925,38 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn persistor_does_not_clobber_newer_file_credentials() -> Result<()> {
let _env = TempCodexHome::new();
let tokens = sample_tokens();
super::save_oauth_tokens_to_file(&tokens)?;
let persistor = OAuthPersistor::new(
tokens.server_name.clone(),
OAuthHttpClient::new(
Arc::new(FailingHttpClient),
/*http_headers*/ None,
/*env_http_headers*/ None,
)?,
OAuthCredentialsStoreMode::File,
Some(tokens.clone()),
);
let newer_tokens = sample_tokens_with_access_token("newer-access-token");
super::save_oauth_tokens_to_file(&newer_tokens)?;
persistor.persist_if_needed().await?;
assert_eq!(
super::load_oauth_tokens(
&tokens.server_name,
&tokens.url,
OAuthCredentialsStoreMode::File
)?,
Some(newer_tokens)
);
Ok(())
}
#[tokio::test]
async fn access_token_uses_cached_token_when_refresh_fails_before_expiry() -> Result<()> {
let mut tokens = sample_tokens();
@@ -924,6 +983,30 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn access_token_rejects_expired_token_without_refresh_token() -> Result<()> {
let mut tokens = sample_tokens();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0));
tokens.expires_at = Some((now.as_millis() as u64).saturating_sub(1000));
tokens.token_response.0.set_refresh_token(None);
let persistor = OAuthPersistor::new(
tokens.server_name.clone(),
OAuthHttpClient::new(
Arc::new(FailingHttpClient),
/*http_headers*/ None,
/*env_http_headers*/ None,
)?,
OAuthCredentialsStoreMode::File,
Some(tokens),
);
assert_eq!(persistor.access_token().await?, None);
Ok(())
}
struct FailingHttpClient;
impl HttpClient for FailingHttpClient {
@@ -995,8 +1078,12 @@ mod tests {
}
fn sample_tokens() -> StoredOAuthTokens {
sample_tokens_with_access_token("access-token")
}
fn sample_tokens_with_access_token(access_token: &str) -> StoredOAuthTokens {
let mut response = OAuthTokenResponse::new(
AccessToken::new("access-token".to_string()),
AccessToken::new(access_token.to_string()),
BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);