Address MCP OAuth review feedback

Restore stored-token bearer fallback by probing OAuth metadata through the injected runtime HTTP client before creating the OAuth transport. Also broaden discovery to OpenID Connect and protected-resource metadata paths, and avoid writing stale refresh results over newer in-memory credentials.

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Ahmed Ibrahim
2026-04-27 11:34:20 +00:00
parent 9df8e3d935
commit 35080d5cd0
3 changed files with 167 additions and 62 deletions

View File

@@ -83,6 +83,15 @@ struct OAuthClientConfig {
client_secret: Option<String>,
}
#[derive(Debug)]
enum DiscoveredOAuthMetadata {
AuthorizationServer(OAuthDiscoveryMetadata),
ProtectedResource {
metadata_url: Url,
authorization_servers: Vec<String>,
},
}
impl OAuthHttpClient {
pub(crate) fn new(
http_client: Arc<dyn HttpClient>,
@@ -118,48 +127,71 @@ impl OAuthHttpClient {
}
async fn discover_metadata(&self, url: &str) -> Result<Option<OAuthDiscoveryMetadata>> {
match self.discover_metadata_from_paths(url).await? {
Some(DiscoveredOAuthMetadata::AuthorizationServer(metadata)) => Ok(Some(metadata)),
Some(DiscoveredOAuthMetadata::ProtectedResource {
metadata_url,
authorization_servers,
}) => {
for authorization_server in authorization_servers {
let authorization_server = authorization_server.trim();
if authorization_server.is_empty() {
continue;
}
let authorization_server_url = match Url::parse(authorization_server) {
Ok(url) => url,
Err(_) => match metadata_url.join(authorization_server) {
Ok(url) => url,
Err(err) => {
debug!(
"failed to resolve OAuth authorization server URL `{authorization_server}`: {err}"
);
continue;
}
},
};
let discovered = if authorization_server_url.path().contains("/.well-known/") {
self.fetch_discovery_metadata(&authorization_server_url)
.await?
} else {
self.discover_metadata_from_paths(authorization_server_url.as_str())
.await?
};
if let Some(DiscoveredOAuthMetadata::AuthorizationServer(metadata)) = discovered
{
return Ok(Some(metadata));
}
}
Ok(None)
}
None => Ok(None),
}
}
async fn discover_metadata_from_paths(
&self,
url: &str,
) -> Result<Option<DiscoveredOAuthMetadata>> {
let base_url = Url::parse(url)?;
let mut last_error: Option<anyhow::Error> = None;
for candidate_path in discovery_paths(base_url.path()) {
let mut discovery_url = base_url.clone();
discovery_url.set_query(None);
discovery_url.set_fragment(None);
discovery_url.set_path(&candidate_path);
let response = match self
.request(
"GET",
discovery_url.as_str(),
vec![HttpHeader {
name: OAUTH_DISCOVERY_HEADER.to_string(),
value: OAUTH_DISCOVERY_VERSION.to_string(),
}],
/*body*/ None,
Some(DISCOVERY_TIMEOUT),
)
.await
{
Ok(response) => response,
match self.fetch_discovery_metadata(&discovery_url).await {
Ok(Some(metadata)) => return Ok(Some(metadata)),
Ok(None) => {}
Err(err) => {
last_error = Some(err);
continue;
}
};
if response.status != StatusCode::OK.as_u16() {
continue;
}
let metadata = match serde_json::from_slice::<OAuthDiscoveryMetadata>(&response.body.0)
{
Ok(metadata) => metadata,
Err(err) => {
last_error = Some(err.into());
continue;
}
};
if metadata.authorization_endpoint.is_some() && metadata.token_endpoint.is_some() {
return Ok(Some(metadata));
}
}
@@ -170,6 +202,67 @@ impl OAuthHttpClient {
Ok(None)
}
async fn fetch_discovery_metadata(
&self,
discovery_url: &Url,
) -> Result<Option<DiscoveredOAuthMetadata>> {
let response = self
.request(
"GET",
discovery_url.as_str(),
vec![HttpHeader {
name: OAUTH_DISCOVERY_HEADER.to_string(),
value: OAUTH_DISCOVERY_VERSION.to_string(),
}],
/*body*/ None,
Some(DISCOVERY_TIMEOUT),
)
.await?;
if response.status != StatusCode::OK.as_u16() {
return Ok(None);
}
let metadata: OAuthDiscoveryMetadata = serde_json::from_slice(&response.body.0)?;
if metadata.authorization_endpoint.is_some() && metadata.token_endpoint.is_some() {
return Ok(Some(DiscoveredOAuthMetadata::AuthorizationServer(metadata)));
}
let authorization_servers = {
let mut authorization_servers = Vec::new();
let mut push_unique = |authorization_server: &String| {
let authorization_server = authorization_server.trim();
if !authorization_server.is_empty()
&& !authorization_servers
.iter()
.any(|existing| existing == authorization_server)
{
authorization_servers.push(authorization_server.to_string());
}
};
if let Some(authorization_server) = metadata.authorization_server.as_ref() {
push_unique(authorization_server);
}
if let Some(list) = metadata.authorization_servers.as_ref() {
for authorization_server in list {
push_unique(authorization_server);
}
}
authorization_servers
};
if authorization_servers.is_empty() {
Ok(None)
} else {
Ok(Some(DiscoveredOAuthMetadata::ProtectedResource {
metadata_url: discovery_url.clone(),
authorization_servers,
}))
}
}
pub(crate) async fn start_authorization(
&self,
server_url: &str,
@@ -410,6 +503,10 @@ struct OAuthDiscoveryMetadata {
scopes_supported: Option<Vec<String>>,
#[serde(default)]
response_types_supported: Option<Vec<String>>,
#[serde(default)]
authorization_server: Option<String>,
#[serde(default)]
authorization_servers: Option<Vec<String>>,
}
#[derive(Debug, Serialize)]
@@ -514,15 +611,10 @@ pub(crate) fn normalize_scopes(scopes_supported: Option<Vec<String>>) -> Option<
}
}
/// Implements RFC 8414 section 3.1 for discovering well-known oauth endpoints.
/// Implements RFC 8414 section 3.1 and RFC 9728 section 3 for discovering
/// well-known OAuth endpoints.
pub(crate) fn discovery_paths(base_path: &str) -> Vec<String> {
let trimmed = base_path.trim_start_matches('/').trim_end_matches('/');
let canonical = "/.well-known/oauth-authorization-server".to_string();
if trimmed.is_empty() {
return vec![canonical];
}
let mut candidates = Vec::new();
let mut push_unique = |candidate: String| {
if !candidates.contains(&candidate) {
@@ -530,9 +622,20 @@ pub(crate) fn discovery_paths(base_path: &str) -> Vec<String> {
}
};
push_unique(format!("{canonical}/{trimmed}"));
push_unique(format!("/{trimmed}/.well-known/oauth-authorization-server"));
push_unique(canonical);
let mut push_well_known = |well_known: &str| {
let canonical = format!("/.well-known/{well_known}");
if trimmed.is_empty() {
push_unique(canonical);
} else {
push_unique(format!("{canonical}/{trimmed}"));
push_unique(format!("/{trimmed}/.well-known/{well_known}"));
push_unique(canonical);
}
};
push_well_known("oauth-authorization-server");
push_well_known("openid-configuration");
push_well_known("oauth-protected-resource");
candidates
}
@@ -550,6 +653,12 @@ mod tests {
"/.well-known/oauth-authorization-server/mcp".to_string(),
"/mcp/.well-known/oauth-authorization-server".to_string(),
"/.well-known/oauth-authorization-server".to_string(),
"/.well-known/openid-configuration/mcp".to_string(),
"/mcp/.well-known/openid-configuration".to_string(),
"/.well-known/openid-configuration".to_string(),
"/.well-known/oauth-protected-resource/mcp".to_string(),
"/mcp/.well-known/oauth-protected-resource".to_string(),
"/.well-known/oauth-protected-resource".to_string(),
]
);
}
@@ -558,7 +667,11 @@ mod tests {
fn discovery_paths_deduplicate_root_path() {
assert_eq!(
discovery_paths("/"),
vec!["/.well-known/oauth-authorization-server".to_string()]
vec![
"/.well-known/oauth-authorization-server".to_string(),
"/.well-known/openid-configuration".to_string(),
"/.well-known/oauth-protected-resource".to_string(),
]
);
}

View File

@@ -20,7 +20,6 @@ use anyhow::Context;
use anyhow::Error;
use anyhow::Result;
use codex_config::types::OAuthCredentialsStoreMode;
use codex_exec_server::HttpClient;
use oauth2::AccessToken;
use oauth2::EmptyExtraTokenFields;
use oauth2::RefreshToken;
@@ -290,21 +289,6 @@ impl OAuthPersistor {
}
}
pub(crate) fn with_http_client(
server_name: String,
http_client: Arc<dyn HttpClient>,
default_headers: reqwest::header::HeaderMap,
store_mode: OAuthCredentialsStoreMode,
initial_credentials: Option<StoredOAuthTokens>,
) -> Self {
Self::new(
server_name,
OAuthHttpClient::from_default_headers(http_client, default_headers),
store_mode,
initial_credentials,
)
}
/// Persists the latest stored credentials if they have changed.
pub(crate) async fn persist_if_needed(&self) -> Result<()> {
let mut credentials_state = self.inner.credentials_state.lock().await;
@@ -364,7 +348,9 @@ impl OAuthPersistor {
})?
{
let mut state = self.inner.credentials_state.lock().await;
state.current = Some(refreshed);
if state.current.as_ref() == Some(&tokens) {
state.current = Some(refreshed);
}
}
self.persist_if_needed().await

View File

@@ -59,6 +59,7 @@ use crate::elicitation_client_service::ElicitationClientService;
use crate::http_client_adapter::StreamableHttpClientAdapter;
use crate::http_client_adapter::StreamableHttpClientAdapterError;
use crate::load_oauth_tokens;
use crate::mcp_oauth_http::OAuthHttpClient;
use crate::oauth::OAuthPersistor;
use crate::oauth::StoredOAuthTokens;
use crate::stdio_server_launcher::StdioServerCommand;
@@ -939,10 +940,15 @@ async fn create_oauth_transport_and_runtime(
StreamableHttpClientTransport<StreamableHttpClientAdapter>,
OAuthPersistor,
)> {
let oauth_persistor = OAuthPersistor::with_http_client(
let oauth_http =
OAuthHttpClient::from_default_headers(Arc::clone(&http_client), default_headers.clone());
if oauth_http.discover(url).await?.is_none() {
anyhow::bail!("No authorization support detected");
}
let oauth_persistor = OAuthPersistor::new(
server_name.to_string(),
Arc::clone(&http_client),
default_headers.clone(),
oauth_http,
credentials_store,
Some(initial_tokens),
);