Compare commits

...

1 Commits

Author SHA1 Message Date
Casey Chow
45e887e72f fix: persist and reuse OAuth discovery base URL 2026-02-11 15:45:00 -05:00
4 changed files with 185 additions and 51 deletions

View File

@@ -15,6 +15,7 @@ use crate::OAuthCredentialsStoreMode;
use crate::oauth::has_oauth_tokens;
use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
use crate::utils::oauth_auth_url_candidates;
const DISCOVERY_TIMEOUT: Duration = Duration::from_secs(5);
const OAUTH_DISCOVERY_HEADER: &str = "MCP-Protocol-Version";
@@ -57,50 +58,69 @@ pub async fn supports_oauth_login(url: &str) -> Result<bool> {
}
async fn supports_oauth_login_with_headers(url: &str, default_headers: &HeaderMap) -> Result<bool> {
let base_url = Url::parse(url)?;
let mut supports = false;
// Use no_proxy to avoid a bug in the system-configuration crate that
// can result in a panic. See #8912.
let builder = Client::builder().timeout(DISCOVERY_TIMEOUT).no_proxy();
let client = apply_default_headers(builder, default_headers).build()?;
let mut last_error: Option<Error> = None;
for candidate_path in discovery_paths(base_url.path()) {
let mut discovery_url = base_url.clone();
discovery_url.set_path(&candidate_path);
let response = match client
.get(discovery_url.clone())
.header(OAUTH_DISCOVERY_HEADER, OAUTH_DISCOVERY_VERSION)
.send()
.await
{
Ok(response) => response,
for candidate_url in oauth_auth_url_candidates(url) {
let base_url = match Url::parse(&candidate_url) {
Ok(base_url) => base_url,
Err(err) => {
last_error = Some(err.into());
debug!("Skipping OAuth discovery candidate `{candidate_url}`: {err:?}");
continue;
}
};
if response.status() != StatusCode::OK {
continue;
// Use no_proxy to avoid a bug in the system-configuration crate that
// can result in a panic. See #8912.
let builder = Client::builder().timeout(DISCOVERY_TIMEOUT).no_proxy();
let client = apply_default_headers(builder, default_headers).build()?;
let mut last_error: Option<Error> = None;
for candidate_path in discovery_paths(base_url.path()) {
let mut discovery_url = base_url.clone();
discovery_url.set_path(&candidate_path);
let response = match client
.get(discovery_url.clone())
.header(OAUTH_DISCOVERY_HEADER, OAUTH_DISCOVERY_VERSION)
.send()
.await
{
Ok(response) => response,
Err(err) => {
last_error = Some(err.into());
continue;
}
};
if response.status() != StatusCode::OK {
continue;
}
let metadata = match response.json::<OAuthDiscoveryMetadata>().await {
Ok(metadata) => metadata,
Err(err) => {
last_error = Some(err.into());
continue;
}
};
if metadata.authorization_endpoint.is_some() && metadata.token_endpoint.is_some() {
supports = true;
break;
}
}
let metadata = match response.json::<OAuthDiscoveryMetadata>().await {
Ok(metadata) => metadata,
Err(err) => {
last_error = Some(err.into());
continue;
}
};
if let Some(err) = last_error {
debug!("OAuth discovery requests failed for {candidate_url}: {err:?}");
}
if metadata.authorization_endpoint.is_some() && metadata.token_endpoint.is_some() {
return Ok(true);
if supports {
break;
}
}
if let Some(err) = last_error {
debug!("OAuth discovery requests failed for {url}: {err:?}");
if supports {
return Ok(true);
}
Ok(false)

View File

@@ -13,6 +13,7 @@ use tiny_http::Response;
use tiny_http::Server;
use tokio::sync::oneshot;
use tokio::time::timeout;
use tracing::debug;
use urlencoding::decode;
use crate::OAuthCredentialsStoreMode;
@@ -22,6 +23,7 @@ use crate::oauth::compute_expires_at_millis;
use crate::save_oauth_tokens;
use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
use crate::utils::oauth_auth_url_candidates;
struct OauthHeaders {
http_headers: Option<HashMap<String, String>>,
@@ -219,7 +221,7 @@ struct OauthLoginFlow {
rx: oneshot::Receiver<(String, String)>,
guard: CallbackServerGuard,
server_name: String,
server_url: String,
oauth_server_url: String,
store_mode: OAuthCredentialsStoreMode,
launch_browser: bool,
timeout: Duration,
@@ -288,12 +290,61 @@ impl OauthLoginFlow {
let default_headers = build_default_headers(http_headers, env_http_headers)?;
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
oauth_state
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
.await?;
let auth_url = oauth_state.get_authorization_url().await?;
let mut last_error: Option<anyhow::Error> = None;
let mut auth_url = None;
let mut oauth_state = None;
let mut oauth_server_url = None;
for candidate_url in oauth_auth_url_candidates(server_url) {
let mut state = match OAuthState::new(&candidate_url, Some(http_client.clone())).await {
Ok(state) => state,
Err(err) => {
last_error = Some(err.into());
continue;
}
};
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
if let Err(err) = state
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
.await
{
last_error = Some(err.into());
continue;
}
auth_url = match state.get_authorization_url().await {
Ok(candidate_auth_url) => {
oauth_state = Some(state);
oauth_server_url = Some(candidate_url);
Some(candidate_auth_url)
}
Err(err) => {
let err_msg = format!("{err:?}");
last_error = Some(anyhow!(err_msg.clone()));
debug!("OAuth state did not provide authorization URL: {err_msg}");
continue;
}
};
break;
}
let oauth_state = match oauth_state {
Some(state) => state,
None => {
return Err(last_error.unwrap_or_else(|| {
anyhow!("No usable OAuth auth endpoint found for MCP server")
}));
}
};
let auth_url = match auth_url {
Some(url) => url,
None => {
return Err(anyhow!(
"Internal OAuth login error: authorization URL not initialized after state setup"
));
}
};
let oauth_server_url = oauth_server_url.unwrap_or_else(|| server_url.to_string());
let timeout_secs = timeout_secs.unwrap_or(DEFAULT_OAUTH_TIMEOUT_SECS).max(1);
let timeout = Duration::from_secs(timeout_secs as u64);
@@ -303,7 +354,7 @@ impl OauthLoginFlow {
rx,
guard,
server_name: server_name.to_string(),
server_url: server_url.to_string(),
oauth_server_url,
store_mode,
launch_browser,
timeout,
@@ -349,7 +400,7 @@ impl OauthLoginFlow {
let expires_at = compute_expires_at_millis(&credentials);
let stored = StoredOAuthTokens {
server_name: self.server_name.clone(),
url: self.server_url.clone(),
url: self.oauth_server_url.clone(),
client_id,
token_response: WrappedOAuthTokenResponse(credentials),
expires_at,

View File

@@ -59,6 +59,7 @@ use crate::program_resolver;
use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
use crate::utils::create_env_for_mcp_server;
use crate::utils::oauth_auth_url_candidates;
use crate::utils::run_with_timeout;
enum PendingTransport {
@@ -244,21 +245,36 @@ impl RmcpClient {
) -> Result<Self> {
let default_headers = build_default_headers(http_headers, env_http_headers)?;
let initial_oauth_tokens = match bearer_token {
Some(_) => None,
None => match load_oauth_tokens(server_name, url, store_mode) {
Ok(tokens) => tokens,
Err(err) => {
warn!("failed to read tokens for server `{server_name}`: {err}");
None
let (initial_oauth_tokens, oauth_server_url) = match bearer_token {
Some(_) => (None, None),
None => {
let mut token_match = None;
for candidate_url in oauth_auth_url_candidates(url) {
match load_oauth_tokens(server_name, &candidate_url, store_mode) {
Ok(Some(tokens)) => {
token_match = Some((candidate_url, tokens));
break;
}
Ok(None) => {}
Err(err) => {
warn!(
"failed to read tokens for server `{server_name}` at `{candidate_url}`: {err}"
);
}
}
}
},
token_match.map_or((None, None), |(candidate_url, tokens)| {
(Some(tokens), Some(candidate_url))
})
}
};
let transport = if let Some(initial_tokens) = initial_oauth_tokens.clone() {
let oauth_url = oauth_server_url.as_deref().unwrap_or(url);
match create_oauth_transport_and_runtime(
server_name,
url,
oauth_url,
initial_tokens.clone(),
store_mode,
default_headers.clone(),
@@ -583,6 +599,7 @@ impl RmcpClient {
async fn create_oauth_transport_and_runtime(
server_name: &str,
url: &str,
oauth_url: &str,
initial_tokens: StoredOAuthTokens,
credentials_store: OAuthCredentialsStoreMode,
default_headers: HeaderMap,
@@ -592,7 +609,7 @@ async fn create_oauth_transport_and_runtime(
)> {
let http_client =
apply_default_headers(reqwest::Client::builder(), &default_headers).build()?;
let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?;
let mut oauth_state = OAuthState::new(oauth_url.to_string(), Some(http_client.clone())).await?;
oauth_state
.set_credentials(
@@ -619,7 +636,7 @@ async fn create_oauth_transport_and_runtime(
let runtime = OAuthPersistor::new(
server_name.to_string(),
url.to_string(),
oauth_url.to_string(),
auth_manager,
credentials_store,
Some(initial_tokens),

View File

@@ -6,6 +6,7 @@ use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use reqwest::ClientBuilder;
use reqwest::Url;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderName;
use reqwest::header::HeaderValue;
@@ -112,6 +113,33 @@ pub(crate) fn apply_default_headers(
}
}
pub(crate) fn oauth_auth_url_candidates(server_url: &str) -> Vec<String> {
let mut candidates = vec![server_url.to_string()];
let Ok(mut base_url) = Url::parse(server_url) else {
return candidates;
};
let mut trimmed_path = base_url.path().trim_end_matches('/').to_string();
if let Some(truncated) = trimmed_path.strip_suffix("/mcp") {
trimmed_path = truncated.to_string();
if trimmed_path.is_empty() {
base_url.set_path("/");
} else {
base_url.set_path(&trimmed_path);
}
base_url.set_query(None);
base_url.set_fragment(None);
let fallback = base_url.to_string().trim_end_matches('/').to_string();
if !candidates.contains(&fallback) {
candidates.push(fallback);
}
}
candidates
}
#[cfg(unix)]
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
"HOME",
@@ -215,4 +243,22 @@ mod tests {
let env = create_env_for_mcp_server(None, &[custom_var.to_string()]);
assert_eq!(env.get(custom_var), Some(&value.to_string()));
}
#[test]
fn oauth_auth_url_candidates_returns_input_when_not_mcp_path() {
let candidates = oauth_auth_url_candidates("https://example.com/api");
assert_eq!(candidates, vec!["https://example.com/api".to_string()]);
}
#[test]
fn oauth_auth_url_candidates_strips_trailing_mcp_segment() {
let candidates = oauth_auth_url_candidates("https://example.com/mcp");
assert_eq!(
candidates,
vec![
"https://example.com/mcp".to_string(),
"https://example.com".to_string()
]
);
}
}