mirror of
https://github.com/openai/codex.git
synced 2026-03-02 04:33:54 +00:00
Compare commits
1 Commits
codex/fast
...
dev/caseyc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
45e887e72f |
@@ -15,6 +15,7 @@ use crate::OAuthCredentialsStoreMode;
|
||||
use crate::oauth::has_oauth_tokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
use crate::utils::oauth_auth_url_candidates;
|
||||
|
||||
const DISCOVERY_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const OAUTH_DISCOVERY_HEADER: &str = "MCP-Protocol-Version";
|
||||
@@ -57,50 +58,69 @@ pub async fn supports_oauth_login(url: &str) -> Result<bool> {
|
||||
}
|
||||
|
||||
async fn supports_oauth_login_with_headers(url: &str, default_headers: &HeaderMap) -> Result<bool> {
|
||||
let base_url = Url::parse(url)?;
|
||||
let mut supports = false;
|
||||
|
||||
// Use no_proxy to avoid a bug in the system-configuration crate that
|
||||
// can result in a panic. See #8912.
|
||||
let builder = Client::builder().timeout(DISCOVERY_TIMEOUT).no_proxy();
|
||||
let client = apply_default_headers(builder, default_headers).build()?;
|
||||
|
||||
let mut last_error: Option<Error> = None;
|
||||
for candidate_path in discovery_paths(base_url.path()) {
|
||||
let mut discovery_url = base_url.clone();
|
||||
discovery_url.set_path(&candidate_path);
|
||||
|
||||
let response = match client
|
||||
.get(discovery_url.clone())
|
||||
.header(OAUTH_DISCOVERY_HEADER, OAUTH_DISCOVERY_VERSION)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
for candidate_url in oauth_auth_url_candidates(url) {
|
||||
let base_url = match Url::parse(&candidate_url) {
|
||||
Ok(base_url) => base_url,
|
||||
Err(err) => {
|
||||
last_error = Some(err.into());
|
||||
debug!("Skipping OAuth discovery candidate `{candidate_url}`: {err:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if response.status() != StatusCode::OK {
|
||||
continue;
|
||||
// Use no_proxy to avoid a bug in the system-configuration crate that
|
||||
// can result in a panic. See #8912.
|
||||
let builder = Client::builder().timeout(DISCOVERY_TIMEOUT).no_proxy();
|
||||
let client = apply_default_headers(builder, default_headers).build()?;
|
||||
|
||||
let mut last_error: Option<Error> = None;
|
||||
for candidate_path in discovery_paths(base_url.path()) {
|
||||
let mut discovery_url = base_url.clone();
|
||||
discovery_url.set_path(&candidate_path);
|
||||
|
||||
let response = match client
|
||||
.get(discovery_url.clone())
|
||||
.header(OAUTH_DISCOVERY_HEADER, OAUTH_DISCOVERY_VERSION)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
last_error = Some(err.into());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if response.status() != StatusCode::OK {
|
||||
continue;
|
||||
}
|
||||
|
||||
let metadata = match response.json::<OAuthDiscoveryMetadata>().await {
|
||||
Ok(metadata) => metadata,
|
||||
Err(err) => {
|
||||
last_error = Some(err.into());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if metadata.authorization_endpoint.is_some() && metadata.token_endpoint.is_some() {
|
||||
supports = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let metadata = match response.json::<OAuthDiscoveryMetadata>().await {
|
||||
Ok(metadata) => metadata,
|
||||
Err(err) => {
|
||||
last_error = Some(err.into());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if let Some(err) = last_error {
|
||||
debug!("OAuth discovery requests failed for {candidate_url}: {err:?}");
|
||||
}
|
||||
|
||||
if metadata.authorization_endpoint.is_some() && metadata.token_endpoint.is_some() {
|
||||
return Ok(true);
|
||||
if supports {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(err) = last_error {
|
||||
debug!("OAuth discovery requests failed for {url}: {err:?}");
|
||||
if supports {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
|
||||
@@ -13,6 +13,7 @@ use tiny_http::Response;
|
||||
use tiny_http::Server;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::timeout;
|
||||
use tracing::debug;
|
||||
use urlencoding::decode;
|
||||
|
||||
use crate::OAuthCredentialsStoreMode;
|
||||
@@ -22,6 +23,7 @@ use crate::oauth::compute_expires_at_millis;
|
||||
use crate::save_oauth_tokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
use crate::utils::oauth_auth_url_candidates;
|
||||
|
||||
struct OauthHeaders {
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
@@ -219,7 +221,7 @@ struct OauthLoginFlow {
|
||||
rx: oneshot::Receiver<(String, String)>,
|
||||
guard: CallbackServerGuard,
|
||||
server_name: String,
|
||||
server_url: String,
|
||||
oauth_server_url: String,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
launch_browser: bool,
|
||||
timeout: Duration,
|
||||
@@ -288,12 +290,61 @@ impl OauthLoginFlow {
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
|
||||
|
||||
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
|
||||
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
|
||||
oauth_state
|
||||
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
|
||||
.await?;
|
||||
let auth_url = oauth_state.get_authorization_url().await?;
|
||||
let mut last_error: Option<anyhow::Error> = None;
|
||||
let mut auth_url = None;
|
||||
let mut oauth_state = None;
|
||||
let mut oauth_server_url = None;
|
||||
|
||||
for candidate_url in oauth_auth_url_candidates(server_url) {
|
||||
let mut state = match OAuthState::new(&candidate_url, Some(http_client.clone())).await {
|
||||
Ok(state) => state,
|
||||
Err(err) => {
|
||||
last_error = Some(err.into());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
|
||||
if let Err(err) = state
|
||||
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
|
||||
.await
|
||||
{
|
||||
last_error = Some(err.into());
|
||||
continue;
|
||||
}
|
||||
|
||||
auth_url = match state.get_authorization_url().await {
|
||||
Ok(candidate_auth_url) => {
|
||||
oauth_state = Some(state);
|
||||
oauth_server_url = Some(candidate_url);
|
||||
Some(candidate_auth_url)
|
||||
}
|
||||
Err(err) => {
|
||||
let err_msg = format!("{err:?}");
|
||||
last_error = Some(anyhow!(err_msg.clone()));
|
||||
debug!("OAuth state did not provide authorization URL: {err_msg}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
break;
|
||||
}
|
||||
|
||||
let oauth_state = match oauth_state {
|
||||
Some(state) => state,
|
||||
None => {
|
||||
return Err(last_error.unwrap_or_else(|| {
|
||||
anyhow!("No usable OAuth auth endpoint found for MCP server")
|
||||
}));
|
||||
}
|
||||
};
|
||||
let auth_url = match auth_url {
|
||||
Some(url) => url,
|
||||
None => {
|
||||
return Err(anyhow!(
|
||||
"Internal OAuth login error: authorization URL not initialized after state setup"
|
||||
));
|
||||
}
|
||||
};
|
||||
let oauth_server_url = oauth_server_url.unwrap_or_else(|| server_url.to_string());
|
||||
let timeout_secs = timeout_secs.unwrap_or(DEFAULT_OAUTH_TIMEOUT_SECS).max(1);
|
||||
let timeout = Duration::from_secs(timeout_secs as u64);
|
||||
|
||||
@@ -303,7 +354,7 @@ impl OauthLoginFlow {
|
||||
rx,
|
||||
guard,
|
||||
server_name: server_name.to_string(),
|
||||
server_url: server_url.to_string(),
|
||||
oauth_server_url,
|
||||
store_mode,
|
||||
launch_browser,
|
||||
timeout,
|
||||
@@ -349,7 +400,7 @@ impl OauthLoginFlow {
|
||||
let expires_at = compute_expires_at_millis(&credentials);
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: self.server_name.clone(),
|
||||
url: self.server_url.clone(),
|
||||
url: self.oauth_server_url.clone(),
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials),
|
||||
expires_at,
|
||||
|
||||
@@ -59,6 +59,7 @@ use crate::program_resolver;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
use crate::utils::create_env_for_mcp_server;
|
||||
use crate::utils::oauth_auth_url_candidates;
|
||||
use crate::utils::run_with_timeout;
|
||||
|
||||
enum PendingTransport {
|
||||
@@ -244,21 +245,36 @@ impl RmcpClient {
|
||||
) -> Result<Self> {
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
|
||||
let initial_oauth_tokens = match bearer_token {
|
||||
Some(_) => None,
|
||||
None => match load_oauth_tokens(server_name, url, store_mode) {
|
||||
Ok(tokens) => tokens,
|
||||
Err(err) => {
|
||||
warn!("failed to read tokens for server `{server_name}`: {err}");
|
||||
None
|
||||
let (initial_oauth_tokens, oauth_server_url) = match bearer_token {
|
||||
Some(_) => (None, None),
|
||||
None => {
|
||||
let mut token_match = None;
|
||||
for candidate_url in oauth_auth_url_candidates(url) {
|
||||
match load_oauth_tokens(server_name, &candidate_url, store_mode) {
|
||||
Ok(Some(tokens)) => {
|
||||
token_match = Some((candidate_url, tokens));
|
||||
break;
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"failed to read tokens for server `{server_name}` at `{candidate_url}`: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
token_match.map_or((None, None), |(candidate_url, tokens)| {
|
||||
(Some(tokens), Some(candidate_url))
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let transport = if let Some(initial_tokens) = initial_oauth_tokens.clone() {
|
||||
let oauth_url = oauth_server_url.as_deref().unwrap_or(url);
|
||||
match create_oauth_transport_and_runtime(
|
||||
server_name,
|
||||
url,
|
||||
oauth_url,
|
||||
initial_tokens.clone(),
|
||||
store_mode,
|
||||
default_headers.clone(),
|
||||
@@ -583,6 +599,7 @@ impl RmcpClient {
|
||||
async fn create_oauth_transport_and_runtime(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
oauth_url: &str,
|
||||
initial_tokens: StoredOAuthTokens,
|
||||
credentials_store: OAuthCredentialsStoreMode,
|
||||
default_headers: HeaderMap,
|
||||
@@ -592,7 +609,7 @@ async fn create_oauth_transport_and_runtime(
|
||||
)> {
|
||||
let http_client =
|
||||
apply_default_headers(reqwest::Client::builder(), &default_headers).build()?;
|
||||
let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?;
|
||||
let mut oauth_state = OAuthState::new(oauth_url.to_string(), Some(http_client.clone())).await?;
|
||||
|
||||
oauth_state
|
||||
.set_credentials(
|
||||
@@ -619,7 +636,7 @@ async fn create_oauth_transport_and_runtime(
|
||||
|
||||
let runtime = OAuthPersistor::new(
|
||||
server_name.to_string(),
|
||||
url.to_string(),
|
||||
oauth_url.to_string(),
|
||||
auth_manager,
|
||||
credentials_store,
|
||||
Some(initial_tokens),
|
||||
|
||||
@@ -6,6 +6,7 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use reqwest::ClientBuilder;
|
||||
use reqwest::Url;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderName;
|
||||
use reqwest::header::HeaderValue;
|
||||
@@ -112,6 +113,33 @@ pub(crate) fn apply_default_headers(
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn oauth_auth_url_candidates(server_url: &str) -> Vec<String> {
|
||||
let mut candidates = vec![server_url.to_string()];
|
||||
|
||||
let Ok(mut base_url) = Url::parse(server_url) else {
|
||||
return candidates;
|
||||
};
|
||||
|
||||
let mut trimmed_path = base_url.path().trim_end_matches('/').to_string();
|
||||
if let Some(truncated) = trimmed_path.strip_suffix("/mcp") {
|
||||
trimmed_path = truncated.to_string();
|
||||
if trimmed_path.is_empty() {
|
||||
base_url.set_path("/");
|
||||
} else {
|
||||
base_url.set_path(&trimmed_path);
|
||||
}
|
||||
base_url.set_query(None);
|
||||
base_url.set_fragment(None);
|
||||
|
||||
let fallback = base_url.to_string().trim_end_matches('/').to_string();
|
||||
if !candidates.contains(&fallback) {
|
||||
candidates.push(fallback);
|
||||
}
|
||||
}
|
||||
|
||||
candidates
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
|
||||
"HOME",
|
||||
@@ -215,4 +243,22 @@ mod tests {
|
||||
let env = create_env_for_mcp_server(None, &[custom_var.to_string()]);
|
||||
assert_eq!(env.get(custom_var), Some(&value.to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oauth_auth_url_candidates_returns_input_when_not_mcp_path() {
|
||||
let candidates = oauth_auth_url_candidates("https://example.com/api");
|
||||
assert_eq!(candidates, vec!["https://example.com/api".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oauth_auth_url_candidates_strips_trailing_mcp_segment() {
|
||||
let candidates = oauth_auth_url_candidates("https://example.com/mcp");
|
||||
assert_eq!(
|
||||
candidates,
|
||||
vec![
|
||||
"https://example.com/mcp".to_string(),
|
||||
"https://example.com".to_string()
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user