mirror of
https://github.com/openai/codex.git
synced 2026-05-11 23:02:39 +00:00
Compare commits
12 Commits
dev/bookho
...
dev/remote
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8e2beeafa | ||
|
|
35080d5cd0 | ||
|
|
9df8e3d935 | ||
|
|
da1000bbb4 | ||
|
|
ec6227c6cd | ||
|
|
520ee70b73 | ||
|
|
276a97340c | ||
|
|
2b68005d97 | ||
|
|
c2a2ffa512 | ||
|
|
a3abb8ba6e | ||
|
|
a4de53c661 | ||
|
|
177e39f4b7 |
2
codex-rs/Cargo.lock
generated
2
codex-rs/Cargo.lock
generated
@@ -1864,7 +1864,6 @@ dependencies = [
|
||||
"codex-models-manager",
|
||||
"codex-otel",
|
||||
"codex-protocol",
|
||||
"codex-rmcp-client",
|
||||
"codex-rollout",
|
||||
"codex-sandboxing",
|
||||
"codex-shell-command",
|
||||
@@ -3168,7 +3167,6 @@ dependencies = [
|
||||
"axum",
|
||||
"bytes",
|
||||
"codex-api",
|
||||
"codex-client",
|
||||
"codex-config",
|
||||
"codex-exec-server",
|
||||
"codex-keyring-store",
|
||||
|
||||
@@ -54,7 +54,6 @@ codex-models-manager = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-feedback = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-rollout = { workspace = true }
|
||||
codex-sandboxing = { workspace = true }
|
||||
codex-state = { workspace = true }
|
||||
|
||||
@@ -310,10 +310,9 @@ use codex_mcp::McpRuntimeEnvironment;
|
||||
use codex_mcp::McpServerStatusSnapshot;
|
||||
use codex_mcp::McpSnapshotDetail;
|
||||
use codex_mcp::collect_mcp_server_status_snapshot_with_detail;
|
||||
use codex_mcp::discover_supported_scopes;
|
||||
use codex_mcp::effective_mcp_servers;
|
||||
use codex_mcp::perform_oauth_login_return_url_for_server;
|
||||
use codex_mcp::read_mcp_resource as read_mcp_resource_without_thread;
|
||||
use codex_mcp::resolve_oauth_scopes;
|
||||
use codex_model_provider::ProviderAccountError;
|
||||
use codex_model_provider::create_model_provider;
|
||||
use codex_models_manager::collaboration_mode_presets::CollaborationModesConfig;
|
||||
@@ -355,7 +354,6 @@ use codex_protocol::protocol::USER_MESSAGE_BEGIN;
|
||||
use codex_protocol::protocol::W3cTraceContext;
|
||||
use codex_protocol::user_input::MAX_USER_INPUT_TEXT_CHARS;
|
||||
use codex_protocol::user_input::UserInput as CoreInputItem;
|
||||
use codex_rmcp_client::perform_oauth_login_return_url;
|
||||
use codex_rollout::state_db::StateDbHandle;
|
||||
use codex_rollout::state_db::get_state_db;
|
||||
use codex_rollout::state_db::reconcile_rollout;
|
||||
@@ -5941,13 +5939,8 @@ impl CodexMessageProcessor {
|
||||
return;
|
||||
};
|
||||
|
||||
let (url, http_headers, env_http_headers) = match &server.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
..
|
||||
} => (url.clone(), http_headers.clone(), env_http_headers.clone()),
|
||||
match &server.transport {
|
||||
McpServerTransportConfig::StreamableHttp { .. } => {}
|
||||
_ => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
@@ -5960,25 +5953,24 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
};
|
||||
|
||||
let discovered_scopes = if scopes.is_none() && server.scopes.is_none() {
|
||||
discover_supported_scopes(&server.transport).await
|
||||
} else {
|
||||
None
|
||||
let environment_manager = self.thread_manager.environment_manager();
|
||||
let runtime_environment = match environment_manager.default_environment() {
|
||||
Some(environment) => McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf()),
|
||||
None => McpRuntimeEnvironment::new(
|
||||
environment_manager.local_environment(),
|
||||
config.cwd.to_path_buf(),
|
||||
),
|
||||
};
|
||||
let resolved_scopes =
|
||||
resolve_oauth_scopes(scopes, server.scopes.clone(), discovered_scopes);
|
||||
|
||||
match perform_oauth_login_return_url(
|
||||
match perform_oauth_login_return_url_for_server(
|
||||
&name,
|
||||
&url,
|
||||
server,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
&resolved_scopes.scopes,
|
||||
server.oauth_resource.as_deref(),
|
||||
scopes,
|
||||
timeout_secs,
|
||||
config.mcp_oauth_callback_port,
|
||||
config.mcp_oauth_callback_url.as_deref(),
|
||||
runtime_environment,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
||||
@@ -5,12 +5,9 @@ use codex_app_server_protocol::McpServerOauthLoginCompletedNotification;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_config::types::McpServerConfig;
|
||||
use codex_core::config::Config;
|
||||
use codex_mcp::McpOAuthLoginSupport;
|
||||
use codex_mcp::oauth_login_support;
|
||||
use codex_mcp::resolve_oauth_scopes;
|
||||
use codex_mcp::should_retry_without_scopes;
|
||||
use codex_rmcp_client::perform_oauth_login_silent;
|
||||
use tracing::warn;
|
||||
use codex_mcp::McpOAuthLoginOutcome;
|
||||
use codex_mcp::McpRuntimeEnvironment;
|
||||
use codex_mcp::perform_oauth_login_silent_for_server;
|
||||
|
||||
use super::CodexMessageProcessor;
|
||||
|
||||
@@ -21,23 +18,17 @@ impl CodexMessageProcessor {
|
||||
plugin_mcp_servers: HashMap<String, McpServerConfig>,
|
||||
) {
|
||||
for (name, server) in plugin_mcp_servers {
|
||||
let oauth_config = match oauth_login_support(&server.transport).await {
|
||||
McpOAuthLoginSupport::Supported(config) => config,
|
||||
McpOAuthLoginSupport::Unsupported => continue,
|
||||
McpOAuthLoginSupport::Unknown(err) => {
|
||||
warn!(
|
||||
"MCP server may or may not require login for plugin install {name}: {err}"
|
||||
);
|
||||
continue;
|
||||
let environment_manager = self.thread_manager.environment_manager();
|
||||
let runtime_environment = match environment_manager.default_environment() {
|
||||
Some(environment) => {
|
||||
McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf())
|
||||
}
|
||||
None => McpRuntimeEnvironment::new(
|
||||
environment_manager.local_environment(),
|
||||
config.cwd.to_path_buf(),
|
||||
),
|
||||
};
|
||||
|
||||
let resolved_scopes = resolve_oauth_scopes(
|
||||
/*explicit_scopes*/ None,
|
||||
server.scopes.clone(),
|
||||
oauth_config.discovered_scopes.clone(),
|
||||
);
|
||||
|
||||
let store_mode = config.mcp_oauth_credentials_store_mode;
|
||||
let callback_port = config.mcp_oauth_callback_port;
|
||||
let callback_url = config.mcp_oauth_callback_url.clone();
|
||||
@@ -45,39 +36,20 @@ impl CodexMessageProcessor {
|
||||
let notification_name = name.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let first_attempt = perform_oauth_login_silent(
|
||||
let final_result = perform_oauth_login_silent_for_server(
|
||||
&name,
|
||||
&oauth_config.url,
|
||||
&server,
|
||||
store_mode,
|
||||
oauth_config.http_headers.clone(),
|
||||
oauth_config.env_http_headers.clone(),
|
||||
&resolved_scopes.scopes,
|
||||
server.oauth_resource.as_deref(),
|
||||
/*explicit_scopes*/ None,
|
||||
callback_port,
|
||||
callback_url.as_deref(),
|
||||
runtime_environment,
|
||||
)
|
||||
.await;
|
||||
|
||||
let final_result = match first_attempt {
|
||||
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
|
||||
perform_oauth_login_silent(
|
||||
&name,
|
||||
&oauth_config.url,
|
||||
store_mode,
|
||||
oauth_config.http_headers,
|
||||
oauth_config.env_http_headers,
|
||||
&[],
|
||||
server.oauth_resource.as_deref(),
|
||||
callback_port,
|
||||
callback_url.as_deref(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
result => result,
|
||||
};
|
||||
|
||||
let (success, error) = match final_result {
|
||||
Ok(()) => (true, None),
|
||||
Ok(McpOAuthLoginOutcome::Completed) => (true, None),
|
||||
Ok(McpOAuthLoginOutcome::Unsupported) => return,
|
||||
Err(err) => (false, Some(err.to_string())),
|
||||
};
|
||||
|
||||
|
||||
@@ -15,16 +15,19 @@ use codex_core::config::edit::ConfigEditsBuilder;
|
||||
use codex_core::config::find_codex_home;
|
||||
use codex_core::config::load_global_mcp_servers;
|
||||
use codex_core::plugins::PluginsManager;
|
||||
use codex_exec_server::CODEX_EXEC_SERVER_URL_ENV_VAR;
|
||||
use codex_exec_server::Environment;
|
||||
use codex_exec_server::EnvironmentManager;
|
||||
use codex_exec_server::EnvironmentManagerArgs;
|
||||
use codex_exec_server::ExecServerRuntimePaths;
|
||||
use codex_mcp::McpOAuthLoginOutcome;
|
||||
use codex_mcp::McpOAuthLoginSupport;
|
||||
use codex_mcp::ResolvedMcpOAuthScopes;
|
||||
use codex_mcp::McpRuntimeEnvironment;
|
||||
use codex_mcp::compute_auth_statuses;
|
||||
use codex_mcp::discover_supported_scopes;
|
||||
use codex_mcp::oauth_login_support;
|
||||
use codex_mcp::resolve_oauth_scopes;
|
||||
use codex_mcp::should_retry_without_scopes;
|
||||
use codex_mcp::oauth_login_support_for_server;
|
||||
use codex_mcp::perform_oauth_login_for_server;
|
||||
use codex_protocol::protocol::McpAuthStatus;
|
||||
use codex_rmcp_client::delete_oauth_tokens;
|
||||
use codex_rmcp_client::perform_oauth_login;
|
||||
use codex_utils_cli::CliConfigOverrides;
|
||||
use codex_utils_cli::format_env_display;
|
||||
|
||||
@@ -188,54 +191,6 @@ impl McpCli {
|
||||
}
|
||||
}
|
||||
|
||||
/// Preserve compatibility with servers that still expect the legacy empty-scope
|
||||
/// OAuth request. If a discovered-scope request is rejected by the provider,
|
||||
/// retry the login flow once without scopes.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn perform_oauth_login_retry_without_scopes(
|
||||
name: &str,
|
||||
url: &str,
|
||||
store_mode: codex_config::types::OAuthCredentialsStoreMode,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
resolved_scopes: &ResolvedMcpOAuthScopes,
|
||||
oauth_resource: Option<&str>,
|
||||
callback_port: Option<u16>,
|
||||
callback_url: Option<&str>,
|
||||
) -> Result<()> {
|
||||
match perform_oauth_login(
|
||||
name,
|
||||
url,
|
||||
store_mode,
|
||||
http_headers.clone(),
|
||||
env_http_headers.clone(),
|
||||
&resolved_scopes.scopes,
|
||||
oauth_resource,
|
||||
callback_port,
|
||||
callback_url,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => Ok(()),
|
||||
Err(err) if should_retry_without_scopes(resolved_scopes, &err) => {
|
||||
println!("OAuth provider rejected discovered scopes. Retrying without scopes…");
|
||||
perform_oauth_login(
|
||||
name,
|
||||
url,
|
||||
store_mode,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
&[],
|
||||
oauth_resource,
|
||||
callback_port,
|
||||
callback_url,
|
||||
)
|
||||
.await
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<()> {
|
||||
// Validate any provided overrides even though they are not currently applied.
|
||||
let overrides = config_overrides
|
||||
@@ -297,7 +252,7 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re
|
||||
};
|
||||
|
||||
let new_entry = McpServerConfig {
|
||||
transport: transport.clone(),
|
||||
transport,
|
||||
experimental_environment: None,
|
||||
enabled: true,
|
||||
required: false,
|
||||
@@ -313,7 +268,7 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re
|
||||
tools: HashMap::new(),
|
||||
};
|
||||
|
||||
servers.insert(name.clone(), new_entry);
|
||||
servers.insert(name.clone(), new_entry.clone());
|
||||
|
||||
ConfigEditsBuilder::new(&codex_home)
|
||||
.replace_mcp_servers(&servers)
|
||||
@@ -323,27 +278,24 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re
|
||||
|
||||
println!("Added global MCP server '{name}'.");
|
||||
|
||||
match oauth_login_support(&transport).await {
|
||||
McpOAuthLoginSupport::Supported(oauth_config) => {
|
||||
let runtime_environment = runtime_environment_from_config(&config);
|
||||
match oauth_login_support_for_server(&new_entry, runtime_environment.clone()).await {
|
||||
McpOAuthLoginSupport::Supported(_) => {
|
||||
println!("Detected OAuth support. Starting OAuth flow…");
|
||||
let resolved_scopes = resolve_oauth_scopes(
|
||||
/*explicit_scopes*/ None,
|
||||
/*configured_scopes*/ None,
|
||||
oauth_config.discovered_scopes.clone(),
|
||||
);
|
||||
perform_oauth_login_retry_without_scopes(
|
||||
match perform_oauth_login_for_server(
|
||||
&name,
|
||||
&oauth_config.url,
|
||||
&new_entry,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
oauth_config.http_headers,
|
||||
oauth_config.env_http_headers,
|
||||
&resolved_scopes,
|
||||
/*oauth_resource*/ None,
|
||||
/*explicit_scopes*/ None,
|
||||
config.mcp_oauth_callback_port,
|
||||
config.mcp_oauth_callback_url.as_deref(),
|
||||
runtime_environment,
|
||||
)
|
||||
.await?;
|
||||
println!("Successfully logged in.");
|
||||
.await?
|
||||
{
|
||||
McpOAuthLoginOutcome::Completed => println!("Successfully logged in."),
|
||||
McpOAuthLoginOutcome::Unsupported => {}
|
||||
}
|
||||
}
|
||||
McpOAuthLoginSupport::Unsupported => {}
|
||||
McpOAuthLoginSupport::Unknown(_) => println!(
|
||||
@@ -405,38 +357,32 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs)
|
||||
bail!("No MCP server named '{name}' found.");
|
||||
};
|
||||
|
||||
let (url, http_headers, env_http_headers) = match &server.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
..
|
||||
} => (url.clone(), http_headers.clone(), env_http_headers.clone()),
|
||||
_ => bail!("OAuth login is only supported for streamable HTTP servers."),
|
||||
if !matches!(
|
||||
&server.transport,
|
||||
McpServerTransportConfig::StreamableHttp { .. }
|
||||
) {
|
||||
bail!("OAuth login is only supported for streamable HTTP servers.");
|
||||
};
|
||||
|
||||
let explicit_scopes = (!scopes.is_empty()).then_some(scopes);
|
||||
let discovered_scopes = if explicit_scopes.is_none() && server.scopes.is_none() {
|
||||
discover_supported_scopes(&server.transport).await
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let resolved_scopes =
|
||||
resolve_oauth_scopes(explicit_scopes, server.scopes.clone(), discovered_scopes);
|
||||
|
||||
perform_oauth_login_retry_without_scopes(
|
||||
match perform_oauth_login_for_server(
|
||||
&name,
|
||||
&url,
|
||||
server,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
&resolved_scopes,
|
||||
server.oauth_resource.as_deref(),
|
||||
explicit_scopes,
|
||||
config.mcp_oauth_callback_port,
|
||||
config.mcp_oauth_callback_url.as_deref(),
|
||||
runtime_environment_from_config(&config),
|
||||
)
|
||||
.await?;
|
||||
println!("Successfully logged in to MCP server '{name}'.");
|
||||
.await?
|
||||
{
|
||||
McpOAuthLoginOutcome::Completed => {
|
||||
println!("Successfully logged in to MCP server '{name}'.")
|
||||
}
|
||||
McpOAuthLoginOutcome::Unsupported => {
|
||||
bail!("MCP server '{name}' does not advertise OAuth support.")
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -472,6 +418,27 @@ async fn run_logout(config_overrides: &CliConfigOverrides, logout_args: LogoutAr
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn runtime_environment_from_config(config: &Config) -> McpRuntimeEnvironment {
|
||||
let local_runtime_paths = ExecServerRuntimePaths::from_optional_paths(
|
||||
config.codex_self_exe.clone(),
|
||||
config.codex_linux_sandbox_exe.clone(),
|
||||
);
|
||||
let environment = match local_runtime_paths {
|
||||
Ok(local_runtime_paths) => {
|
||||
let environment_manager =
|
||||
EnvironmentManager::new(EnvironmentManagerArgs::from_env(local_runtime_paths));
|
||||
environment_manager
|
||||
.default_environment()
|
||||
.unwrap_or_else(|| environment_manager.local_environment())
|
||||
}
|
||||
Err(_) => Arc::new(
|
||||
Environment::create_for_tests(std::env::var(CODEX_EXEC_SERVER_URL_ENV_VAR).ok())
|
||||
.unwrap_or_else(|_| Environment::default_for_tests()),
|
||||
),
|
||||
};
|
||||
McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf())
|
||||
}
|
||||
|
||||
async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Result<()> {
|
||||
let overrides = config_overrides
|
||||
.parse_overrides()
|
||||
@@ -483,6 +450,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
config.codex_home.to_path_buf(),
|
||||
)));
|
||||
let mcp_servers = mcp_manager.effective_servers(&config, /*auth*/ None).await;
|
||||
let runtime_environment = runtime_environment_from_config(&config);
|
||||
|
||||
let mut entries: Vec<_> = mcp_servers.iter().collect();
|
||||
entries.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
@@ -490,6 +458,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
|
||||
mcp_servers.iter(),
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
/*auth*/ None,
|
||||
runtime_environment,
|
||||
)
|
||||
.await;
|
||||
|
||||
|
||||
@@ -24,12 +24,19 @@ pub use mcp::read_mcp_resource;
|
||||
|
||||
pub use mcp::McpAuthStatusEntry;
|
||||
pub use mcp::McpOAuthLoginConfig;
|
||||
pub use mcp::McpOAuthLoginOutcome;
|
||||
pub use mcp::McpOAuthLoginSupport;
|
||||
pub use mcp::McpOAuthScopesSource;
|
||||
pub use mcp::ResolvedMcpOAuthScopes;
|
||||
pub use mcp::compute_auth_statuses;
|
||||
pub use mcp::discover_supported_scopes;
|
||||
pub use mcp::discover_supported_scopes_for_server;
|
||||
pub use mcp::http_client_for_server;
|
||||
pub use mcp::oauth_login_support;
|
||||
pub use mcp::oauth_login_support_for_server;
|
||||
pub use mcp::perform_oauth_login_for_server;
|
||||
pub use mcp::perform_oauth_login_return_url_for_server;
|
||||
pub use mcp::perform_oauth_login_silent_for_server;
|
||||
pub use mcp::resolve_oauth_scopes;
|
||||
pub use mcp::should_retry_without_scopes;
|
||||
|
||||
|
||||
@@ -4,14 +4,28 @@ use anyhow::Result;
|
||||
use codex_config::McpServerConfig;
|
||||
use codex_config::McpServerTransportConfig;
|
||||
use codex_config::types::OAuthCredentialsStoreMode;
|
||||
use codex_exec_server::HttpClient;
|
||||
use codex_exec_server::ReqwestHttpClient;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_protocol::protocol::McpAuthStatus;
|
||||
use codex_rmcp_client::OAuthProviderError;
|
||||
use codex_rmcp_client::OauthLoginHandle;
|
||||
use codex_rmcp_client::determine_streamable_http_auth_status;
|
||||
use codex_rmcp_client::determine_streamable_http_auth_status_with_client;
|
||||
use codex_rmcp_client::discover_streamable_http_oauth;
|
||||
use codex_rmcp_client::discover_streamable_http_oauth_with_client;
|
||||
use codex_rmcp_client::perform_oauth_login;
|
||||
use codex_rmcp_client::perform_oauth_login_return_url;
|
||||
use codex_rmcp_client::perform_oauth_login_return_url_with_client;
|
||||
use codex_rmcp_client::perform_oauth_login_silent;
|
||||
use codex_rmcp_client::perform_oauth_login_silent_with_client;
|
||||
use codex_rmcp_client::perform_oauth_login_with_client;
|
||||
use futures::future::join_all;
|
||||
use std::sync::Arc;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::runtime::McpRuntimeEnvironment;
|
||||
|
||||
use super::CODEX_APPS_MCP_SERVER_NAME;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -49,7 +63,65 @@ pub struct McpAuthStatusEntry {
|
||||
pub auth_status: McpAuthStatus,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum McpOAuthLoginOutcome {
|
||||
Completed,
|
||||
Unsupported,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum McpOAuthHttpClient {
|
||||
LocalNoProxy,
|
||||
Runtime(Arc<dyn HttpClient>),
|
||||
}
|
||||
|
||||
impl McpOAuthHttpClient {
|
||||
async fn login_support(&self, transport: &McpServerTransportConfig) -> McpOAuthLoginSupport {
|
||||
match self {
|
||||
Self::LocalNoProxy => oauth_login_support(transport).await,
|
||||
Self::Runtime(http_client) => {
|
||||
oauth_login_support_with_client(transport, http_client.clone()).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn discover_supported_scopes(
|
||||
&self,
|
||||
transport: &McpServerTransportConfig,
|
||||
) -> Option<Vec<String>> {
|
||||
match self.login_support(transport).await {
|
||||
McpOAuthLoginSupport::Supported(config) => config.discovered_scopes,
|
||||
McpOAuthLoginSupport::Unsupported | McpOAuthLoginSupport::Unknown(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn oauth_login_support(transport: &McpServerTransportConfig) -> McpOAuthLoginSupport {
|
||||
oauth_login_support_with_discovery(transport).await
|
||||
}
|
||||
|
||||
pub async fn oauth_login_support_for_server(
|
||||
config: &McpServerConfig,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
) -> McpOAuthLoginSupport {
|
||||
match config.experimental_environment.as_deref() {
|
||||
Some("remote") => {
|
||||
let http_client = match http_client_for_server(config, runtime_environment) {
|
||||
Ok(http_client) => http_client,
|
||||
Err(err) => return McpOAuthLoginSupport::Unknown(err),
|
||||
};
|
||||
oauth_login_support_with_client(&config.transport, http_client).await
|
||||
}
|
||||
None | Some("local") => oauth_login_support(&config.transport).await,
|
||||
Some(environment) => McpOAuthLoginSupport::Unknown(anyhow::anyhow!(
|
||||
"unsupported experimental_environment `{environment}`"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn oauth_login_support_with_discovery(
|
||||
transport: &McpServerTransportConfig,
|
||||
) -> McpOAuthLoginSupport {
|
||||
let McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
@@ -77,6 +149,43 @@ pub async fn oauth_login_support(transport: &McpServerTransportConfig) -> McpOAu
|
||||
}
|
||||
}
|
||||
|
||||
async fn oauth_login_support_with_client(
|
||||
transport: &McpServerTransportConfig,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
) -> McpOAuthLoginSupport {
|
||||
let McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} = transport
|
||||
else {
|
||||
return McpOAuthLoginSupport::Unsupported;
|
||||
};
|
||||
|
||||
if bearer_token_env_var.is_some() {
|
||||
return McpOAuthLoginSupport::Unsupported;
|
||||
}
|
||||
|
||||
match discover_streamable_http_oauth_with_client(
|
||||
url,
|
||||
http_headers.clone(),
|
||||
env_http_headers.clone(),
|
||||
http_client,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some(discovery)) => McpOAuthLoginSupport::Supported(McpOAuthLoginConfig {
|
||||
url: url.clone(),
|
||||
http_headers: http_headers.clone(),
|
||||
env_http_headers: env_http_headers.clone(),
|
||||
discovered_scopes: discovery.scopes_supported,
|
||||
}),
|
||||
Ok(None) => McpOAuthLoginSupport::Unsupported,
|
||||
Err(err) => McpOAuthLoginSupport::Unknown(err),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn discover_supported_scopes(
|
||||
transport: &McpServerTransportConfig,
|
||||
) -> Option<Vec<String>> {
|
||||
@@ -86,6 +195,16 @@ pub async fn discover_supported_scopes(
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn discover_supported_scopes_for_server(
|
||||
config: &McpServerConfig,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
) -> Option<Vec<String>> {
|
||||
match oauth_login_support_for_server(config, runtime_environment).await {
|
||||
McpOAuthLoginSupport::Supported(config) => config.discovered_scopes,
|
||||
McpOAuthLoginSupport::Unsupported | McpOAuthLoginSupport::Unknown(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resolve_oauth_scopes(
|
||||
explicit_scopes: Option<Vec<String>>,
|
||||
configured_scopes: Option<Vec<String>>,
|
||||
@@ -125,10 +244,264 @@ pub fn should_retry_without_scopes(scopes: &ResolvedMcpOAuthScopes, error: &anyh
|
||||
&& error.downcast_ref::<OAuthProviderError>().is_some()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn perform_oauth_login_return_url_for_server(
|
||||
server_name: &str,
|
||||
config: &McpServerConfig,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
explicit_scopes: Option<Vec<String>>,
|
||||
timeout_secs: Option<i64>,
|
||||
callback_port: Option<u16>,
|
||||
callback_url: Option<&str>,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
) -> Result<OauthLoginHandle> {
|
||||
let McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
..
|
||||
} = &config.transport
|
||||
else {
|
||||
anyhow::bail!("OAuth login is only supported for streamable HTTP servers.");
|
||||
};
|
||||
|
||||
let oauth_http_client = oauth_http_client_for_server(config, runtime_environment)?;
|
||||
let discovered_scopes = if explicit_scopes.is_none() && config.scopes.is_none() {
|
||||
oauth_http_client
|
||||
.discover_supported_scopes(&config.transport)
|
||||
.await
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let resolved_scopes =
|
||||
resolve_oauth_scopes(explicit_scopes, config.scopes.clone(), discovered_scopes);
|
||||
|
||||
match oauth_http_client {
|
||||
McpOAuthHttpClient::LocalNoProxy => {
|
||||
perform_oauth_login_return_url(
|
||||
server_name,
|
||||
url,
|
||||
store_mode,
|
||||
http_headers.clone(),
|
||||
env_http_headers.clone(),
|
||||
&resolved_scopes.scopes,
|
||||
config.oauth_resource.as_deref(),
|
||||
timeout_secs,
|
||||
callback_port,
|
||||
callback_url,
|
||||
)
|
||||
.await
|
||||
}
|
||||
McpOAuthHttpClient::Runtime(http_client) => {
|
||||
perform_oauth_login_return_url_with_client(
|
||||
server_name,
|
||||
url,
|
||||
store_mode,
|
||||
http_headers.clone(),
|
||||
env_http_headers.clone(),
|
||||
&resolved_scopes.scopes,
|
||||
config.oauth_resource.as_deref(),
|
||||
timeout_secs,
|
||||
callback_port,
|
||||
callback_url,
|
||||
http_client,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn perform_oauth_login_silent_for_server(
|
||||
server_name: &str,
|
||||
config: &McpServerConfig,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
explicit_scopes: Option<Vec<String>>,
|
||||
callback_port: Option<u16>,
|
||||
callback_url: Option<&str>,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
) -> Result<McpOAuthLoginOutcome> {
|
||||
let oauth_http_client = oauth_http_client_for_server(config, runtime_environment)?;
|
||||
let oauth_config = match oauth_http_client.login_support(&config.transport).await {
|
||||
McpOAuthLoginSupport::Supported(config) => config,
|
||||
McpOAuthLoginSupport::Unsupported => return Ok(McpOAuthLoginOutcome::Unsupported),
|
||||
McpOAuthLoginSupport::Unknown(err) => return Err(err),
|
||||
};
|
||||
|
||||
let resolved_scopes = resolve_oauth_scopes(
|
||||
explicit_scopes,
|
||||
config.scopes.clone(),
|
||||
oauth_config.discovered_scopes.clone(),
|
||||
);
|
||||
|
||||
let final_result = match oauth_http_client {
|
||||
McpOAuthHttpClient::LocalNoProxy => {
|
||||
let first_attempt = perform_oauth_login_silent(
|
||||
server_name,
|
||||
&oauth_config.url,
|
||||
store_mode,
|
||||
oauth_config.http_headers.clone(),
|
||||
oauth_config.env_http_headers.clone(),
|
||||
&resolved_scopes.scopes,
|
||||
config.oauth_resource.as_deref(),
|
||||
callback_port,
|
||||
callback_url,
|
||||
)
|
||||
.await;
|
||||
match first_attempt {
|
||||
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
|
||||
perform_oauth_login_silent(
|
||||
server_name,
|
||||
&oauth_config.url,
|
||||
store_mode,
|
||||
oauth_config.http_headers,
|
||||
oauth_config.env_http_headers,
|
||||
&[],
|
||||
config.oauth_resource.as_deref(),
|
||||
callback_port,
|
||||
callback_url,
|
||||
)
|
||||
.await
|
||||
}
|
||||
result => result,
|
||||
}
|
||||
}
|
||||
McpOAuthHttpClient::Runtime(http_client) => {
|
||||
let first_attempt = perform_oauth_login_silent_with_client(
|
||||
server_name,
|
||||
&oauth_config.url,
|
||||
store_mode,
|
||||
oauth_config.http_headers.clone(),
|
||||
oauth_config.env_http_headers.clone(),
|
||||
&resolved_scopes.scopes,
|
||||
config.oauth_resource.as_deref(),
|
||||
callback_port,
|
||||
callback_url,
|
||||
http_client.clone(),
|
||||
)
|
||||
.await;
|
||||
match first_attempt {
|
||||
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
|
||||
perform_oauth_login_silent_with_client(
|
||||
server_name,
|
||||
&oauth_config.url,
|
||||
store_mode,
|
||||
oauth_config.http_headers,
|
||||
oauth_config.env_http_headers,
|
||||
&[],
|
||||
config.oauth_resource.as_deref(),
|
||||
callback_port,
|
||||
callback_url,
|
||||
http_client,
|
||||
)
|
||||
.await
|
||||
}
|
||||
result => result,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
final_result.map(|()| McpOAuthLoginOutcome::Completed)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn perform_oauth_login_for_server(
|
||||
server_name: &str,
|
||||
config: &McpServerConfig,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
explicit_scopes: Option<Vec<String>>,
|
||||
callback_port: Option<u16>,
|
||||
callback_url: Option<&str>,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
) -> Result<McpOAuthLoginOutcome> {
|
||||
let oauth_http_client = oauth_http_client_for_server(config, runtime_environment)?;
|
||||
let oauth_config = match oauth_http_client.login_support(&config.transport).await {
|
||||
McpOAuthLoginSupport::Supported(config) => config,
|
||||
McpOAuthLoginSupport::Unsupported => return Ok(McpOAuthLoginOutcome::Unsupported),
|
||||
McpOAuthLoginSupport::Unknown(err) => return Err(err),
|
||||
};
|
||||
|
||||
let resolved_scopes = resolve_oauth_scopes(
|
||||
explicit_scopes,
|
||||
config.scopes.clone(),
|
||||
oauth_config.discovered_scopes.clone(),
|
||||
);
|
||||
|
||||
let final_result = match oauth_http_client {
|
||||
McpOAuthHttpClient::LocalNoProxy => {
|
||||
let first_attempt = perform_oauth_login(
|
||||
server_name,
|
||||
&oauth_config.url,
|
||||
store_mode,
|
||||
oauth_config.http_headers.clone(),
|
||||
oauth_config.env_http_headers.clone(),
|
||||
&resolved_scopes.scopes,
|
||||
config.oauth_resource.as_deref(),
|
||||
callback_port,
|
||||
callback_url,
|
||||
)
|
||||
.await;
|
||||
match first_attempt {
|
||||
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
|
||||
perform_oauth_login(
|
||||
server_name,
|
||||
&oauth_config.url,
|
||||
store_mode,
|
||||
oauth_config.http_headers,
|
||||
oauth_config.env_http_headers,
|
||||
&[],
|
||||
config.oauth_resource.as_deref(),
|
||||
callback_port,
|
||||
callback_url,
|
||||
)
|
||||
.await
|
||||
}
|
||||
result => result,
|
||||
}
|
||||
}
|
||||
McpOAuthHttpClient::Runtime(http_client) => {
|
||||
let first_attempt = perform_oauth_login_with_client(
|
||||
server_name,
|
||||
&oauth_config.url,
|
||||
store_mode,
|
||||
oauth_config.http_headers.clone(),
|
||||
oauth_config.env_http_headers.clone(),
|
||||
&resolved_scopes.scopes,
|
||||
config.oauth_resource.as_deref(),
|
||||
callback_port,
|
||||
callback_url,
|
||||
http_client.clone(),
|
||||
)
|
||||
.await;
|
||||
match first_attempt {
|
||||
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
|
||||
perform_oauth_login_with_client(
|
||||
server_name,
|
||||
&oauth_config.url,
|
||||
store_mode,
|
||||
oauth_config.http_headers,
|
||||
oauth_config.env_http_headers,
|
||||
&[],
|
||||
config.oauth_resource.as_deref(),
|
||||
callback_port,
|
||||
callback_url,
|
||||
http_client,
|
||||
)
|
||||
.await
|
||||
}
|
||||
result => result,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
final_result.map(|()| McpOAuthLoginOutcome::Completed)
|
||||
}
|
||||
|
||||
pub async fn compute_auth_statuses<'a, I>(
|
||||
servers: I,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
auth: Option<&CodexAuth>,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
) -> HashMap<String, McpAuthStatusEntry>
|
||||
where
|
||||
I: IntoIterator<Item = (&'a String, &'a McpServerConfig)>,
|
||||
@@ -136,6 +509,7 @@ where
|
||||
let futures = servers.into_iter().map(|(name, config)| {
|
||||
let name = name.clone();
|
||||
let config = config.clone();
|
||||
let runtime_environment = runtime_environment.clone();
|
||||
let has_runtime_auth = name == CODEX_APPS_MCP_SERVER_NAME
|
||||
&& auth.is_some_and(CodexAuth::uses_codex_backend)
|
||||
&& matches!(
|
||||
@@ -146,14 +520,21 @@ where
|
||||
}
|
||||
);
|
||||
async move {
|
||||
let auth_status =
|
||||
match compute_auth_status(&name, &config, store_mode, has_runtime_auth).await {
|
||||
Ok(status) => status,
|
||||
Err(error) => {
|
||||
warn!("failed to determine auth status for MCP server `{name}`: {error:?}");
|
||||
McpAuthStatus::Unsupported
|
||||
}
|
||||
};
|
||||
let auth_status = match compute_auth_status(
|
||||
&name,
|
||||
&config,
|
||||
store_mode,
|
||||
has_runtime_auth,
|
||||
runtime_environment,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(status) => status,
|
||||
Err(error) => {
|
||||
warn!("failed to determine auth status for MCP server `{name}`: {error:?}");
|
||||
McpAuthStatus::Unsupported
|
||||
}
|
||||
};
|
||||
let entry = McpAuthStatusEntry {
|
||||
config,
|
||||
auth_status,
|
||||
@@ -170,6 +551,7 @@ async fn compute_auth_status(
|
||||
config: &McpServerConfig,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
has_runtime_auth: bool,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
) -> Result<McpAuthStatus> {
|
||||
if !config.enabled {
|
||||
return Ok(McpAuthStatus::Unsupported);
|
||||
@@ -186,28 +568,88 @@ async fn compute_auth_status(
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => {
|
||||
determine_streamable_http_auth_status(
|
||||
server_name,
|
||||
url,
|
||||
bearer_token_env_var.as_deref(),
|
||||
http_headers.clone(),
|
||||
env_http_headers.clone(),
|
||||
store_mode,
|
||||
)
|
||||
.await
|
||||
} => match config.experimental_environment.as_deref() {
|
||||
Some("remote") => {
|
||||
let http_client = http_client_for_server(config, runtime_environment)?;
|
||||
determine_streamable_http_auth_status_with_client(
|
||||
server_name,
|
||||
url,
|
||||
bearer_token_env_var.as_deref(),
|
||||
http_headers.clone(),
|
||||
env_http_headers.clone(),
|
||||
store_mode,
|
||||
http_client,
|
||||
)
|
||||
.await
|
||||
}
|
||||
None | Some("local") => {
|
||||
determine_streamable_http_auth_status(
|
||||
server_name,
|
||||
url,
|
||||
bearer_token_env_var.as_deref(),
|
||||
http_headers.clone(),
|
||||
env_http_headers.clone(),
|
||||
store_mode,
|
||||
)
|
||||
.await
|
||||
}
|
||||
Some(environment) => {
|
||||
anyhow::bail!("unsupported experimental_environment `{environment}`")
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn http_client_for_server(
|
||||
config: &McpServerConfig,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
) -> Result<Arc<dyn HttpClient>> {
|
||||
match config.experimental_environment.as_deref() {
|
||||
None | Some("local") => Ok(Arc::new(ReqwestHttpClient)),
|
||||
Some("remote") => {
|
||||
let environment = runtime_environment.environment();
|
||||
if !environment.is_remote() {
|
||||
anyhow::bail!("remote MCP server requires a remote environment");
|
||||
}
|
||||
Ok(environment.get_http_client())
|
||||
}
|
||||
Some(environment) => anyhow::bail!("unsupported experimental_environment `{environment}`"),
|
||||
}
|
||||
}
|
||||
|
||||
fn oauth_http_client_for_server(
|
||||
config: &McpServerConfig,
|
||||
runtime_environment: McpRuntimeEnvironment,
|
||||
) -> Result<McpOAuthHttpClient> {
|
||||
match config.experimental_environment.as_deref() {
|
||||
None | Some("local") => Ok(McpOAuthHttpClient::LocalNoProxy),
|
||||
Some("remote") => Ok(McpOAuthHttpClient::Runtime(http_client_for_server(
|
||||
config,
|
||||
runtime_environment,
|
||||
)?)),
|
||||
Some(environment) => anyhow::bail!("unsupported experimental_environment `{environment}`"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use codex_config::McpServerConfig;
|
||||
use codex_config::McpServerTransportConfig;
|
||||
use codex_exec_server::Environment;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use crate::runtime::McpRuntimeEnvironment;
|
||||
|
||||
use super::McpOAuthScopesSource;
|
||||
use super::OAuthProviderError;
|
||||
use super::ResolvedMcpOAuthScopes;
|
||||
use super::http_client_for_server;
|
||||
use super::resolve_oauth_scopes;
|
||||
use super::should_retry_without_scopes;
|
||||
|
||||
@@ -318,4 +760,71 @@ mod tests {
|
||||
&anyhow!("timed out waiting for OAuth callback"),
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_server_uses_local_http_client_even_with_remote_runtime() {
|
||||
let config = streamable_config(/*experimental_environment*/ None);
|
||||
let runtime_environment = remote_runtime_environment();
|
||||
|
||||
assert!(http_client_for_server(&config, runtime_environment).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remote_server_uses_runtime_http_client_when_runtime_is_remote() {
|
||||
let config = streamable_config(Some("remote"));
|
||||
let runtime_environment = remote_runtime_environment();
|
||||
|
||||
assert!(http_client_for_server(&config, runtime_environment).is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_server_without_remote_runtime_returns_clear_error() {
|
||||
let config = streamable_config(Some("remote"));
|
||||
let runtime_environment = McpRuntimeEnvironment::new(
|
||||
Arc::new(Environment::default_for_tests()),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
|
||||
let Err(error) = http_client_for_server(&config, runtime_environment) else {
|
||||
panic!("remote server should require remote runtime");
|
||||
};
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
"remote MCP server requires a remote environment"
|
||||
);
|
||||
}
|
||||
|
||||
fn streamable_config(experimental_environment: Option<&str>) -> McpServerConfig {
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: "http://mcp.example.test/mcp".to_string(),
|
||||
bearer_token_env_var: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
},
|
||||
experimental_environment: experimental_environment.map(str::to_string),
|
||||
enabled: true,
|
||||
required: false,
|
||||
supports_parallel_tool_calls: false,
|
||||
disabled_reason: None,
|
||||
startup_timeout_sec: Some(Duration::from_secs(30)),
|
||||
tool_timeout_sec: None,
|
||||
default_tools_approval_mode: None,
|
||||
enabled_tools: None,
|
||||
disabled_tools: None,
|
||||
scopes: None,
|
||||
oauth_resource: None,
|
||||
tools: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_runtime_environment() -> McpRuntimeEnvironment {
|
||||
McpRuntimeEnvironment::new(
|
||||
Arc::new(
|
||||
Environment::create_for_tests(Some("ws://127.0.0.1:65535".to_string()))
|
||||
.expect("create remote environment"),
|
||||
),
|
||||
PathBuf::from("/tmp"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
pub use auth::McpAuthStatusEntry;
|
||||
pub use auth::McpOAuthLoginConfig;
|
||||
pub use auth::McpOAuthLoginOutcome;
|
||||
pub use auth::McpOAuthLoginSupport;
|
||||
pub use auth::McpOAuthScopesSource;
|
||||
pub use auth::ResolvedMcpOAuthScopes;
|
||||
pub use auth::compute_auth_statuses;
|
||||
pub use auth::discover_supported_scopes;
|
||||
pub use auth::discover_supported_scopes_for_server;
|
||||
pub use auth::http_client_for_server;
|
||||
pub use auth::oauth_login_support;
|
||||
pub use auth::oauth_login_support_for_server;
|
||||
pub use auth::perform_oauth_login_for_server;
|
||||
pub use auth::perform_oauth_login_return_url_for_server;
|
||||
pub use auth::perform_oauth_login_silent_for_server;
|
||||
pub use auth::resolve_oauth_scopes;
|
||||
pub use auth::should_retry_without_scopes;
|
||||
|
||||
@@ -223,6 +230,7 @@ pub async fn read_mcp_resource(
|
||||
mcp_servers.iter(),
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
auth,
|
||||
runtime_environment.clone(),
|
||||
)
|
||||
.await;
|
||||
let (tx_event, rx_event) = unbounded();
|
||||
@@ -286,6 +294,7 @@ pub async fn collect_mcp_server_status_snapshot_with_detail(
|
||||
mcp_servers.iter(),
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
auth,
|
||||
runtime_environment.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
|
||||
@@ -253,20 +253,23 @@ pub async fn list_accessible_connectors_from_mcp_tools_with_environment_manager(
|
||||
});
|
||||
}
|
||||
|
||||
let environment = environment_manager
|
||||
.default_environment()
|
||||
.unwrap_or_else(|| environment_manager.local_environment());
|
||||
let runtime_environment =
|
||||
McpRuntimeEnvironment::new(environment.clone(), config.cwd.to_path_buf());
|
||||
|
||||
let auth_status_entries = compute_auth_statuses(
|
||||
mcp_servers.iter(),
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
auth.as_ref(),
|
||||
runtime_environment.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let (tx_event, rx_event) = unbounded();
|
||||
drop(rx_event);
|
||||
|
||||
let environment = environment_manager
|
||||
.default_environment()
|
||||
.unwrap_or_else(|| environment_manager.local_environment());
|
||||
|
||||
let (mcp_connection_manager, cancel_token) = McpConnectionManager::new(
|
||||
&mcp_servers,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
@@ -275,7 +278,7 @@ pub async fn list_accessible_connectors_from_mcp_tools_with_environment_manager(
|
||||
INITIAL_SUBMIT_ID.to_owned(),
|
||||
tx_event,
|
||||
PermissionProfile::default(),
|
||||
McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf()),
|
||||
runtime_environment,
|
||||
config.codex_home.to_path_buf(),
|
||||
codex_apps_tools_cache_key(auth.as_ref()),
|
||||
ToolPluginProvenance::default(),
|
||||
|
||||
@@ -11,7 +11,6 @@ use codex_protocol::request_user_input::RequestUserInputArgs;
|
||||
use codex_protocol::request_user_input::RequestUserInputQuestion;
|
||||
use codex_protocol::request_user_input::RequestUserInputQuestionOption;
|
||||
use codex_protocol::request_user_input::RequestUserInputResponse;
|
||||
use codex_rmcp_client::perform_oauth_login;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::warn;
|
||||
|
||||
@@ -19,11 +18,10 @@ use crate::SkillMetadata;
|
||||
use crate::session::session::Session;
|
||||
use crate::session::turn_context::TurnContext;
|
||||
use crate::skills::model::SkillToolDependency;
|
||||
use codex_mcp::McpOAuthLoginSupport;
|
||||
use codex_mcp::McpOAuthLoginOutcome;
|
||||
use codex_mcp::McpRuntimeEnvironment;
|
||||
use codex_mcp::mcp_permission_prompt_is_auto_approved;
|
||||
use codex_mcp::oauth_login_support;
|
||||
use codex_mcp::resolve_oauth_scopes;
|
||||
use codex_mcp::should_retry_without_scopes;
|
||||
use codex_mcp::perform_oauth_login_for_server;
|
||||
|
||||
const SKILL_MCP_DEPENDENCY_PROMPT_ID: &str = "skill_mcp_dependency_install";
|
||||
const MCP_DEPENDENCY_OPTION_INSTALL: &str = "Install";
|
||||
@@ -126,15 +124,13 @@ pub(crate) async fn maybe_install_mcp_dependencies(
|
||||
}
|
||||
|
||||
for (name, server_config) in added {
|
||||
let oauth_config = match oauth_login_support(&server_config.transport).await {
|
||||
McpOAuthLoginSupport::Supported(config) => config,
|
||||
McpOAuthLoginSupport::Unsupported => continue,
|
||||
McpOAuthLoginSupport::Unknown(err) => {
|
||||
warn!("MCP server may or may not require login for dependency {name}: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let runtime_environment = McpRuntimeEnvironment::new(
|
||||
turn_context
|
||||
.environment
|
||||
.clone()
|
||||
.unwrap_or_else(|| sess.services.environment_manager.local_environment()),
|
||||
turn_context.cwd.to_path_buf(),
|
||||
);
|
||||
sess.notify_background_event(
|
||||
turn_context,
|
||||
format!(
|
||||
@@ -143,52 +139,20 @@ pub(crate) async fn maybe_install_mcp_dependencies(
|
||||
)
|
||||
.await;
|
||||
|
||||
let resolved_scopes = resolve_oauth_scopes(
|
||||
/*explicit_scopes*/ None,
|
||||
server_config.scopes.clone(),
|
||||
oauth_config.discovered_scopes.clone(),
|
||||
);
|
||||
let first_attempt = perform_oauth_login(
|
||||
match perform_oauth_login_for_server(
|
||||
&name,
|
||||
&oauth_config.url,
|
||||
&server_config,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
oauth_config.http_headers.clone(),
|
||||
oauth_config.env_http_headers.clone(),
|
||||
&resolved_scopes.scopes,
|
||||
server_config.oauth_resource.as_deref(),
|
||||
/*explicit_scopes*/ None,
|
||||
config.mcp_oauth_callback_port,
|
||||
config.mcp_oauth_callback_url.as_deref(),
|
||||
runtime_environment,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(err) = first_attempt {
|
||||
if should_retry_without_scopes(&resolved_scopes, &err) {
|
||||
sess.notify_background_event(
|
||||
turn_context,
|
||||
format!(
|
||||
"Retrying MCP {name} authentication without scopes after provider rejection."
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(err) = perform_oauth_login(
|
||||
&name,
|
||||
&oauth_config.url,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
oauth_config.http_headers,
|
||||
oauth_config.env_http_headers,
|
||||
&[],
|
||||
server_config.oauth_resource.as_deref(),
|
||||
config.mcp_oauth_callback_port,
|
||||
config.mcp_oauth_callback_url.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!("failed to login to MCP dependency {name}: {err}");
|
||||
}
|
||||
} else {
|
||||
warn!("failed to login to MCP dependency {name}: {err}");
|
||||
}
|
||||
.await
|
||||
{
|
||||
Ok(McpOAuthLoginOutcome::Completed) => {}
|
||||
Ok(McpOAuthLoginOutcome::Unsupported) => {}
|
||||
Err(err) => warn!("failed to login to MCP dependency {name}: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ use crate::tasks::UndoTask;
|
||||
use crate::tasks::UserShellCommandMode;
|
||||
use crate::tasks::UserShellCommandTask;
|
||||
use crate::tasks::execute_user_shell_command;
|
||||
use codex_mcp::McpRuntimeEnvironment;
|
||||
use codex_mcp::collect_mcp_snapshot_from_manager;
|
||||
use codex_mcp::compute_auth_statuses;
|
||||
use codex_protocol::models::ContentItem;
|
||||
@@ -538,12 +539,18 @@ pub async fn list_mcp_tools(sess: &Session, config: &Arc<Config>, sub_id: String
|
||||
.mcp_manager
|
||||
.effective_servers(config, auth.as_ref())
|
||||
.await;
|
||||
let environment = sess
|
||||
.services
|
||||
.environment_manager
|
||||
.default_environment()
|
||||
.unwrap_or_else(|| sess.services.environment_manager.local_environment());
|
||||
let snapshot = collect_mcp_snapshot_from_manager(
|
||||
&mcp_connection_manager,
|
||||
compute_auth_statuses(
|
||||
mcp_servers.iter(),
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
auth.as_ref(),
|
||||
McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf()),
|
||||
)
|
||||
.await,
|
||||
)
|
||||
|
||||
@@ -219,8 +219,20 @@ impl Session {
|
||||
.tool_plugin_provenance(config.as_ref())
|
||||
.await;
|
||||
let mcp_servers = with_codex_apps_mcp(mcp_servers, auth.as_ref(), &mcp_config);
|
||||
let auth_statuses =
|
||||
compute_auth_statuses(mcp_servers.iter(), store_mode, auth.as_ref()).await;
|
||||
let runtime_environment = McpRuntimeEnvironment::new(
|
||||
turn_context
|
||||
.environment
|
||||
.clone()
|
||||
.unwrap_or_else(|| self.services.environment_manager.local_environment()),
|
||||
turn_context.cwd.to_path_buf(),
|
||||
);
|
||||
let auth_statuses = compute_auth_statuses(
|
||||
mcp_servers.iter(),
|
||||
store_mode,
|
||||
auth.as_ref(),
|
||||
runtime_environment.clone(),
|
||||
)
|
||||
.await;
|
||||
{
|
||||
let mut guard = self.services.mcp_startup_cancellation_token.lock().await;
|
||||
guard.cancel();
|
||||
@@ -234,13 +246,7 @@ impl Session {
|
||||
turn_context.sub_id.clone(),
|
||||
self.get_tx_event(),
|
||||
turn_context.permission_profile(),
|
||||
McpRuntimeEnvironment::new(
|
||||
turn_context
|
||||
.environment
|
||||
.clone()
|
||||
.unwrap_or_else(|| self.services.environment_manager.local_environment()),
|
||||
turn_context.cwd.to_path_buf(),
|
||||
),
|
||||
runtime_environment,
|
||||
config.codex_home.to_path_buf(),
|
||||
codex_apps_tools_cache_key(auth.as_ref()),
|
||||
tool_plugin_provenance,
|
||||
|
||||
@@ -445,15 +445,20 @@ impl Session {
|
||||
let auth_manager_clone = Arc::clone(&auth_manager);
|
||||
let config_for_mcp = Arc::clone(&config);
|
||||
let mcp_manager_for_mcp = Arc::clone(&mcp_manager);
|
||||
let environment_manager_for_mcp = Arc::clone(&environment_manager);
|
||||
let auth_and_mcp_fut = async move {
|
||||
let auth = auth_manager_clone.auth().await;
|
||||
let mcp_servers = mcp_manager_for_mcp
|
||||
.effective_servers(&config_for_mcp, auth.as_ref())
|
||||
.await;
|
||||
let environment = environment_manager_for_mcp
|
||||
.default_environment()
|
||||
.unwrap_or_else(|| environment_manager_for_mcp.local_environment());
|
||||
let auth_statuses = compute_auth_statuses(
|
||||
mcp_servers.iter(),
|
||||
config_for_mcp.mcp_oauth_credentials_store_mode,
|
||||
auth.as_ref(),
|
||||
McpRuntimeEnvironment::new(environment, config_for_mcp.cwd.to_path_buf()),
|
||||
)
|
||||
.await;
|
||||
(auth, mcp_servers, auth_statuses)
|
||||
|
||||
@@ -14,7 +14,6 @@ axum = { workspace = true, default-features = false, features = [
|
||||
"tokio",
|
||||
] }
|
||||
codex-api = { workspace = true }
|
||||
codex-client = { workspace = true }
|
||||
codex-config = { workspace = true }
|
||||
codex-exec-server = { workspace = true }
|
||||
codex-keyring-store = { workspace = true }
|
||||
|
||||
@@ -1,31 +1,29 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Error;
|
||||
use anyhow::Result;
|
||||
use codex_exec_server::ExecServerError;
|
||||
use codex_exec_server::HttpClient;
|
||||
use codex_exec_server::HttpHeader;
|
||||
use codex_exec_server::HttpRequestParams;
|
||||
use codex_exec_server::HttpRequestResponse;
|
||||
use codex_exec_server::HttpResponseBodyStream;
|
||||
use codex_protocol::protocol::McpAuthStatus;
|
||||
use reqwest::Client;
|
||||
use reqwest::StatusCode;
|
||||
use reqwest::Url;
|
||||
use futures::FutureExt;
|
||||
use futures::future::BoxFuture;
|
||||
use reqwest::Method;
|
||||
use reqwest::header::AUTHORIZATION;
|
||||
use reqwest::header::HeaderMap;
|
||||
use serde::Deserialize;
|
||||
use reqwest::header::HeaderName;
|
||||
use reqwest::header::HeaderValue;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::mcp_oauth_http::OAuthHttpClient;
|
||||
use crate::mcp_oauth_http::StreamableHttpOAuthDiscovery;
|
||||
use crate::oauth::has_oauth_tokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
use codex_config::types::OAuthCredentialsStoreMode;
|
||||
|
||||
const DISCOVERY_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const OAUTH_DISCOVERY_HEADER: &str = "MCP-Protocol-Version";
|
||||
const OAUTH_DISCOVERY_VERSION: &str = "2024-11-05";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct StreamableHttpOAuthDiscovery {
|
||||
pub scopes_supported: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Determine the authentication status for a streamable HTTP MCP server.
|
||||
pub async fn determine_streamable_http_auth_status(
|
||||
server_name: &str,
|
||||
@@ -34,12 +32,36 @@ pub async fn determine_streamable_http_auth_status(
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<McpAuthStatus> {
|
||||
determine_streamable_http_auth_status_with_client(
|
||||
server_name,
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
store_mode,
|
||||
no_proxy_http_client(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Determine the authentication status for a streamable HTTP MCP server using
|
||||
/// the caller-selected runtime HTTP client.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn determine_streamable_http_auth_status_with_client(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
bearer_token_env_var: Option<&str>,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
) -> Result<McpAuthStatus> {
|
||||
if bearer_token_env_var.is_some() {
|
||||
return Ok(McpAuthStatus::BearerToken);
|
||||
}
|
||||
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
let default_headers = build_default_headers(http_headers.clone(), env_http_headers.clone())?;
|
||||
if default_headers.contains_key(AUTHORIZATION) {
|
||||
return Ok(McpAuthStatus::BearerToken);
|
||||
}
|
||||
@@ -48,7 +70,8 @@ pub async fn determine_streamable_http_auth_status(
|
||||
return Ok(McpAuthStatus::OAuth);
|
||||
}
|
||||
|
||||
match discover_streamable_http_oauth_with_headers(url, &default_headers).await {
|
||||
let oauth_http = OAuthHttpClient::from_default_headers(http_client, default_headers);
|
||||
match oauth_http.discover(url).await {
|
||||
Ok(Some(_)) => Ok(McpAuthStatus::NotLoggedIn),
|
||||
Ok(None) => Ok(McpAuthStatus::Unsupported),
|
||||
Err(error) => {
|
||||
@@ -74,121 +97,107 @@ pub async fn discover_streamable_http_oauth(
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
) -> Result<Option<StreamableHttpOAuthDiscovery>> {
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
discover_streamable_http_oauth_with_headers(url, &default_headers).await
|
||||
discover_streamable_http_oauth_with_client(
|
||||
url,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
no_proxy_http_client(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn discover_streamable_http_oauth_with_headers(
|
||||
pub(crate) fn no_proxy_http_client() -> Arc<dyn HttpClient> {
|
||||
Arc::new(NoProxyReqwestHttpClient)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct NoProxyReqwestHttpClient;
|
||||
|
||||
impl HttpClient for NoProxyReqwestHttpClient {
|
||||
fn http_request(
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> BoxFuture<'_, std::result::Result<HttpRequestResponse, ExecServerError>> {
|
||||
async move { no_proxy_http_request(params).await }.boxed()
|
||||
}
|
||||
|
||||
fn http_request_stream(
|
||||
&self,
|
||||
_params: HttpRequestParams,
|
||||
) -> BoxFuture<
|
||||
'_,
|
||||
std::result::Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>,
|
||||
> {
|
||||
async move {
|
||||
Err(ExecServerError::Protocol(
|
||||
"streaming is not supported for OAuth discovery".to_string(),
|
||||
))
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
async fn no_proxy_http_request(
|
||||
params: HttpRequestParams,
|
||||
) -> std::result::Result<HttpRequestResponse, ExecServerError> {
|
||||
let method = Method::from_bytes(params.method.as_bytes())
|
||||
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?;
|
||||
let mut builder = reqwest::Client::builder().no_proxy();
|
||||
if let Some(timeout_ms) = params.timeout_ms {
|
||||
builder = builder.timeout(Duration::from_millis(timeout_ms));
|
||||
}
|
||||
let client = builder
|
||||
.build()
|
||||
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?;
|
||||
|
||||
let mut request = client.request(method, ¶ms.url);
|
||||
for header in params.headers {
|
||||
let name = HeaderName::from_bytes(header.name.as_bytes())
|
||||
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?;
|
||||
let value = HeaderValue::from_str(&header.value)
|
||||
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?;
|
||||
request = request.header(name, value);
|
||||
}
|
||||
if let Some(body) = params.body {
|
||||
request = request.body(body.into_inner());
|
||||
}
|
||||
|
||||
let response = request
|
||||
.send()
|
||||
.await
|
||||
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?;
|
||||
let status = response.status().as_u16();
|
||||
let headers = response
|
||||
.headers()
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
value.to_str().ok().map(|value| HttpHeader {
|
||||
name: name.to_string(),
|
||||
value: value.to_string(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let body = response
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?
|
||||
.to_vec();
|
||||
|
||||
Ok(HttpRequestResponse {
|
||||
status,
|
||||
headers,
|
||||
body: body.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn discover_streamable_http_oauth_with_client(
|
||||
url: &str,
|
||||
default_headers: &HeaderMap,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
) -> Result<Option<StreamableHttpOAuthDiscovery>> {
|
||||
let base_url = Url::parse(url)?;
|
||||
|
||||
// Use no_proxy to avoid a bug in the system-configuration crate that
|
||||
// can result in a panic. See #8912.
|
||||
let builder = Client::builder().timeout(DISCOVERY_TIMEOUT).no_proxy();
|
||||
let client = apply_default_headers(builder, default_headers).build()?;
|
||||
|
||||
let mut last_error: Option<Error> = None;
|
||||
for candidate_path in discovery_paths(base_url.path()) {
|
||||
let mut discovery_url = base_url.clone();
|
||||
discovery_url.set_path(&candidate_path);
|
||||
|
||||
let response = match client
|
||||
.get(discovery_url.clone())
|
||||
.header(OAUTH_DISCOVERY_HEADER, OAUTH_DISCOVERY_VERSION)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
last_error = Some(err.into());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if response.status() != StatusCode::OK {
|
||||
continue;
|
||||
}
|
||||
|
||||
let metadata = match response.json::<OAuthDiscoveryMetadata>().await {
|
||||
Ok(metadata) => metadata,
|
||||
Err(err) => {
|
||||
last_error = Some(err.into());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if metadata.authorization_endpoint.is_some() && metadata.token_endpoint.is_some() {
|
||||
return Ok(Some(StreamableHttpOAuthDiscovery {
|
||||
scopes_supported: normalize_scopes(metadata.scopes_supported),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(err) = last_error {
|
||||
debug!("OAuth discovery requests failed for {url}: {err:?}");
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OAuthDiscoveryMetadata {
|
||||
#[serde(default)]
|
||||
authorization_endpoint: Option<String>,
|
||||
#[serde(default)]
|
||||
token_endpoint: Option<String>,
|
||||
#[serde(default)]
|
||||
scopes_supported: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
fn normalize_scopes(scopes_supported: Option<Vec<String>>) -> Option<Vec<String>> {
|
||||
let scopes_supported = scopes_supported?;
|
||||
|
||||
let mut normalized = Vec::new();
|
||||
for scope in scopes_supported {
|
||||
let scope = scope.trim();
|
||||
if scope.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let scope = scope.to_string();
|
||||
if !normalized.contains(&scope) {
|
||||
normalized.push(scope);
|
||||
}
|
||||
}
|
||||
|
||||
if normalized.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(normalized)
|
||||
}
|
||||
}
|
||||
|
||||
/// Implements RFC 8414 section 3.1 for discovering well-known oauth endpoints.
|
||||
/// This is a requirement for MCP servers to support OAuth.
|
||||
/// https://datatracker.ietf.org/doc/html/rfc8414#section-3.1
|
||||
/// https://github.com/modelcontextprotocol/rust-sdk/blob/main/crates/rmcp/src/transport/auth.rs#L182
|
||||
fn discovery_paths(base_path: &str) -> Vec<String> {
|
||||
let trimmed = base_path.trim_start_matches('/').trim_end_matches('/');
|
||||
let canonical = "/.well-known/oauth-authorization-server".to_string();
|
||||
|
||||
if trimmed.is_empty() {
|
||||
return vec![canonical];
|
||||
}
|
||||
|
||||
let mut candidates = Vec::new();
|
||||
let mut push_unique = |candidate: String| {
|
||||
if !candidates.contains(&candidate) {
|
||||
candidates.push(candidate);
|
||||
}
|
||||
};
|
||||
|
||||
push_unique(format!("{canonical}/{trimmed}"));
|
||||
push_unique(format!("/{trimmed}/.well-known/oauth-authorization-server"));
|
||||
push_unique(canonical);
|
||||
|
||||
candidates
|
||||
let oauth_http = OAuthHttpClient::new(http_client, http_headers, env_http_headers)?;
|
||||
oauth_http.discover(url).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -8,6 +8,7 @@ use std::time::Duration;
|
||||
|
||||
use axum::Router;
|
||||
use axum::body::Body;
|
||||
use axum::body::Bytes;
|
||||
use axum::extract::Json;
|
||||
use axum::extract::State;
|
||||
use axum::http::HeaderMap;
|
||||
@@ -122,6 +123,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
SESSION_POST_FAILURE_CONTROL_PATH,
|
||||
post(arm_session_post_failure),
|
||||
)
|
||||
.route("/oauth/token", post(oauth_token))
|
||||
.route(
|
||||
"/.well-known/oauth-authorization-server/mcp",
|
||||
get({
|
||||
@@ -139,6 +141,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
serde_json::to_vec(&json!({
|
||||
"authorization_endpoint": format!("{metadata_base}/oauth/authorize"),
|
||||
"token_endpoint": format!("{metadata_base}/oauth/token"),
|
||||
"registration_endpoint": format!("{metadata_base}/oauth/register"),
|
||||
"scopes_supported": [""],
|
||||
})).expect("failed to serialize metadata"),
|
||||
))
|
||||
@@ -386,7 +389,8 @@ async fn require_bearer(
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
if request.uri().path().contains("/.well-known/") {
|
||||
let path = request.uri().path();
|
||||
if path.contains("/.well-known/") || path.starts_with("/oauth/") {
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
if request
|
||||
@@ -400,6 +404,33 @@ async fn require_bearer(
|
||||
}
|
||||
}
|
||||
|
||||
async fn oauth_token(body: Bytes) -> Result<Response, StatusCode> {
|
||||
let form = String::from_utf8(body.to_vec()).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
if !form.contains("grant_type=refresh_token") {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
let access_token =
|
||||
std::env::var("MCP_OAUTH_ACCESS_TOKEN").unwrap_or_else(|_| "refreshed-oauth-token".into());
|
||||
let refresh_token =
|
||||
std::env::var("MCP_OAUTH_REFRESH_TOKEN").unwrap_or_else(|_| "refresh-token".into());
|
||||
|
||||
#[expect(clippy::expect_used)]
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(
|
||||
serde_json::to_vec(&json!({
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": refresh_token,
|
||||
}))
|
||||
.expect("failed to serialize token response"),
|
||||
))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
async fn arm_session_post_failure(
|
||||
State(state): State<SessionFailureState>,
|
||||
Json(request): Json<ArmSessionPostFailureRequest>,
|
||||
|
||||
@@ -35,6 +35,8 @@ use rmcp::transport::streamable_http_client::StreamableHttpPostResponse;
|
||||
use sse_stream::Sse;
|
||||
use sse_stream::SseStream;
|
||||
|
||||
use crate::oauth::OAuthPersistor;
|
||||
|
||||
const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
|
||||
const JSON_MIME_TYPE: &str = "application/json";
|
||||
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
|
||||
@@ -45,6 +47,7 @@ pub(crate) struct StreamableHttpClientAdapter {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
default_headers: HeaderMap,
|
||||
auth_provider: Option<SharedAuthProvider>,
|
||||
oauth_persistor: Option<OAuthPersistor>,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@@ -55,6 +58,8 @@ pub(crate) enum StreamableHttpClientAdapterError {
|
||||
HttpRequest(#[from] ExecServerError),
|
||||
#[error("invalid HTTP header: {0}")]
|
||||
Header(String),
|
||||
#[error("OAuth token error: {0}")]
|
||||
Auth(String),
|
||||
}
|
||||
|
||||
impl StreamableHttpClientAdapter {
|
||||
@@ -62,11 +67,13 @@ impl StreamableHttpClientAdapter {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
default_headers: HeaderMap,
|
||||
auth_provider: Option<SharedAuthProvider>,
|
||||
oauth_persistor: Option<OAuthPersistor>,
|
||||
) -> Self {
|
||||
Self {
|
||||
http_client,
|
||||
default_headers,
|
||||
auth_provider,
|
||||
oauth_persistor,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -82,7 +89,7 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
|
||||
auth_token: Option<String>,
|
||||
) -> std::result::Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>> {
|
||||
let mut headers = self.default_headers.clone();
|
||||
self.add_auth_headers(&mut headers);
|
||||
self.add_auth_headers(&mut headers).await?;
|
||||
insert_header(
|
||||
&mut headers,
|
||||
ACCEPT,
|
||||
@@ -179,7 +186,7 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
|
||||
auth_token: Option<String>,
|
||||
) -> std::result::Result<(), StreamableHttpError<Self::Error>> {
|
||||
let mut headers = self.default_headers.clone();
|
||||
self.add_auth_headers(&mut headers);
|
||||
self.add_auth_headers(&mut headers).await?;
|
||||
if let Some(auth_token) = auth_token {
|
||||
insert_header(
|
||||
&mut headers,
|
||||
@@ -232,7 +239,7 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
|
||||
StreamableHttpError<Self::Error>,
|
||||
> {
|
||||
let mut headers = self.default_headers.clone();
|
||||
self.add_auth_headers(&mut headers);
|
||||
self.add_auth_headers(&mut headers).await?;
|
||||
insert_header(
|
||||
&mut headers,
|
||||
ACCEPT,
|
||||
@@ -308,10 +315,27 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
|
||||
}
|
||||
|
||||
impl StreamableHttpClientAdapter {
|
||||
fn add_auth_headers(&self, headers: &mut HeaderMap) {
|
||||
async fn add_auth_headers(
|
||||
&self,
|
||||
headers: &mut HeaderMap,
|
||||
) -> std::result::Result<(), StreamableHttpError<StreamableHttpClientAdapterError>> {
|
||||
if let Some(auth_provider) = &self.auth_provider {
|
||||
headers.extend(auth_provider.to_auth_headers());
|
||||
}
|
||||
if !headers.contains_key(AUTHORIZATION)
|
||||
&& let Some(oauth_persistor) = &self.oauth_persistor
|
||||
&& let Some(access_token) = oauth_persistor.access_token().await.map_err(|err| {
|
||||
StreamableHttpError::Client(StreamableHttpClientAdapterError::Auth(err.to_string()))
|
||||
})?
|
||||
{
|
||||
insert_header(
|
||||
headers,
|
||||
AUTHORIZATION,
|
||||
format!("Bearer {access_token}"),
|
||||
StreamableHttpClientAdapterError::Header,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ mod elicitation_client_service;
|
||||
mod executor_process_transport;
|
||||
mod http_client_adapter;
|
||||
mod logging_client_handler;
|
||||
mod mcp_oauth_http;
|
||||
mod oauth;
|
||||
mod perform_oauth_login;
|
||||
mod program_resolver;
|
||||
@@ -10,11 +11,13 @@ mod rmcp_client;
|
||||
mod stdio_server_launcher;
|
||||
mod utils;
|
||||
|
||||
pub use auth_status::StreamableHttpOAuthDiscovery;
|
||||
pub use auth_status::determine_streamable_http_auth_status;
|
||||
pub use auth_status::determine_streamable_http_auth_status_with_client;
|
||||
pub use auth_status::discover_streamable_http_oauth;
|
||||
pub use auth_status::discover_streamable_http_oauth_with_client;
|
||||
pub use auth_status::supports_oauth_login;
|
||||
pub use codex_protocol::protocol::McpAuthStatus;
|
||||
pub use mcp_oauth_http::StreamableHttpOAuthDiscovery;
|
||||
pub use oauth::StoredOAuthTokens;
|
||||
pub use oauth::WrappedOAuthTokenResponse;
|
||||
pub use oauth::delete_oauth_tokens;
|
||||
@@ -24,7 +27,10 @@ pub use perform_oauth_login::OAuthProviderError;
|
||||
pub use perform_oauth_login::OauthLoginHandle;
|
||||
pub use perform_oauth_login::perform_oauth_login;
|
||||
pub use perform_oauth_login::perform_oauth_login_return_url;
|
||||
pub use perform_oauth_login::perform_oauth_login_return_url_with_client;
|
||||
pub use perform_oauth_login::perform_oauth_login_silent;
|
||||
pub use perform_oauth_login::perform_oauth_login_silent_with_client;
|
||||
pub use perform_oauth_login::perform_oauth_login_with_client;
|
||||
pub use rmcp::model::ElicitationAction;
|
||||
pub use rmcp_client::Elicitation;
|
||||
pub use rmcp_client::ElicitationResponse;
|
||||
|
||||
690
codex-rs/rmcp-client/src/mcp_oauth_http.rs
Normal file
690
codex-rs/rmcp-client/src/mcp_oauth_http.rs
Normal file
@@ -0,0 +1,690 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_exec_server::HttpClient;
|
||||
use codex_exec_server::HttpHeader;
|
||||
use codex_exec_server::HttpRequestParams;
|
||||
use oauth2::AuthUrl;
|
||||
use oauth2::AuthorizationCode;
|
||||
use oauth2::ClientId;
|
||||
use oauth2::ClientSecret;
|
||||
use oauth2::CsrfToken;
|
||||
use oauth2::EmptyExtraTokenFields;
|
||||
use oauth2::PkceCodeChallenge;
|
||||
use oauth2::PkceCodeVerifier;
|
||||
use oauth2::RedirectUrl;
|
||||
use oauth2::RefreshToken;
|
||||
use oauth2::RequestTokenError;
|
||||
use oauth2::Scope;
|
||||
use oauth2::StandardErrorResponse;
|
||||
use oauth2::StandardTokenResponse;
|
||||
use oauth2::TokenResponse;
|
||||
use oauth2::TokenUrl;
|
||||
use oauth2::basic::BasicClient;
|
||||
use oauth2::basic::BasicErrorResponseType;
|
||||
use oauth2::basic::BasicTokenType;
|
||||
use reqwest::StatusCode;
|
||||
use reqwest::Url;
|
||||
use reqwest::header::HeaderMap;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::WrappedOAuthTokenResponse;
|
||||
use crate::oauth::StoredOAuthTokens;
|
||||
use crate::oauth::compute_expires_at_millis;
|
||||
use crate::utils::build_default_headers;
|
||||
|
||||
const DISCOVERY_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const OAUTH_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
const OAUTH_DISCOVERY_HEADER: &str = "MCP-Protocol-Version";
|
||||
const OAUTH_DISCOVERY_VERSION: &str = "2024-11-05";
|
||||
|
||||
type OAuthTokenResponse = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
|
||||
type OAuthErrorResponse = StandardErrorResponse<BasicErrorResponseType>;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct StreamableHttpOAuthDiscovery {
|
||||
pub scopes_supported: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct OAuthHttpClient {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
default_headers: HeaderMap,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("{0}")]
|
||||
pub(crate) struct OAuthHttpError(String);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct OAuthAuthorizationSession {
|
||||
pub authorization_url: String,
|
||||
pub csrf_state: CsrfToken,
|
||||
pkce_verifier: PkceCodeVerifier,
|
||||
client: BasicClient<
|
||||
oauth2::EndpointSet,
|
||||
oauth2::EndpointNotSet,
|
||||
oauth2::EndpointNotSet,
|
||||
oauth2::EndpointNotSet,
|
||||
oauth2::EndpointSet,
|
||||
>,
|
||||
client_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct OAuthClientConfig {
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum DiscoveredOAuthMetadata {
|
||||
AuthorizationServer(OAuthDiscoveryMetadata),
|
||||
ProtectedResource {
|
||||
metadata_url: Url,
|
||||
authorization_servers: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl OAuthHttpClient {
|
||||
pub(crate) fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
) -> Result<Self> {
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
Ok(Self {
|
||||
http_client,
|
||||
default_headers,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn from_default_headers(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
default_headers: HeaderMap,
|
||||
) -> Self {
|
||||
Self {
|
||||
http_client,
|
||||
default_headers,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn discover(&self, url: &str) -> Result<Option<StreamableHttpOAuthDiscovery>> {
|
||||
let metadata = match self.discover_metadata(url).await? {
|
||||
Some(metadata) => metadata,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
Ok(Some(StreamableHttpOAuthDiscovery {
|
||||
scopes_supported: normalize_scopes(metadata.scopes_supported),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn discover_metadata(&self, url: &str) -> Result<Option<OAuthDiscoveryMetadata>> {
|
||||
match self.discover_metadata_from_paths(url).await? {
|
||||
Some(DiscoveredOAuthMetadata::AuthorizationServer(metadata)) => Ok(Some(metadata)),
|
||||
Some(DiscoveredOAuthMetadata::ProtectedResource {
|
||||
metadata_url,
|
||||
authorization_servers,
|
||||
}) => {
|
||||
for authorization_server in authorization_servers {
|
||||
let authorization_server = authorization_server.trim();
|
||||
if authorization_server.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let authorization_server_url = match Url::parse(authorization_server) {
|
||||
Ok(url) => url,
|
||||
Err(_) => match metadata_url.join(authorization_server) {
|
||||
Ok(url) => url,
|
||||
Err(err) => {
|
||||
debug!(
|
||||
"failed to resolve OAuth authorization server URL `{authorization_server}`: {err}"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let discovered = if authorization_server_url.path().contains("/.well-known/") {
|
||||
self.fetch_discovery_metadata(&authorization_server_url)
|
||||
.await?
|
||||
} else {
|
||||
self.discover_metadata_from_paths(authorization_server_url.as_str())
|
||||
.await?
|
||||
};
|
||||
|
||||
if let Some(DiscoveredOAuthMetadata::AuthorizationServer(metadata)) = discovered
|
||||
{
|
||||
return Ok(Some(metadata));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
async fn discover_metadata_from_paths(
|
||||
&self,
|
||||
url: &str,
|
||||
) -> Result<Option<DiscoveredOAuthMetadata>> {
|
||||
let base_url = Url::parse(url)?;
|
||||
let mut last_error: Option<anyhow::Error> = None;
|
||||
|
||||
for candidate_path in discovery_paths(base_url.path()) {
|
||||
let mut discovery_url = base_url.clone();
|
||||
discovery_url.set_query(None);
|
||||
discovery_url.set_fragment(None);
|
||||
discovery_url.set_path(&candidate_path);
|
||||
|
||||
match self.fetch_discovery_metadata(&discovery_url).await {
|
||||
Ok(Some(metadata)) => return Ok(Some(metadata)),
|
||||
Ok(None) => {}
|
||||
Err(err) => {
|
||||
last_error = Some(err);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(err) = last_error {
|
||||
debug!("OAuth discovery requests failed for {url}: {err:?}");
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn fetch_discovery_metadata(
|
||||
&self,
|
||||
discovery_url: &Url,
|
||||
) -> Result<Option<DiscoveredOAuthMetadata>> {
|
||||
let response = self
|
||||
.request(
|
||||
"GET",
|
||||
discovery_url.as_str(),
|
||||
vec![HttpHeader {
|
||||
name: OAUTH_DISCOVERY_HEADER.to_string(),
|
||||
value: OAUTH_DISCOVERY_VERSION.to_string(),
|
||||
}],
|
||||
/*body*/ None,
|
||||
Some(DISCOVERY_TIMEOUT),
|
||||
)
|
||||
.await?;
|
||||
|
||||
if response.status != StatusCode::OK.as_u16() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let metadata: OAuthDiscoveryMetadata = serde_json::from_slice(&response.body.0)?;
|
||||
|
||||
if metadata.authorization_endpoint.is_some() && metadata.token_endpoint.is_some() {
|
||||
return Ok(Some(DiscoveredOAuthMetadata::AuthorizationServer(metadata)));
|
||||
}
|
||||
|
||||
let authorization_servers = {
|
||||
let mut authorization_servers = Vec::new();
|
||||
let mut push_unique = |authorization_server: &String| {
|
||||
let authorization_server = authorization_server.trim();
|
||||
if !authorization_server.is_empty()
|
||||
&& !authorization_servers
|
||||
.iter()
|
||||
.any(|existing| existing == authorization_server)
|
||||
{
|
||||
authorization_servers.push(authorization_server.to_string());
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(authorization_server) = metadata.authorization_server.as_ref() {
|
||||
push_unique(authorization_server);
|
||||
}
|
||||
if let Some(list) = metadata.authorization_servers.as_ref() {
|
||||
for authorization_server in list {
|
||||
push_unique(authorization_server);
|
||||
}
|
||||
}
|
||||
|
||||
authorization_servers
|
||||
};
|
||||
if authorization_servers.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(DiscoveredOAuthMetadata::ProtectedResource {
|
||||
metadata_url: discovery_url.clone(),
|
||||
authorization_servers,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn start_authorization(
|
||||
&self,
|
||||
server_url: &str,
|
||||
scopes: &[String],
|
||||
redirect_uri: &str,
|
||||
client_name: &str,
|
||||
) -> Result<OAuthAuthorizationSession> {
|
||||
let metadata = self
|
||||
.discover_metadata(server_url)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("No authorization support detected"))?;
|
||||
let client_config = self
|
||||
.register_client(&metadata, client_name, redirect_uri)
|
||||
.await?;
|
||||
let client = oauth_client(&metadata, &client_config, redirect_uri)?;
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
let mut auth_request = client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.set_pkce_challenge(pkce_challenge);
|
||||
for scope in scopes {
|
||||
auth_request = auth_request.add_scope(Scope::new(scope.clone()));
|
||||
}
|
||||
let (authorization_url, csrf_state) = auth_request.url();
|
||||
|
||||
Ok(OAuthAuthorizationSession {
|
||||
authorization_url: authorization_url.to_string(),
|
||||
csrf_state,
|
||||
pkce_verifier,
|
||||
client,
|
||||
client_id: client_config.client_id,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn exchange_code(
|
||||
&self,
|
||||
session: OAuthAuthorizationSession,
|
||||
code: &str,
|
||||
csrf_state: &str,
|
||||
) -> Result<(String, OAuthTokenResponse)> {
|
||||
if session.csrf_state.secret() != csrf_state {
|
||||
return Err(anyhow!(
|
||||
"OAuth callback state did not match authorization request"
|
||||
));
|
||||
}
|
||||
|
||||
let http_client = self.clone();
|
||||
let token = session
|
||||
.client
|
||||
.exchange_code(AuthorizationCode::new(code.to_string()))
|
||||
.set_pkce_verifier(session.pkce_verifier)
|
||||
.request_async(&|request| {
|
||||
let http_client = http_client.clone();
|
||||
async move { http_client.oauth_request(request).await }
|
||||
})
|
||||
.await
|
||||
.or_else(parse_token_from_parse_error)?;
|
||||
|
||||
Ok((session.client_id, token))
|
||||
}
|
||||
|
||||
pub(crate) async fn refresh_token(
|
||||
&self,
|
||||
tokens: &StoredOAuthTokens,
|
||||
) -> Result<Option<StoredOAuthTokens>> {
|
||||
let refresh_token = match tokens.token_response.0.refresh_token() {
|
||||
Some(refresh_token) => refresh_token.secret().to_string(),
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let metadata = self
|
||||
.discover_metadata(&tokens.url)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("No authorization support detected"))?;
|
||||
let client_config = OAuthClientConfig {
|
||||
client_id: tokens.client_id.clone(),
|
||||
client_secret: None,
|
||||
};
|
||||
let client = oauth_client(&metadata, &client_config, &tokens.url)?;
|
||||
let http_client = self.clone();
|
||||
let token = client
|
||||
.exchange_refresh_token(&RefreshToken::new(refresh_token))
|
||||
.request_async(&|request| {
|
||||
let http_client = http_client.clone();
|
||||
async move { http_client.oauth_request(request).await }
|
||||
})
|
||||
.await
|
||||
.or_else(parse_token_from_parse_error)?;
|
||||
|
||||
let expires_at = compute_expires_at_millis(&token);
|
||||
Ok(Some(StoredOAuthTokens {
|
||||
server_name: tokens.server_name.clone(),
|
||||
url: tokens.url.clone(),
|
||||
client_id: tokens.client_id.clone(),
|
||||
token_response: WrappedOAuthTokenResponse(token),
|
||||
expires_at,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn register_client(
|
||||
&self,
|
||||
metadata: &OAuthDiscoveryMetadata,
|
||||
client_name: &str,
|
||||
redirect_uri: &str,
|
||||
) -> Result<OAuthClientConfig> {
|
||||
let registration_url = metadata
|
||||
.registration_endpoint
|
||||
.as_deref()
|
||||
.ok_or_else(|| anyhow!("Dynamic client registration not supported"))?;
|
||||
|
||||
if let Some(response_types_supported) = metadata.response_types_supported.as_ref()
|
||||
&& !response_types_supported.iter().any(|value| value == "code")
|
||||
{
|
||||
return Err(anyhow!(
|
||||
"OAuth server does not support authorization code flow"
|
||||
));
|
||||
}
|
||||
|
||||
let body = serde_json::to_vec(&ClientRegistrationRequest {
|
||||
client_name: client_name.to_string(),
|
||||
redirect_uris: vec![redirect_uri.to_string()],
|
||||
grant_types: vec![
|
||||
"authorization_code".to_string(),
|
||||
"refresh_token".to_string(),
|
||||
],
|
||||
token_endpoint_auth_method: "none".to_string(),
|
||||
response_types: vec!["code".to_string()],
|
||||
})?;
|
||||
|
||||
let response = self
|
||||
.request(
|
||||
"POST",
|
||||
registration_url,
|
||||
vec![HttpHeader {
|
||||
name: reqwest::header::CONTENT_TYPE.to_string(),
|
||||
value: "application/json".to_string(),
|
||||
}],
|
||||
Some(body),
|
||||
Some(OAUTH_HTTP_TIMEOUT),
|
||||
)
|
||||
.await?;
|
||||
|
||||
if !status_is_success(response.status) {
|
||||
return Err(anyhow!(
|
||||
"Dynamic registration failed: HTTP {}: {}",
|
||||
response.status,
|
||||
String::from_utf8_lossy(&response.body.0)
|
||||
));
|
||||
}
|
||||
|
||||
let registration: ClientRegistrationResponse = serde_json::from_slice(&response.body.0)
|
||||
.context("failed to parse registration response")?;
|
||||
Ok(OAuthClientConfig {
|
||||
client_id: registration.client_id,
|
||||
client_secret: registration
|
||||
.client_secret
|
||||
.filter(|secret| !secret.is_empty()),
|
||||
})
|
||||
}
|
||||
|
||||
async fn oauth_request(
|
||||
&self,
|
||||
request: oauth2::HttpRequest,
|
||||
) -> std::result::Result<oauth2::HttpResponse, OAuthHttpError> {
|
||||
let (parts, body) = request.into_parts();
|
||||
let headers = parts
|
||||
.headers
|
||||
.iter()
|
||||
.map(|(name, value)| {
|
||||
let value = value
|
||||
.to_str()
|
||||
.map_err(|err| OAuthHttpError(format!("invalid OAuth header value: {err}")))?;
|
||||
Ok(HttpHeader {
|
||||
name: name.to_string(),
|
||||
value: value.to_string(),
|
||||
})
|
||||
})
|
||||
.collect::<std::result::Result<Vec<_>, OAuthHttpError>>()?;
|
||||
let response = self
|
||||
.request(
|
||||
parts.method.as_str(),
|
||||
parts.uri.to_string().as_str(),
|
||||
headers,
|
||||
Some(body),
|
||||
Some(OAUTH_HTTP_TIMEOUT),
|
||||
)
|
||||
.await
|
||||
.map_err(|err| OAuthHttpError(err.to_string()))?;
|
||||
|
||||
let mut oauth_response = oauth2::HttpResponse::new(response.body.0);
|
||||
*oauth_response.status_mut() = oauth2::http::StatusCode::from_u16(response.status)
|
||||
.map_err(|err| OAuthHttpError(format!("invalid OAuth response status: {err}")))?;
|
||||
for header in response.headers {
|
||||
let name = oauth2::http::HeaderName::from_bytes(header.name.as_bytes())
|
||||
.map_err(|err| OAuthHttpError(format!("invalid OAuth response header: {err}")))?;
|
||||
let value = oauth2::http::HeaderValue::from_str(&header.value).map_err(|err| {
|
||||
OAuthHttpError(format!("invalid OAuth response header value: {err}"))
|
||||
})?;
|
||||
oauth_response.headers_mut().append(name, value);
|
||||
}
|
||||
Ok(oauth_response)
|
||||
}
|
||||
|
||||
async fn request(
|
||||
&self,
|
||||
method: &str,
|
||||
url: &str,
|
||||
extra_headers: Vec<HttpHeader>,
|
||||
body: Option<Vec<u8>>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<codex_exec_server::HttpRequestResponse> {
|
||||
let mut headers = protocol_headers(&self.default_headers)?;
|
||||
headers.extend(extra_headers);
|
||||
self.http_client
|
||||
.http_request(HttpRequestParams {
|
||||
method: method.to_string(),
|
||||
url: url.to_string(),
|
||||
headers,
|
||||
body: body.map(Into::into),
|
||||
timeout_ms: timeout
|
||||
.map(|timeout| timeout.as_millis().clamp(1, u64::MAX as u128) as u64),
|
||||
request_id: "oauth-request".to_string(),
|
||||
stream_response: false,
|
||||
})
|
||||
.await
|
||||
.map_err(|err| anyhow!(err))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OAuthDiscoveryMetadata {
|
||||
#[serde(default)]
|
||||
authorization_endpoint: Option<String>,
|
||||
#[serde(default)]
|
||||
token_endpoint: Option<String>,
|
||||
#[serde(default)]
|
||||
registration_endpoint: Option<String>,
|
||||
#[serde(default)]
|
||||
scopes_supported: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
response_types_supported: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
authorization_server: Option<String>,
|
||||
#[serde(default)]
|
||||
authorization_servers: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ClientRegistrationRequest {
|
||||
client_name: String,
|
||||
redirect_uris: Vec<String>,
|
||||
grant_types: Vec<String>,
|
||||
token_endpoint_auth_method: String,
|
||||
response_types: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ClientRegistrationResponse {
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
}
|
||||
|
||||
fn oauth_client(
|
||||
metadata: &OAuthDiscoveryMetadata,
|
||||
config: &OAuthClientConfig,
|
||||
redirect_uri: &str,
|
||||
) -> Result<
|
||||
BasicClient<
|
||||
oauth2::EndpointSet,
|
||||
oauth2::EndpointNotSet,
|
||||
oauth2::EndpointNotSet,
|
||||
oauth2::EndpointNotSet,
|
||||
oauth2::EndpointSet,
|
||||
>,
|
||||
> {
|
||||
let authorization_endpoint = metadata
|
||||
.authorization_endpoint
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("OAuth metadata did not include authorization endpoint"))?;
|
||||
let token_endpoint = metadata
|
||||
.token_endpoint
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("OAuth metadata did not include token endpoint"))?;
|
||||
let mut client = BasicClient::new(ClientId::new(config.client_id.clone()))
|
||||
.set_auth_uri(AuthUrl::new(authorization_endpoint)?)
|
||||
.set_token_uri(TokenUrl::new(token_endpoint)?)
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_uri.to_string())?);
|
||||
if let Some(secret) = config.client_secret.clone() {
|
||||
client = client.set_client_secret(ClientSecret::new(secret));
|
||||
}
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn parse_token_from_parse_error(
|
||||
error: RequestTokenError<OAuthHttpError, OAuthErrorResponse>,
|
||||
) -> std::result::Result<OAuthTokenResponse, RequestTokenError<OAuthHttpError, OAuthErrorResponse>>
|
||||
{
|
||||
match error {
|
||||
RequestTokenError::Parse(parse_error, body) => {
|
||||
match serde_json::from_slice::<OAuthTokenResponse>(&body) {
|
||||
Ok(parsed) => Ok(parsed),
|
||||
Err(_) => Err(RequestTokenError::Parse(parse_error, body)),
|
||||
}
|
||||
}
|
||||
error => Err(error),
|
||||
}
|
||||
}
|
||||
|
||||
fn protocol_headers(headers: &HeaderMap) -> Result<Vec<HttpHeader>> {
|
||||
headers
|
||||
.iter()
|
||||
.map(|(name, value)| {
|
||||
let value = value
|
||||
.to_str()
|
||||
.with_context(|| format!("invalid HTTP header value for `{name}`"))?;
|
||||
Ok(HttpHeader {
|
||||
name: name.to_string(),
|
||||
value: value.to_string(),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn status_is_success(status: u16) -> bool {
|
||||
(200..300).contains(&status)
|
||||
}
|
||||
|
||||
pub(crate) fn normalize_scopes(scopes_supported: Option<Vec<String>>) -> Option<Vec<String>> {
|
||||
let scopes_supported = scopes_supported?;
|
||||
|
||||
let mut normalized = Vec::new();
|
||||
for scope in scopes_supported {
|
||||
let scope = scope.trim();
|
||||
if scope.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let scope = scope.to_string();
|
||||
if !normalized.contains(&scope) {
|
||||
normalized.push(scope);
|
||||
}
|
||||
}
|
||||
|
||||
if normalized.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(normalized)
|
||||
}
|
||||
}
|
||||
|
||||
/// Implements RFC 8414 section 3.1 and RFC 9728 section 3 for discovering
|
||||
/// well-known OAuth endpoints.
|
||||
pub(crate) fn discovery_paths(base_path: &str) -> Vec<String> {
|
||||
let trimmed = base_path.trim_start_matches('/').trim_end_matches('/');
|
||||
let mut candidates = Vec::new();
|
||||
let mut push_unique = |candidate: String| {
|
||||
if !candidates.contains(&candidate) {
|
||||
candidates.push(candidate);
|
||||
}
|
||||
};
|
||||
|
||||
let mut push_well_known = |well_known: &str| {
|
||||
let canonical = format!("/.well-known/{well_known}");
|
||||
if trimmed.is_empty() {
|
||||
push_unique(canonical);
|
||||
} else {
|
||||
push_unique(format!("{canonical}/{trimmed}"));
|
||||
push_unique(format!("/{trimmed}/.well-known/{well_known}"));
|
||||
push_unique(canonical);
|
||||
}
|
||||
};
|
||||
|
||||
push_well_known("oauth-authorization-server");
|
||||
push_well_known("openid-configuration");
|
||||
push_well_known("oauth-protected-resource");
|
||||
|
||||
candidates
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn discovery_paths_prefer_rfc_8414_resource_path() {
|
||||
assert_eq!(
|
||||
discovery_paths("/mcp"),
|
||||
vec![
|
||||
"/.well-known/oauth-authorization-server/mcp".to_string(),
|
||||
"/mcp/.well-known/oauth-authorization-server".to_string(),
|
||||
"/.well-known/oauth-authorization-server".to_string(),
|
||||
"/.well-known/openid-configuration/mcp".to_string(),
|
||||
"/mcp/.well-known/openid-configuration".to_string(),
|
||||
"/.well-known/openid-configuration".to_string(),
|
||||
"/.well-known/oauth-protected-resource/mcp".to_string(),
|
||||
"/mcp/.well-known/oauth-protected-resource".to_string(),
|
||||
"/.well-known/oauth-protected-resource".to_string(),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discovery_paths_deduplicate_root_path() {
|
||||
assert_eq!(
|
||||
discovery_paths("/"),
|
||||
vec![
|
||||
"/.well-known/oauth-authorization-server".to_string(),
|
||||
"/.well-known/openid-configuration".to_string(),
|
||||
"/.well-known/oauth-protected-resource".to_string(),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_scopes_trims_empties_and_deduplicates() {
|
||||
assert_eq!(
|
||||
normalize_scopes(Some(vec![
|
||||
"read".to_string(),
|
||||
" write ".to_string(),
|
||||
"".to_string(),
|
||||
"read".to_string(),
|
||||
])),
|
||||
Some(vec!["read".to_string(), "write".to_string()])
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -45,9 +45,10 @@ use tracing::warn;
|
||||
|
||||
use codex_keyring_store::DefaultKeyringStore;
|
||||
use codex_keyring_store::KeyringStore;
|
||||
use rmcp::transport::auth::AuthorizationManager;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
use crate::mcp_oauth_http::OAuthHttpClient;
|
||||
use codex_utils_home_dir::find_codex_home;
|
||||
|
||||
const KEYRING_SERVICE: &str = "Codex MCP Credentials";
|
||||
@@ -94,6 +95,43 @@ pub(crate) fn load_oauth_tokens(
|
||||
}
|
||||
}
|
||||
|
||||
enum OAuthTokensStorageState {
|
||||
Found(StoredOAuthTokens),
|
||||
Missing,
|
||||
Unavailable,
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_for_persist(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<OAuthTokensStorageState> {
|
||||
let keyring_store = DefaultKeyringStore;
|
||||
match store_mode {
|
||||
OAuthCredentialsStoreMode::Auto => {
|
||||
load_oauth_tokens_for_persist_from_keyring_with_fallback_to_file(
|
||||
&keyring_store,
|
||||
server_name,
|
||||
url,
|
||||
)
|
||||
}
|
||||
OAuthCredentialsStoreMode::File => {
|
||||
Ok(match load_oauth_tokens_from_file(server_name, url)? {
|
||||
Some(tokens) => OAuthTokensStorageState::Found(tokens),
|
||||
None => OAuthTokensStorageState::Missing,
|
||||
})
|
||||
}
|
||||
OAuthCredentialsStoreMode::Keyring => Ok(
|
||||
match load_oauth_tokens_from_keyring(&keyring_store, server_name, url)
|
||||
.with_context(|| "failed to read OAuth tokens from keyring".to_string())?
|
||||
{
|
||||
Some(tokens) => OAuthTokensStorageState::Found(tokens),
|
||||
None => OAuthTokensStorageState::Missing,
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn has_oauth_tokens(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
@@ -134,6 +172,48 @@ fn load_oauth_tokens_from_keyring_with_fallback_to_file<K: KeyringStore>(
|
||||
}
|
||||
}
|
||||
|
||||
fn stored_tokens_match_without_expires_in(
|
||||
actual: &StoredOAuthTokens,
|
||||
expected: &StoredOAuthTokens,
|
||||
) -> bool {
|
||||
let actual_response = &actual.token_response.0;
|
||||
let expected_response = &expected.token_response.0;
|
||||
|
||||
actual.server_name == expected.server_name
|
||||
&& actual.url == expected.url
|
||||
&& actual.client_id == expected.client_id
|
||||
&& actual.expires_at == expected.expires_at
|
||||
&& actual_response.access_token().secret() == expected_response.access_token().secret()
|
||||
&& actual_response.token_type() == expected_response.token_type()
|
||||
&& actual_response.refresh_token().map(RefreshToken::secret)
|
||||
== expected_response.refresh_token().map(RefreshToken::secret)
|
||||
&& actual_response.scopes() == expected_response.scopes()
|
||||
&& actual_response.extra_fields() == expected_response.extra_fields()
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_for_persist_from_keyring_with_fallback_to_file<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<OAuthTokensStorageState> {
|
||||
match load_oauth_tokens_from_keyring(keyring_store, server_name, url) {
|
||||
Ok(Some(tokens)) => Ok(OAuthTokensStorageState::Found(tokens)),
|
||||
Ok(None) => Ok(match load_oauth_tokens_from_file(server_name, url)? {
|
||||
Some(tokens) => OAuthTokensStorageState::Found(tokens),
|
||||
None => OAuthTokensStorageState::Missing,
|
||||
}),
|
||||
Err(error) => {
|
||||
warn!("failed to read OAuth tokens from keyring: {error}");
|
||||
match load_oauth_tokens_from_file(server_name, url)
|
||||
.with_context(|| format!("failed to read OAuth tokens from keyring: {error}"))?
|
||||
{
|
||||
Some(tokens) => Ok(OAuthTokensStorageState::Found(tokens)),
|
||||
None => Ok(OAuthTokensStorageState::Unavailable),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_from_keyring<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
@@ -256,116 +336,161 @@ pub(crate) struct OAuthPersistor {
|
||||
|
||||
struct OAuthPersistorInner {
|
||||
server_name: String,
|
||||
url: String,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
oauth_http: OAuthHttpClient,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
last_credentials: Mutex<Option<StoredOAuthTokens>>,
|
||||
credentials_state: Mutex<OAuthCredentialsState>,
|
||||
refresh_gate: Semaphore,
|
||||
}
|
||||
|
||||
struct OAuthCredentialsState {
|
||||
current: Option<StoredOAuthTokens>,
|
||||
last_persisted: Option<StoredOAuthTokens>,
|
||||
}
|
||||
|
||||
impl OAuthPersistor {
|
||||
pub(crate) fn new(
|
||||
server_name: String,
|
||||
url: String,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
oauth_http: OAuthHttpClient,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
initial_credentials: Option<StoredOAuthTokens>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(OAuthPersistorInner {
|
||||
server_name,
|
||||
url,
|
||||
authorization_manager,
|
||||
oauth_http,
|
||||
store_mode,
|
||||
last_credentials: Mutex::new(initial_credentials),
|
||||
credentials_state: Mutex::new(OAuthCredentialsState {
|
||||
current: initial_credentials.clone(),
|
||||
last_persisted: initial_credentials,
|
||||
}),
|
||||
refresh_gate: Semaphore::new(1),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Persists the latest stored credentials if they have changed.
|
||||
/// Deletes the credentials if they are no longer present.
|
||||
#[expect(
|
||||
clippy::await_holding_invalid_type,
|
||||
reason = "AuthorizationManager async access must be serialized through its mutex"
|
||||
)]
|
||||
pub(crate) async fn persist_if_needed(&self) -> Result<()> {
|
||||
let (client_id, maybe_credentials) = {
|
||||
let manager = self.inner.authorization_manager.clone();
|
||||
let guard = manager.lock().await;
|
||||
guard.get_credentials().await
|
||||
}?;
|
||||
|
||||
match maybe_credentials {
|
||||
Some(credentials) => {
|
||||
let mut last_credentials = self.inner.last_credentials.lock().await;
|
||||
let new_token_response = WrappedOAuthTokenResponse(credentials.clone());
|
||||
let same_token = last_credentials
|
||||
.as_ref()
|
||||
.map(|prev| prev.token_response == new_token_response)
|
||||
.unwrap_or(false);
|
||||
let expires_at = if same_token {
|
||||
last_credentials.as_ref().and_then(|prev| prev.expires_at)
|
||||
} else {
|
||||
compute_expires_at_millis(&credentials)
|
||||
let mut credentials_state = self.inner.credentials_state.lock().await;
|
||||
let Some(stored) = credentials_state.current.clone() else {
|
||||
return Ok(());
|
||||
};
|
||||
match load_oauth_tokens_for_persist(
|
||||
&self.inner.server_name,
|
||||
&stored.url,
|
||||
self.inner.store_mode,
|
||||
)? {
|
||||
OAuthTokensStorageState::Missing => {
|
||||
credentials_state.current = None;
|
||||
credentials_state.last_persisted = None;
|
||||
}
|
||||
OAuthTokensStorageState::Unavailable => {
|
||||
let should_save = match credentials_state.last_persisted.as_ref() {
|
||||
Some(last_persisted) => {
|
||||
!stored_tokens_match_without_expires_in(&stored, last_persisted)
|
||||
}
|
||||
None => true,
|
||||
};
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: self.inner.server_name.clone(),
|
||||
url: self.inner.url.clone(),
|
||||
client_id,
|
||||
token_response: new_token_response,
|
||||
expires_at,
|
||||
};
|
||||
if last_credentials.as_ref() != Some(&stored) {
|
||||
if should_save {
|
||||
save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?;
|
||||
*last_credentials = Some(stored);
|
||||
credentials_state.last_persisted = Some(stored);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let mut last_serialized = self.inner.last_credentials.lock().await;
|
||||
if last_serialized.take().is_some()
|
||||
&& let Err(error) = delete_oauth_tokens(
|
||||
&self.inner.server_name,
|
||||
&self.inner.url,
|
||||
self.inner.store_mode,
|
||||
)
|
||||
{
|
||||
warn!(
|
||||
"failed to remove OAuth tokens for server {}: {error}",
|
||||
self.inner.server_name
|
||||
);
|
||||
OAuthTokensStorageState::Found(current_store)
|
||||
if stored_tokens_match_without_expires_in(¤t_store, &stored) =>
|
||||
{
|
||||
credentials_state.last_persisted = Some(current_store);
|
||||
}
|
||||
OAuthTokensStorageState::Found(current_store) => {
|
||||
let store_matches_last_persisted = credentials_state
|
||||
.last_persisted
|
||||
.as_ref()
|
||||
.is_some_and(|last_persisted| {
|
||||
stored_tokens_match_without_expires_in(¤t_store, last_persisted)
|
||||
});
|
||||
if store_matches_last_persisted {
|
||||
save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?;
|
||||
credentials_state.last_persisted = Some(stored);
|
||||
} else {
|
||||
credentials_state.current = Some(current_store.clone());
|
||||
credentials_state.last_persisted = Some(current_store);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[expect(
|
||||
clippy::await_holding_invalid_type,
|
||||
reason = "AuthorizationManager async access must be serialized through its mutex"
|
||||
)]
|
||||
pub(crate) async fn refresh_if_needed(&self) -> Result<()> {
|
||||
let expires_at = {
|
||||
let guard = self.inner.last_credentials.lock().await;
|
||||
guard.as_ref().and_then(|tokens| tokens.expires_at)
|
||||
let _permit = self
|
||||
.inner
|
||||
.refresh_gate
|
||||
.acquire()
|
||||
.await
|
||||
.context("OAuth refresh gate was closed")?;
|
||||
let tokens = {
|
||||
let state = self.inner.credentials_state.lock().await;
|
||||
let Some(tokens) = state.current.as_ref() else {
|
||||
return Ok(());
|
||||
};
|
||||
if !token_needs_refresh(tokens.expires_at) {
|
||||
return Ok(());
|
||||
}
|
||||
tokens.clone()
|
||||
};
|
||||
|
||||
if !token_needs_refresh(expires_at) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
{
|
||||
let manager = self.inner.authorization_manager.clone();
|
||||
let guard = manager.lock().await;
|
||||
guard.refresh_token().await.with_context(|| {
|
||||
if let Some(refreshed) = self
|
||||
.inner
|
||||
.oauth_http
|
||||
.refresh_token(&tokens)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to refresh OAuth tokens for server {}",
|
||||
self.inner.server_name
|
||||
)
|
||||
})?;
|
||||
})?
|
||||
{
|
||||
let mut state = self.inner.credentials_state.lock().await;
|
||||
if state.current.as_ref() == Some(&tokens) {
|
||||
state.current = Some(refreshed);
|
||||
}
|
||||
}
|
||||
|
||||
self.persist_if_needed().await
|
||||
}
|
||||
|
||||
pub(crate) async fn access_token(&self) -> Result<Option<String>> {
|
||||
let cached = {
|
||||
let state = self.inner.credentials_state.lock().await;
|
||||
state.current.as_ref().map(|tokens| {
|
||||
(
|
||||
tokens.token_response.0.access_token().secret().to_string(),
|
||||
tokens.expires_at,
|
||||
)
|
||||
})
|
||||
};
|
||||
if let Err(err) = self.refresh_if_needed().await {
|
||||
if let Some((access_token, expires_at)) = cached
|
||||
&& expires_in_from_timestamp(expires_at.unwrap_or(u64::MAX)).is_some()
|
||||
{
|
||||
warn!(
|
||||
"failed to refresh OAuth tokens for server {}; using cached access token: {err:#}",
|
||||
self.inner.server_name
|
||||
);
|
||||
return Ok(Some(access_token));
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
let state = self.inner.credentials_state.lock().await;
|
||||
let Some(tokens) = state.current.as_ref() else {
|
||||
return Ok(None);
|
||||
};
|
||||
if expires_in_from_timestamp(tokens.expires_at.unwrap_or(u64::MAX)).is_none() {
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some(
|
||||
tokens.token_response.0.access_token().secret().to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
const FALLBACK_FILENAME: &str = ".credentials.json";
|
||||
@@ -595,6 +720,12 @@ fn sha_256_prefix(value: &Value) -> Result<String> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::Result;
|
||||
use codex_exec_server::ExecServerError;
|
||||
use codex_exec_server::HttpClient;
|
||||
use codex_exec_server::HttpRequestParams;
|
||||
use codex_exec_server::HttpRequestResponse;
|
||||
use codex_exec_server::HttpResponseBodyStream;
|
||||
use futures::future::BoxFuture;
|
||||
use keyring::Error as KeyringError;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::sync::Mutex;
|
||||
@@ -693,6 +824,24 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_oauth_tokens_for_persist_keeps_keyring_errors_distinct() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockKeyringStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "load".into()));
|
||||
|
||||
let state = super::load_oauth_tokens_for_persist_from_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?;
|
||||
|
||||
assert!(matches!(state, super::OAuthTokensStorageState::Unavailable));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_oauth_tokens_prefers_keyring_when_available() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
@@ -846,6 +995,197 @@ mod tests {
|
||||
assert!(tokens.token_response.0.expires_in().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persistor_does_not_recreate_deleted_file_credentials() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let tokens = sample_tokens();
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
assert!(super::fallback_file_path()?.exists());
|
||||
super::delete_oauth_tokens(
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
)?;
|
||||
|
||||
let persistor = OAuthPersistor::new(
|
||||
tokens.server_name.clone(),
|
||||
OAuthHttpClient::new(
|
||||
Arc::new(FailingHttpClient),
|
||||
/*http_headers*/ None,
|
||||
/*env_http_headers*/ None,
|
||||
)?,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
Some(tokens.clone()),
|
||||
);
|
||||
|
||||
persistor.persist_if_needed().await?;
|
||||
|
||||
assert!(
|
||||
super::load_oauth_tokens(
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
OAuthCredentialsStoreMode::File
|
||||
)?
|
||||
.is_none()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persistor_does_not_clobber_newer_file_credentials() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let tokens = sample_tokens();
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
let persistor = OAuthPersistor::new(
|
||||
tokens.server_name.clone(),
|
||||
OAuthHttpClient::new(
|
||||
Arc::new(FailingHttpClient),
|
||||
/*http_headers*/ None,
|
||||
/*env_http_headers*/ None,
|
||||
)?,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
Some(tokens.clone()),
|
||||
);
|
||||
|
||||
let newer_tokens = sample_tokens_with_access_token("newer-access-token");
|
||||
super::save_oauth_tokens_to_file(&newer_tokens)?;
|
||||
|
||||
persistor.persist_if_needed().await?;
|
||||
|
||||
let loaded = super::load_oauth_tokens(
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
)?
|
||||
.expect("newer tokens should remain stored");
|
||||
assert_tokens_match_without_expiry(&loaded, &newer_tokens);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persistor_saves_newer_memory_credentials_despite_expires_in_drift() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let mut stored_tokens = sample_tokens();
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| Duration::from_secs(0));
|
||||
stored_tokens.expires_at = Some(now.as_millis() as u64 + 10_000);
|
||||
let expires_in = Duration::from_secs(3600);
|
||||
stored_tokens
|
||||
.token_response
|
||||
.0
|
||||
.set_expires_in(Some(&expires_in));
|
||||
super::save_oauth_tokens_to_file(&stored_tokens)?;
|
||||
|
||||
let persistor = OAuthPersistor::new(
|
||||
stored_tokens.server_name.clone(),
|
||||
OAuthHttpClient::new(
|
||||
Arc::new(FailingHttpClient),
|
||||
/*http_headers*/ None,
|
||||
/*env_http_headers*/ None,
|
||||
)?,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
Some(stored_tokens.clone()),
|
||||
);
|
||||
|
||||
let newer_tokens = sample_tokens_with_access_token("newer-access-token");
|
||||
{
|
||||
let mut state = persistor.inner.credentials_state.lock().await;
|
||||
state.current = Some(newer_tokens.clone());
|
||||
state.last_persisted = Some(stored_tokens);
|
||||
}
|
||||
|
||||
persistor.persist_if_needed().await?;
|
||||
|
||||
let loaded = super::load_oauth_tokens(
|
||||
&newer_tokens.server_name,
|
||||
&newer_tokens.url,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
)?
|
||||
.expect("newer tokens should be stored");
|
||||
assert_tokens_match_without_expiry(&loaded, &newer_tokens);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn access_token_uses_cached_token_when_refresh_fails_before_expiry() -> Result<()> {
|
||||
let mut tokens = sample_tokens();
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| Duration::from_secs(0));
|
||||
tokens.expires_at = Some(now.as_millis() as u64 + 10_000);
|
||||
|
||||
let persistor = OAuthPersistor::new(
|
||||
tokens.server_name.clone(),
|
||||
OAuthHttpClient::new(
|
||||
Arc::new(FailingHttpClient),
|
||||
/*http_headers*/ None,
|
||||
/*env_http_headers*/ None,
|
||||
)?,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
Some(tokens),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
persistor.access_token().await?,
|
||||
Some("access-token".to_string())
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn access_token_rejects_expired_token_without_refresh_token() -> Result<()> {
|
||||
let mut tokens = sample_tokens();
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| Duration::from_secs(0));
|
||||
tokens.expires_at = Some((now.as_millis() as u64).saturating_sub(1000));
|
||||
tokens.token_response.0.set_refresh_token(None);
|
||||
|
||||
let persistor = OAuthPersistor::new(
|
||||
tokens.server_name.clone(),
|
||||
OAuthHttpClient::new(
|
||||
Arc::new(FailingHttpClient),
|
||||
/*http_headers*/ None,
|
||||
/*env_http_headers*/ None,
|
||||
)?,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
Some(tokens),
|
||||
);
|
||||
|
||||
assert_eq!(persistor.access_token().await?, None);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct FailingHttpClient;
|
||||
|
||||
impl HttpClient for FailingHttpClient {
|
||||
fn http_request(
|
||||
&self,
|
||||
_params: HttpRequestParams,
|
||||
) -> BoxFuture<'_, std::result::Result<HttpRequestResponse, ExecServerError>> {
|
||||
Box::pin(async {
|
||||
Err(ExecServerError::Disconnected(
|
||||
"forced refresh failure".to_string(),
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn http_request_stream(
|
||||
&self,
|
||||
_params: HttpRequestParams,
|
||||
) -> BoxFuture<
|
||||
'_,
|
||||
std::result::Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>,
|
||||
> {
|
||||
Box::pin(async {
|
||||
Err(ExecServerError::Disconnected(
|
||||
"forced refresh failure".to_string(),
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_tokens_match_without_expiry(
|
||||
actual: &StoredOAuthTokens,
|
||||
expected: &StoredOAuthTokens,
|
||||
@@ -888,8 +1228,12 @@ mod tests {
|
||||
}
|
||||
|
||||
fn sample_tokens() -> StoredOAuthTokens {
|
||||
sample_tokens_with_access_token("access-token")
|
||||
}
|
||||
|
||||
fn sample_tokens_with_access_token(access_token: &str) -> StoredOAuthTokens {
|
||||
let mut response = OAuthTokenResponse::new(
|
||||
AccessToken::new("access-token".to_string()),
|
||||
AccessToken::new(access_token.to_string()),
|
||||
BasicTokenType::Bearer,
|
||||
EmptyExtraTokenFields {},
|
||||
);
|
||||
|
||||
@@ -7,9 +7,8 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use anyhow::bail;
|
||||
use reqwest::ClientBuilder;
|
||||
use codex_exec_server::HttpClient;
|
||||
use reqwest::Url;
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use tiny_http::Response;
|
||||
use tiny_http::Server;
|
||||
use tokio::sync::oneshot;
|
||||
@@ -18,10 +17,11 @@ use urlencoding::decode;
|
||||
|
||||
use crate::StoredOAuthTokens;
|
||||
use crate::WrappedOAuthTokenResponse;
|
||||
use crate::auth_status::no_proxy_http_client;
|
||||
use crate::mcp_oauth_http::OAuthAuthorizationSession;
|
||||
use crate::mcp_oauth_http::OAuthHttpClient;
|
||||
use crate::oauth::compute_expires_at_millis;
|
||||
use crate::save_oauth_tokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
use codex_config::types::OAuthCredentialsStoreMode;
|
||||
|
||||
struct OauthHeaders {
|
||||
@@ -80,6 +80,34 @@ pub async fn perform_oauth_login(
|
||||
oauth_resource: Option<&str>,
|
||||
callback_port: Option<u16>,
|
||||
callback_url: Option<&str>,
|
||||
) -> Result<()> {
|
||||
perform_oauth_login_with_client(
|
||||
server_name,
|
||||
server_url,
|
||||
store_mode,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
scopes,
|
||||
oauth_resource,
|
||||
callback_port,
|
||||
callback_url,
|
||||
no_proxy_http_client(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn perform_oauth_login_with_client(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
scopes: &[String],
|
||||
oauth_resource: Option<&str>,
|
||||
callback_port: Option<u16>,
|
||||
callback_url: Option<&str>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
) -> Result<()> {
|
||||
perform_oauth_login_with_browser_output(
|
||||
server_name,
|
||||
@@ -92,6 +120,7 @@ pub async fn perform_oauth_login(
|
||||
callback_port,
|
||||
callback_url,
|
||||
/*emit_browser_url*/ true,
|
||||
http_client,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -107,6 +136,34 @@ pub async fn perform_oauth_login_silent(
|
||||
oauth_resource: Option<&str>,
|
||||
callback_port: Option<u16>,
|
||||
callback_url: Option<&str>,
|
||||
) -> Result<()> {
|
||||
perform_oauth_login_silent_with_client(
|
||||
server_name,
|
||||
server_url,
|
||||
store_mode,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
scopes,
|
||||
oauth_resource,
|
||||
callback_port,
|
||||
callback_url,
|
||||
no_proxy_http_client(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn perform_oauth_login_silent_with_client(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
scopes: &[String],
|
||||
oauth_resource: Option<&str>,
|
||||
callback_port: Option<u16>,
|
||||
callback_url: Option<&str>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
) -> Result<()> {
|
||||
perform_oauth_login_with_browser_output(
|
||||
server_name,
|
||||
@@ -119,6 +176,7 @@ pub async fn perform_oauth_login_silent(
|
||||
callback_port,
|
||||
callback_url,
|
||||
/*emit_browser_url*/ false,
|
||||
http_client,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -135,6 +193,7 @@ async fn perform_oauth_login_with_browser_output(
|
||||
callback_port: Option<u16>,
|
||||
callback_url: Option<&str>,
|
||||
emit_browser_url: bool,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
) -> Result<()> {
|
||||
let headers = OauthHeaders {
|
||||
http_headers,
|
||||
@@ -151,6 +210,7 @@ async fn perform_oauth_login_with_browser_output(
|
||||
callback_port,
|
||||
callback_url,
|
||||
/*timeout_secs*/ None,
|
||||
http_client,
|
||||
)
|
||||
.await?
|
||||
.finish(emit_browser_url)
|
||||
@@ -169,6 +229,36 @@ pub async fn perform_oauth_login_return_url(
|
||||
timeout_secs: Option<i64>,
|
||||
callback_port: Option<u16>,
|
||||
callback_url: Option<&str>,
|
||||
) -> Result<OauthLoginHandle> {
|
||||
perform_oauth_login_return_url_with_client(
|
||||
server_name,
|
||||
server_url,
|
||||
store_mode,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
scopes,
|
||||
oauth_resource,
|
||||
timeout_secs,
|
||||
callback_port,
|
||||
callback_url,
|
||||
no_proxy_http_client(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn perform_oauth_login_return_url_with_client(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
scopes: &[String],
|
||||
oauth_resource: Option<&str>,
|
||||
timeout_secs: Option<i64>,
|
||||
callback_port: Option<u16>,
|
||||
callback_url: Option<&str>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
) -> Result<OauthLoginHandle> {
|
||||
let headers = OauthHeaders {
|
||||
http_headers,
|
||||
@@ -185,6 +275,7 @@ pub async fn perform_oauth_login_return_url(
|
||||
callback_port,
|
||||
callback_url,
|
||||
timeout_secs,
|
||||
http_client,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -329,7 +420,8 @@ impl OauthLoginHandle {
|
||||
|
||||
struct OauthLoginFlow {
|
||||
auth_url: String,
|
||||
oauth_state: OAuthState,
|
||||
oauth_http: OAuthHttpClient,
|
||||
oauth_session: Option<OAuthAuthorizationSession>,
|
||||
rx: oneshot::Receiver<CallbackResult>,
|
||||
guard: CallbackServerGuard,
|
||||
server_name: String,
|
||||
@@ -412,6 +504,7 @@ impl OauthLoginFlow {
|
||||
callback_port: Option<u16>,
|
||||
callback_url: Option<&str>,
|
||||
timeout_secs: Option<i64>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
) -> Result<Self> {
|
||||
const DEFAULT_OAUTH_TIMEOUT_SECS: i64 = 300;
|
||||
|
||||
@@ -437,25 +530,19 @@ impl OauthLoginFlow {
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} = headers;
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
|
||||
|
||||
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
|
||||
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
|
||||
oauth_state
|
||||
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
|
||||
let oauth_http = OAuthHttpClient::new(http_client, http_headers, env_http_headers)?;
|
||||
let oauth_session = oauth_http
|
||||
.start_authorization(server_url, scopes, &redirect_uri, "Codex")
|
||||
.await?;
|
||||
let auth_url = append_query_param(
|
||||
&oauth_state.get_authorization_url().await?,
|
||||
"resource",
|
||||
oauth_resource,
|
||||
);
|
||||
let auth_url =
|
||||
append_query_param(&oauth_session.authorization_url, "resource", oauth_resource);
|
||||
let timeout_secs = timeout_secs.unwrap_or(DEFAULT_OAUTH_TIMEOUT_SECS).max(1);
|
||||
let timeout = Duration::from_secs(timeout_secs as u64);
|
||||
|
||||
Ok(Self {
|
||||
auth_url,
|
||||
oauth_state,
|
||||
oauth_http,
|
||||
oauth_session: Some(oauth_session),
|
||||
rx,
|
||||
guard,
|
||||
server_name: server_name.to_string(),
|
||||
@@ -503,19 +590,16 @@ impl OauthLoginFlow {
|
||||
CallbackResult::Error(error) => return Err(anyhow!(error)),
|
||||
};
|
||||
|
||||
self.oauth_state
|
||||
.handle_callback(&code, &csrf_state)
|
||||
let oauth_session = self
|
||||
.oauth_session
|
||||
.take()
|
||||
.ok_or_else(|| anyhow!("OAuth login flow was already completed"))?;
|
||||
let (client_id, credentials) = self
|
||||
.oauth_http
|
||||
.exchange_code(oauth_session, &code, &csrf_state)
|
||||
.await
|
||||
.context("failed to handle OAuth callback")?;
|
||||
|
||||
let (client_id, credentials_opt) = self
|
||||
.oauth_state
|
||||
.get_credentials()
|
||||
.await
|
||||
.context("failed to retrieve OAuth credentials")?;
|
||||
let credentials = credentials_opt
|
||||
.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
|
||||
|
||||
let expires_at = compute_expires_at_millis(&credentials);
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: self.server_name.clone(),
|
||||
|
||||
@@ -12,7 +12,6 @@ use std::time::Instant;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_api::SharedAuthProvider;
|
||||
use codex_client::build_reqwest_client_with_custom_ca;
|
||||
use codex_config::types::McpServerEnvVar;
|
||||
use codex_exec_server::HttpClient;
|
||||
use futures::FutureExt;
|
||||
@@ -45,9 +44,6 @@ use rmcp::service::RoleClient;
|
||||
use rmcp::service::RunningService;
|
||||
use rmcp::service::{self};
|
||||
use rmcp::transport::StreamableHttpClientTransport;
|
||||
use rmcp::transport::auth::AuthClient;
|
||||
use rmcp::transport::auth::AuthError;
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpError;
|
||||
use serde::Deserialize;
|
||||
@@ -63,12 +59,12 @@ use crate::elicitation_client_service::ElicitationClientService;
|
||||
use crate::http_client_adapter::StreamableHttpClientAdapter;
|
||||
use crate::http_client_adapter::StreamableHttpClientAdapterError;
|
||||
use crate::load_oauth_tokens;
|
||||
use crate::mcp_oauth_http::OAuthHttpClient;
|
||||
use crate::oauth::OAuthPersistor;
|
||||
use crate::oauth::StoredOAuthTokens;
|
||||
use crate::stdio_server_launcher::StdioServerCommand;
|
||||
use crate::stdio_server_launcher::StdioServerLauncher;
|
||||
use crate::stdio_server_launcher::StdioServerTransport;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
use codex_config::types::OAuthCredentialsStoreMode;
|
||||
|
||||
@@ -80,7 +76,7 @@ enum PendingTransport {
|
||||
transport: StreamableHttpClientTransport<StreamableHttpClientAdapter>,
|
||||
},
|
||||
StreamableHttpWithOAuth {
|
||||
transport: StreamableHttpClientTransport<AuthClient<StreamableHttpClientAdapter>>,
|
||||
transport: StreamableHttpClientTransport<StreamableHttpClientAdapter>,
|
||||
oauth_persistor: OAuthPersistor,
|
||||
},
|
||||
}
|
||||
@@ -708,11 +704,7 @@ impl RmcpClient {
|
||||
oauth_persistor,
|
||||
})
|
||||
}
|
||||
Err(err)
|
||||
if err.downcast_ref::<AuthError>().is_some_and(|auth_err| {
|
||||
matches!(auth_err, AuthError::NoAuthorizationSupport)
|
||||
}) =>
|
||||
{
|
||||
Err(err) if err.to_string().contains("No authorization support") => {
|
||||
let access_token = initial_tokens
|
||||
.token_response
|
||||
.0
|
||||
@@ -730,6 +722,7 @@ impl RmcpClient {
|
||||
Arc::clone(http_client),
|
||||
default_headers,
|
||||
/*auth_provider*/ None,
|
||||
/*oauth_persistor*/ None,
|
||||
),
|
||||
http_config,
|
||||
);
|
||||
@@ -749,6 +742,7 @@ impl RmcpClient {
|
||||
Arc::clone(http_client),
|
||||
default_headers,
|
||||
auth_provider.clone(),
|
||||
/*oauth_persistor*/ None,
|
||||
),
|
||||
http_config,
|
||||
);
|
||||
@@ -943,52 +937,32 @@ async fn create_oauth_transport_and_runtime(
|
||||
default_headers: HeaderMap,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
) -> Result<(
|
||||
StreamableHttpClientTransport<AuthClient<StreamableHttpClientAdapter>>,
|
||||
StreamableHttpClientTransport<StreamableHttpClientAdapter>,
|
||||
OAuthPersistor,
|
||||
)> {
|
||||
let builder = apply_default_headers(reqwest::Client::builder(), &default_headers);
|
||||
let oauth_metadata_client = build_reqwest_client_with_custom_ca(builder)?;
|
||||
// TODO(aibrahim): teach OAuth bootstrap and refresh to use the same
|
||||
// shared HTTP client abstraction instead of always creating the local
|
||||
// reqwest metadata client here.
|
||||
let mut oauth_state =
|
||||
OAuthState::new(url.to_string(), Some(oauth_metadata_client.clone())).await?;
|
||||
let oauth_http =
|
||||
OAuthHttpClient::from_default_headers(Arc::clone(&http_client), default_headers.clone());
|
||||
if oauth_http.discover(url).await?.is_none() {
|
||||
anyhow::bail!("No authorization support detected");
|
||||
}
|
||||
|
||||
oauth_state
|
||||
.set_credentials(
|
||||
&initial_tokens.client_id,
|
||||
initial_tokens.token_response.0.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let manager = match oauth_state {
|
||||
OAuthState::Authorized(manager) => manager,
|
||||
OAuthState::Unauthorized(manager) => manager,
|
||||
OAuthState::Session(_) | OAuthState::AuthorizedHttpClient(_) => {
|
||||
return Err(anyhow!("unexpected OAuth state during client setup"));
|
||||
}
|
||||
};
|
||||
|
||||
let auth_client = AuthClient::new(
|
||||
StreamableHttpClientAdapter::new(http_client, default_headers, /*auth_provider*/ None),
|
||||
manager,
|
||||
);
|
||||
let auth_manager = auth_client.auth_manager.clone();
|
||||
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
auth_client,
|
||||
StreamableHttpClientTransportConfig::with_uri(url.to_string()),
|
||||
);
|
||||
|
||||
let runtime = OAuthPersistor::new(
|
||||
let oauth_persistor = OAuthPersistor::new(
|
||||
server_name.to_string(),
|
||||
url.to_string(),
|
||||
auth_manager,
|
||||
oauth_http,
|
||||
credentials_store,
|
||||
Some(initial_tokens),
|
||||
);
|
||||
|
||||
Ok((transport, runtime))
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
StreamableHttpClientAdapter::new(
|
||||
http_client,
|
||||
default_headers,
|
||||
/*auth_provider*/ None,
|
||||
Some(oauth_persistor.clone()),
|
||||
),
|
||||
StreamableHttpClientTransportConfig::with_uri(url.to_string()),
|
||||
);
|
||||
Ok((transport, oauth_persistor))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_config::types::McpServerEnvVar;
|
||||
use reqwest::ClientBuilder;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderName;
|
||||
use reqwest::header::HeaderValue;
|
||||
@@ -115,17 +114,6 @@ pub(crate) fn build_default_headers(
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
pub(crate) fn apply_default_headers(
|
||||
builder: ClientBuilder,
|
||||
default_headers: &HeaderMap,
|
||||
) -> ClientBuilder {
|
||||
if default_headers.is_empty() {
|
||||
builder
|
||||
} else {
|
||||
builder.default_headers(default_headers.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
|
||||
"HOME",
|
||||
|
||||
309
codex-rs/rmcp-client/tests/streamable_http_remote_oauth.rs
Normal file
309
codex-rs/rmcp-client/tests/streamable_http_remote_oauth.rs
Normal file
@@ -0,0 +1,309 @@
|
||||
//! Regression coverage for remote-environment Streamable HTTP OAuth.
|
||||
//!
|
||||
//! The OAuth issuer in this test uses an unresolvable hostname. If any
|
||||
//! discovery, registration, or token-exchange request bypasses the injected
|
||||
//! `HttpClient`, the test fails before the callback can complete.
|
||||
|
||||
mod streamable_http_test_support;
|
||||
|
||||
use std::ffi::OsString;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::MutexGuard;
|
||||
use std::sync::OnceLock;
|
||||
use std::sync::PoisonError;
|
||||
|
||||
use codex_config::types::OAuthCredentialsStoreMode;
|
||||
use codex_exec_server::ExecServerError;
|
||||
use codex_exec_server::HttpClient;
|
||||
use codex_exec_server::HttpHeader;
|
||||
use codex_exec_server::HttpRequestParams;
|
||||
use codex_exec_server::HttpRequestResponse;
|
||||
use codex_exec_server::HttpResponseBodyStream;
|
||||
use codex_rmcp_client::StoredOAuthTokens;
|
||||
use codex_rmcp_client::WrappedOAuthTokenResponse;
|
||||
use codex_rmcp_client::perform_oauth_login_return_url_with_client;
|
||||
use codex_rmcp_client::save_oauth_tokens;
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::BoxFuture;
|
||||
use oauth2::AccessToken;
|
||||
use oauth2::EmptyExtraTokenFields;
|
||||
use oauth2::RefreshToken;
|
||||
use oauth2::basic::BasicTokenType;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rmcp::transport::auth::OAuthTokenResponse;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use serial_test::serial;
|
||||
use tempfile::TempDir;
|
||||
|
||||
use streamable_http_test_support::call_echo_tool;
|
||||
use streamable_http_test_support::create_remote_oauth_client;
|
||||
use streamable_http_test_support::expected_echo_result;
|
||||
use streamable_http_test_support::spawn_exec_server;
|
||||
use streamable_http_test_support::spawn_streamable_http_server_with_oauth_bearer;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct RecordingHttpClient {
|
||||
requests: Arc<Mutex<Vec<HttpRequestParams>>>,
|
||||
}
|
||||
|
||||
impl RecordingHttpClient {
|
||||
fn recorded_requests(&self) -> Vec<HttpRequestParams> {
|
||||
self.requests
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpClient for RecordingHttpClient {
|
||||
fn http_request(
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>> {
|
||||
let requests = Arc::clone(&self.requests);
|
||||
async move {
|
||||
requests
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.push(params.clone());
|
||||
|
||||
let response = match (params.method.as_str(), params.url.as_str()) {
|
||||
("GET", "http://oauth.test/.well-known/oauth-authorization-server/mcp") => {
|
||||
json_response(json!({
|
||||
"authorization_endpoint": "http://oauth.test/authorize",
|
||||
"token_endpoint": "http://oauth.test/token",
|
||||
"registration_endpoint": "http://oauth.test/register",
|
||||
"scopes_supported": ["tools.read", "tools.write"],
|
||||
"response_types_supported": ["code"],
|
||||
}))?
|
||||
}
|
||||
("POST", "http://oauth.test/register") => json_response(json!({
|
||||
"client_id": "registered-client",
|
||||
}))?,
|
||||
("POST", "http://oauth.test/token") => json_response(json!({
|
||||
"access_token": "remote-access-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "remote-refresh-token",
|
||||
}))?,
|
||||
_ => {
|
||||
return Err(ExecServerError::HttpRequest(format!(
|
||||
"unexpected HTTP request: {} {}",
|
||||
params.method, params.url
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn http_request_stream(
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> {
|
||||
async move {
|
||||
Err(ExecServerError::HttpRequest(format!(
|
||||
"unexpected streaming HTTP request: {} {}",
|
||||
params.method, params.url
|
||||
)))
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
fn json_response(body: Value) -> Result<HttpRequestResponse, ExecServerError> {
|
||||
let body = serde_json::to_vec(&body)
|
||||
.map_err(|err| ExecServerError::HttpRequest(format!("serialize JSON response: {err}")))?;
|
||||
Ok(HttpRequestResponse {
|
||||
status: 200,
|
||||
headers: vec![HttpHeader {
|
||||
name: "content-type".to_string(),
|
||||
value: "application/json".to_string(),
|
||||
}],
|
||||
body: body.into(),
|
||||
})
|
||||
}
|
||||
|
||||
struct TempCodexHome {
|
||||
_guard: MutexGuard<'static, ()>,
|
||||
previous: Option<OsString>,
|
||||
_dir: TempDir,
|
||||
}
|
||||
|
||||
impl TempCodexHome {
|
||||
fn new() -> anyhow::Result<Self> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
let guard = LOCK
|
||||
.get_or_init(Mutex::default)
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner);
|
||||
let previous = std::env::var_os("CODEX_HOME");
|
||||
let dir = TempDir::new()?;
|
||||
unsafe {
|
||||
std::env::set_var("CODEX_HOME", dir.path());
|
||||
}
|
||||
Ok(Self {
|
||||
_guard: guard,
|
||||
previous,
|
||||
_dir: dir,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TempCodexHome {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
match self.previous.as_ref() {
|
||||
Some(value) => std::env::set_var("CODEX_HOME", value),
|
||||
None => std::env::remove_var("CODEX_HOME"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[serial]
|
||||
async fn browser_callback_flow_uses_injected_http_client_for_oauth_requests() -> anyhow::Result<()>
|
||||
{
|
||||
let _codex_home = TempCodexHome::new()?;
|
||||
let http_client = RecordingHttpClient::default();
|
||||
|
||||
let handle = perform_oauth_login_return_url_with_client(
|
||||
"remote-oauth-test",
|
||||
"http://oauth.test/mcp",
|
||||
OAuthCredentialsStoreMode::File,
|
||||
/*http_headers*/ None,
|
||||
/*env_http_headers*/ None,
|
||||
&["tools.read".to_string()],
|
||||
/*oauth_resource*/ None,
|
||||
Some(10),
|
||||
/*callback_port*/ None,
|
||||
/*callback_url*/ None,
|
||||
Arc::new(http_client.clone()),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let authorization_url = reqwest::Url::parse(handle.authorization_url())?;
|
||||
let mut state = None;
|
||||
let mut redirect_uri = None;
|
||||
for (name, value) in authorization_url.query_pairs() {
|
||||
match name.as_ref() {
|
||||
"state" => state = Some(value.into_owned()),
|
||||
"redirect_uri" => redirect_uri = Some(value.into_owned()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let state = state.expect("authorization URL includes state");
|
||||
let redirect_uri = redirect_uri.expect("authorization URL includes redirect_uri");
|
||||
let callback_url = format!("{redirect_uri}?code=provider-code&state={state}");
|
||||
let callback_response = reqwest::get(callback_url).await?;
|
||||
assert_eq!(callback_response.status(), reqwest::StatusCode::OK);
|
||||
|
||||
handle.wait().await?;
|
||||
|
||||
let requests = http_client.recorded_requests();
|
||||
assert_eq!(
|
||||
requests
|
||||
.iter()
|
||||
.map(|request| (request.method.as_str(), request.url.as_str()))
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
(
|
||||
"GET",
|
||||
"http://oauth.test/.well-known/oauth-authorization-server/mcp"
|
||||
),
|
||||
("POST", "http://oauth.test/register"),
|
||||
("POST", "http://oauth.test/token"),
|
||||
]
|
||||
);
|
||||
|
||||
let registration_body: Value = serde_json::from_slice(
|
||||
&requests[1]
|
||||
.body
|
||||
.as_ref()
|
||||
.expect("registration request has a body")
|
||||
.0,
|
||||
)?;
|
||||
assert_eq!(
|
||||
registration_body,
|
||||
json!({
|
||||
"client_name": "Codex",
|
||||
"redirect_uris": [redirect_uri],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"token_endpoint_auth_method": "none",
|
||||
"response_types": ["code"],
|
||||
})
|
||||
);
|
||||
|
||||
let token_body = String::from_utf8(
|
||||
requests[2]
|
||||
.body
|
||||
.as_ref()
|
||||
.expect("token request has a body")
|
||||
.0
|
||||
.clone(),
|
||||
)?;
|
||||
assert!(token_body.contains("grant_type=authorization_code"));
|
||||
assert!(token_body.contains("code=provider-code"));
|
||||
assert!(token_body.contains("client_id=registered-client"));
|
||||
assert!(token_body.contains("code_verifier="));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[serial]
|
||||
async fn stored_oauth_refreshes_and_authenticates_remote_streamable_http() -> anyhow::Result<()> {
|
||||
let _codex_home = TempCodexHome::new()?;
|
||||
let refresh_token = "remote-refresh-token";
|
||||
let refreshed_access_token = "refreshed-remote-access-token";
|
||||
let (_server, base_url) =
|
||||
spawn_streamable_http_server_with_oauth_bearer(refreshed_access_token, refresh_token)
|
||||
.await?;
|
||||
let exec_server = spawn_exec_server().await?;
|
||||
|
||||
save_oauth_tokens(
|
||||
"test-streamable-http-remote-oauth",
|
||||
&expired_oauth_tokens(
|
||||
"test-streamable-http-remote-oauth",
|
||||
&format!("{base_url}/mcp"),
|
||||
"expired-access-token",
|
||||
refresh_token,
|
||||
),
|
||||
OAuthCredentialsStoreMode::File,
|
||||
)?;
|
||||
|
||||
let client = create_remote_oauth_client(&base_url, exec_server.client.clone()).await?;
|
||||
let result = call_echo_tool(&client, "remote-oauth").await?;
|
||||
|
||||
assert_eq!(result, expected_echo_result("remote-oauth"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn expired_oauth_tokens(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
access_token: &str,
|
||||
refresh_token: &str,
|
||||
) -> StoredOAuthTokens {
|
||||
let mut token_response = OAuthTokenResponse::new(
|
||||
AccessToken::new(access_token.to_string()),
|
||||
BasicTokenType::Bearer,
|
||||
EmptyExtraTokenFields {},
|
||||
);
|
||||
token_response.set_refresh_token(Some(RefreshToken::new(refresh_token.to_string())));
|
||||
|
||||
StoredOAuthTokens {
|
||||
server_name: server_name.to_string(),
|
||||
url: url.to_string(),
|
||||
client_id: "stored-client".to_string(),
|
||||
token_response: WrappedOAuthTokenResponse(token_response),
|
||||
expires_at: Some(0),
|
||||
}
|
||||
}
|
||||
@@ -128,10 +128,40 @@ pub(crate) async fn create_remote_client(
|
||||
base_url: &str,
|
||||
http_client: ExecServerClient,
|
||||
) -> anyhow::Result<RmcpClient> {
|
||||
let client = RmcpClient::new_streamable_http_client(
|
||||
create_remote_client_with_bearer(
|
||||
"test-streamable-http-remote",
|
||||
&format!("{base_url}/mcp"),
|
||||
base_url,
|
||||
Some("test-bearer".to_string()),
|
||||
http_client,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Creates a Streamable HTTP RMCP client that authenticates using stored OAuth
|
||||
/// credentials through the remote runtime HTTP API.
|
||||
pub(crate) async fn create_remote_oauth_client(
|
||||
base_url: &str,
|
||||
http_client: ExecServerClient,
|
||||
) -> anyhow::Result<RmcpClient> {
|
||||
create_remote_client_with_bearer(
|
||||
"test-streamable-http-remote-oauth",
|
||||
base_url,
|
||||
/*bearer_token*/ None,
|
||||
http_client,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn create_remote_client_with_bearer(
|
||||
server_name: &str,
|
||||
base_url: &str,
|
||||
bearer_token: Option<String>,
|
||||
http_client: ExecServerClient,
|
||||
) -> anyhow::Result<RmcpClient> {
|
||||
let client = RmcpClient::new_streamable_http_client(
|
||||
server_name,
|
||||
&format!("{base_url}/mcp"),
|
||||
bearer_token,
|
||||
/*http_headers*/ None,
|
||||
/*env_http_headers*/ None,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
@@ -193,16 +223,38 @@ pub(crate) async fn arm_session_post_failure(
|
||||
}
|
||||
|
||||
pub(crate) async fn spawn_streamable_http_server() -> anyhow::Result<(Child, String)> {
|
||||
spawn_streamable_http_server_with_env(&[]).await
|
||||
}
|
||||
|
||||
pub(crate) async fn spawn_streamable_http_server_with_oauth_bearer(
|
||||
bearer_token: &str,
|
||||
refresh_token: &str,
|
||||
) -> anyhow::Result<(Child, String)> {
|
||||
spawn_streamable_http_server_with_env(&[
|
||||
("MCP_EXPECT_BEARER", bearer_token),
|
||||
("MCP_OAUTH_ACCESS_TOKEN", bearer_token),
|
||||
("MCP_OAUTH_REFRESH_TOKEN", refresh_token),
|
||||
])
|
||||
.await
|
||||
}
|
||||
|
||||
async fn spawn_streamable_http_server_with_env(
|
||||
env: &[(&str, &str)],
|
||||
) -> anyhow::Result<(Child, String)> {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||||
let port = listener.local_addr()?.port();
|
||||
drop(listener);
|
||||
|
||||
let bind_addr = format!("127.0.0.1:{port}");
|
||||
let base_url = format!("http://{bind_addr}");
|
||||
let mut child = Command::new(streamable_http_server_bin()?)
|
||||
let mut command = Command::new(streamable_http_server_bin()?);
|
||||
command
|
||||
.kill_on_drop(true)
|
||||
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
|
||||
.spawn()?;
|
||||
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr);
|
||||
for (name, value) in env {
|
||||
command.env(name, value);
|
||||
}
|
||||
let mut child = command.spawn()?;
|
||||
|
||||
wait_for_streamable_http_server(&mut child, &bind_addr, Duration::from_secs(5)).await?;
|
||||
Ok((child, base_url))
|
||||
|
||||
Reference in New Issue
Block a user