Compare commits

...

12 Commits

Author SHA1 Message Date
Ahmed Ibrahim
f8e2beeafa Address OAuth persistence review feedback
Keep local OAuth login flows on the no-proxy client while remote flows continue to use the runtime-selected HTTP client. Harden OAuth persistence so expires_in drift does not look like an external token update and transient keyring read failures do not clear in-memory credentials.

Co-authored-by: Codex <noreply@openai.com>
2026-04-27 11:58:06 +00:00
Ahmed Ibrahim
35080d5cd0 Address MCP OAuth review feedback
Restore stored-token bearer fallback by probing OAuth metadata through the injected runtime HTTP client before creating the OAuth transport. Also broaden discovery to OpenID Connect and protected-resource metadata paths, and avoid writing stale refresh results over newer in-memory credentials.

Co-authored-by: Codex <noreply@openai.com>
2026-04-27 11:34:20 +00:00
Ahmed Ibrahim
9df8e3d935 Use no-proxy client for local OAuth preflight
Co-authored-by: Codex <noreply@openai.com>
2026-04-27 11:13:44 +00:00
Ahmed Ibrahim
da1000bbb4 Use no-proxy client for local MCP auth status
Co-authored-by: Codex <noreply@openai.com>
2026-04-27 11:01:49 +00:00
Ahmed Ibrahim
ec6227c6cd Stabilize OAuth persistence test
Co-authored-by: Codex <noreply@openai.com>
2026-04-27 10:45:14 +00:00
Ahmed Ibrahim
520ee70b73 Address remote OAuth review feedback
Co-authored-by: Codex <noreply@openai.com>
2026-04-27 10:38:51 +00:00
Ahmed Ibrahim
276a97340c Fix OAuth argument comments
Co-authored-by: Codex <noreply@openai.com>
2026-04-27 10:25:29 +00:00
Ahmed Ibrahim
2b68005d97 Remove stale app-server OAuth dependency
Drop the direct codex-rmcp-client dependency from app-server now that OAuth orchestration is routed through codex-mcp.

Co-authored-by: Codex <noreply@openai.com>
2026-04-27 10:14:35 +00:00
Ahmed Ibrahim
c2a2ffa512 Address MCP OAuth review feedback
Use the runtime environment manager for MCP list auth status, keep cached tokens available when refresh fails before expiry, avoid restoring deleted OAuth credentials, and serialize refresh attempts.

Co-authored-by: Codex <noreply@openai.com>
2026-04-27 10:11:16 +00:00
Ahmed Ibrahim
a3abb8ba6e Move MCP OAuth orchestration downstream
Keep app-server OAuth callsites thin by moving server/runtime-aware OAuth discovery, scope resolution, and runtime HTTP client selection into codex-mcp helpers.

Co-authored-by: Codex <noreply@openai.com>
2026-04-27 10:03:34 +00:00
Ahmed Ibrahim
a4de53c661 Remove unused rmcp client dependency
Drop the stale codex-client dependency from codex-rmcp-client after the OAuth transport changes stopped using that crate.

Co-authored-by: Codex <noreply@openai.com>
2026-04-27 09:38:20 +00:00
Ahmed Ibrahim
177e39f4b7 Add remote runtime OAuth for MCP
Route Streamable HTTP OAuth discovery, registration, token exchange, and refresh through the selected MCP runtime HTTP client so remote MCP servers do not silently fall back to local network access. Keep browser launch and callback handling local while storing tokens in the existing local MCP OAuth store.

Add regression coverage for injected OAuth HTTP requests and stored-token refresh through the exec-server transport.

Co-authored-by: Codex <noreply@openai.com>
2026-04-27 09:33:42 +00:00
25 changed files with 2504 additions and 554 deletions

2
codex-rs/Cargo.lock generated
View File

@@ -1864,7 +1864,6 @@ dependencies = [
"codex-models-manager",
"codex-otel",
"codex-protocol",
"codex-rmcp-client",
"codex-rollout",
"codex-sandboxing",
"codex-shell-command",
@@ -3168,7 +3167,6 @@ dependencies = [
"axum",
"bytes",
"codex-api",
"codex-client",
"codex-config",
"codex-exec-server",
"codex-keyring-store",

View File

@@ -54,7 +54,6 @@ codex-models-manager = { workspace = true }
codex-protocol = { workspace = true }
codex-app-server-protocol = { workspace = true }
codex-feedback = { workspace = true }
codex-rmcp-client = { workspace = true }
codex-rollout = { workspace = true }
codex-sandboxing = { workspace = true }
codex-state = { workspace = true }

View File

@@ -310,10 +310,9 @@ use codex_mcp::McpRuntimeEnvironment;
use codex_mcp::McpServerStatusSnapshot;
use codex_mcp::McpSnapshotDetail;
use codex_mcp::collect_mcp_server_status_snapshot_with_detail;
use codex_mcp::discover_supported_scopes;
use codex_mcp::effective_mcp_servers;
use codex_mcp::perform_oauth_login_return_url_for_server;
use codex_mcp::read_mcp_resource as read_mcp_resource_without_thread;
use codex_mcp::resolve_oauth_scopes;
use codex_model_provider::ProviderAccountError;
use codex_model_provider::create_model_provider;
use codex_models_manager::collaboration_mode_presets::CollaborationModesConfig;
@@ -355,7 +354,6 @@ use codex_protocol::protocol::USER_MESSAGE_BEGIN;
use codex_protocol::protocol::W3cTraceContext;
use codex_protocol::user_input::MAX_USER_INPUT_TEXT_CHARS;
use codex_protocol::user_input::UserInput as CoreInputItem;
use codex_rmcp_client::perform_oauth_login_return_url;
use codex_rollout::state_db::StateDbHandle;
use codex_rollout::state_db::get_state_db;
use codex_rollout::state_db::reconcile_rollout;
@@ -5941,13 +5939,8 @@ impl CodexMessageProcessor {
return;
};
let (url, http_headers, env_http_headers) = match &server.transport {
McpServerTransportConfig::StreamableHttp {
url,
http_headers,
env_http_headers,
..
} => (url.clone(), http_headers.clone(), env_http_headers.clone()),
match &server.transport {
McpServerTransportConfig::StreamableHttp { .. } => {}
_ => {
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
@@ -5960,25 +5953,24 @@ impl CodexMessageProcessor {
}
};
let discovered_scopes = if scopes.is_none() && server.scopes.is_none() {
discover_supported_scopes(&server.transport).await
} else {
None
let environment_manager = self.thread_manager.environment_manager();
let runtime_environment = match environment_manager.default_environment() {
Some(environment) => McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf()),
None => McpRuntimeEnvironment::new(
environment_manager.local_environment(),
config.cwd.to_path_buf(),
),
};
let resolved_scopes =
resolve_oauth_scopes(scopes, server.scopes.clone(), discovered_scopes);
match perform_oauth_login_return_url(
match perform_oauth_login_return_url_for_server(
&name,
&url,
server,
config.mcp_oauth_credentials_store_mode,
http_headers,
env_http_headers,
&resolved_scopes.scopes,
server.oauth_resource.as_deref(),
scopes,
timeout_secs,
config.mcp_oauth_callback_port,
config.mcp_oauth_callback_url.as_deref(),
runtime_environment,
)
.await
{

View File

@@ -5,12 +5,9 @@ use codex_app_server_protocol::McpServerOauthLoginCompletedNotification;
use codex_app_server_protocol::ServerNotification;
use codex_config::types::McpServerConfig;
use codex_core::config::Config;
use codex_mcp::McpOAuthLoginSupport;
use codex_mcp::oauth_login_support;
use codex_mcp::resolve_oauth_scopes;
use codex_mcp::should_retry_without_scopes;
use codex_rmcp_client::perform_oauth_login_silent;
use tracing::warn;
use codex_mcp::McpOAuthLoginOutcome;
use codex_mcp::McpRuntimeEnvironment;
use codex_mcp::perform_oauth_login_silent_for_server;
use super::CodexMessageProcessor;
@@ -21,23 +18,17 @@ impl CodexMessageProcessor {
plugin_mcp_servers: HashMap<String, McpServerConfig>,
) {
for (name, server) in plugin_mcp_servers {
let oauth_config = match oauth_login_support(&server.transport).await {
McpOAuthLoginSupport::Supported(config) => config,
McpOAuthLoginSupport::Unsupported => continue,
McpOAuthLoginSupport::Unknown(err) => {
warn!(
"MCP server may or may not require login for plugin install {name}: {err}"
);
continue;
let environment_manager = self.thread_manager.environment_manager();
let runtime_environment = match environment_manager.default_environment() {
Some(environment) => {
McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf())
}
None => McpRuntimeEnvironment::new(
environment_manager.local_environment(),
config.cwd.to_path_buf(),
),
};
let resolved_scopes = resolve_oauth_scopes(
/*explicit_scopes*/ None,
server.scopes.clone(),
oauth_config.discovered_scopes.clone(),
);
let store_mode = config.mcp_oauth_credentials_store_mode;
let callback_port = config.mcp_oauth_callback_port;
let callback_url = config.mcp_oauth_callback_url.clone();
@@ -45,39 +36,20 @@ impl CodexMessageProcessor {
let notification_name = name.clone();
tokio::spawn(async move {
let first_attempt = perform_oauth_login_silent(
let final_result = perform_oauth_login_silent_for_server(
&name,
&oauth_config.url,
&server,
store_mode,
oauth_config.http_headers.clone(),
oauth_config.env_http_headers.clone(),
&resolved_scopes.scopes,
server.oauth_resource.as_deref(),
/*explicit_scopes*/ None,
callback_port,
callback_url.as_deref(),
runtime_environment,
)
.await;
let final_result = match first_attempt {
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
perform_oauth_login_silent(
&name,
&oauth_config.url,
store_mode,
oauth_config.http_headers,
oauth_config.env_http_headers,
&[],
server.oauth_resource.as_deref(),
callback_port,
callback_url.as_deref(),
)
.await
}
result => result,
};
let (success, error) = match final_result {
Ok(()) => (true, None),
Ok(McpOAuthLoginOutcome::Completed) => (true, None),
Ok(McpOAuthLoginOutcome::Unsupported) => return,
Err(err) => (false, Some(err.to_string())),
};

View File

@@ -15,16 +15,19 @@ use codex_core::config::edit::ConfigEditsBuilder;
use codex_core::config::find_codex_home;
use codex_core::config::load_global_mcp_servers;
use codex_core::plugins::PluginsManager;
use codex_exec_server::CODEX_EXEC_SERVER_URL_ENV_VAR;
use codex_exec_server::Environment;
use codex_exec_server::EnvironmentManager;
use codex_exec_server::EnvironmentManagerArgs;
use codex_exec_server::ExecServerRuntimePaths;
use codex_mcp::McpOAuthLoginOutcome;
use codex_mcp::McpOAuthLoginSupport;
use codex_mcp::ResolvedMcpOAuthScopes;
use codex_mcp::McpRuntimeEnvironment;
use codex_mcp::compute_auth_statuses;
use codex_mcp::discover_supported_scopes;
use codex_mcp::oauth_login_support;
use codex_mcp::resolve_oauth_scopes;
use codex_mcp::should_retry_without_scopes;
use codex_mcp::oauth_login_support_for_server;
use codex_mcp::perform_oauth_login_for_server;
use codex_protocol::protocol::McpAuthStatus;
use codex_rmcp_client::delete_oauth_tokens;
use codex_rmcp_client::perform_oauth_login;
use codex_utils_cli::CliConfigOverrides;
use codex_utils_cli::format_env_display;
@@ -188,54 +191,6 @@ impl McpCli {
}
}
/// Preserve compatibility with servers that still expect the legacy empty-scope
/// OAuth request. If a discovered-scope request is rejected by the provider,
/// retry the login flow once without scopes.
#[allow(clippy::too_many_arguments)]
async fn perform_oauth_login_retry_without_scopes(
name: &str,
url: &str,
store_mode: codex_config::types::OAuthCredentialsStoreMode,
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
resolved_scopes: &ResolvedMcpOAuthScopes,
oauth_resource: Option<&str>,
callback_port: Option<u16>,
callback_url: Option<&str>,
) -> Result<()> {
match perform_oauth_login(
name,
url,
store_mode,
http_headers.clone(),
env_http_headers.clone(),
&resolved_scopes.scopes,
oauth_resource,
callback_port,
callback_url,
)
.await
{
Ok(()) => Ok(()),
Err(err) if should_retry_without_scopes(resolved_scopes, &err) => {
println!("OAuth provider rejected discovered scopes. Retrying without scopes…");
perform_oauth_login(
name,
url,
store_mode,
http_headers,
env_http_headers,
&[],
oauth_resource,
callback_port,
callback_url,
)
.await
}
Err(err) => Err(err),
}
}
async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<()> {
// Validate any provided overrides even though they are not currently applied.
let overrides = config_overrides
@@ -297,7 +252,7 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re
};
let new_entry = McpServerConfig {
transport: transport.clone(),
transport,
experimental_environment: None,
enabled: true,
required: false,
@@ -313,7 +268,7 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re
tools: HashMap::new(),
};
servers.insert(name.clone(), new_entry);
servers.insert(name.clone(), new_entry.clone());
ConfigEditsBuilder::new(&codex_home)
.replace_mcp_servers(&servers)
@@ -323,27 +278,24 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re
println!("Added global MCP server '{name}'.");
match oauth_login_support(&transport).await {
McpOAuthLoginSupport::Supported(oauth_config) => {
let runtime_environment = runtime_environment_from_config(&config);
match oauth_login_support_for_server(&new_entry, runtime_environment.clone()).await {
McpOAuthLoginSupport::Supported(_) => {
println!("Detected OAuth support. Starting OAuth flow…");
let resolved_scopes = resolve_oauth_scopes(
/*explicit_scopes*/ None,
/*configured_scopes*/ None,
oauth_config.discovered_scopes.clone(),
);
perform_oauth_login_retry_without_scopes(
match perform_oauth_login_for_server(
&name,
&oauth_config.url,
&new_entry,
config.mcp_oauth_credentials_store_mode,
oauth_config.http_headers,
oauth_config.env_http_headers,
&resolved_scopes,
/*oauth_resource*/ None,
/*explicit_scopes*/ None,
config.mcp_oauth_callback_port,
config.mcp_oauth_callback_url.as_deref(),
runtime_environment,
)
.await?;
println!("Successfully logged in.");
.await?
{
McpOAuthLoginOutcome::Completed => println!("Successfully logged in."),
McpOAuthLoginOutcome::Unsupported => {}
}
}
McpOAuthLoginSupport::Unsupported => {}
McpOAuthLoginSupport::Unknown(_) => println!(
@@ -405,38 +357,32 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs)
bail!("No MCP server named '{name}' found.");
};
let (url, http_headers, env_http_headers) = match &server.transport {
McpServerTransportConfig::StreamableHttp {
url,
http_headers,
env_http_headers,
..
} => (url.clone(), http_headers.clone(), env_http_headers.clone()),
_ => bail!("OAuth login is only supported for streamable HTTP servers."),
if !matches!(
&server.transport,
McpServerTransportConfig::StreamableHttp { .. }
) {
bail!("OAuth login is only supported for streamable HTTP servers.");
};
let explicit_scopes = (!scopes.is_empty()).then_some(scopes);
let discovered_scopes = if explicit_scopes.is_none() && server.scopes.is_none() {
discover_supported_scopes(&server.transport).await
} else {
None
};
let resolved_scopes =
resolve_oauth_scopes(explicit_scopes, server.scopes.clone(), discovered_scopes);
perform_oauth_login_retry_without_scopes(
match perform_oauth_login_for_server(
&name,
&url,
server,
config.mcp_oauth_credentials_store_mode,
http_headers,
env_http_headers,
&resolved_scopes,
server.oauth_resource.as_deref(),
explicit_scopes,
config.mcp_oauth_callback_port,
config.mcp_oauth_callback_url.as_deref(),
runtime_environment_from_config(&config),
)
.await?;
println!("Successfully logged in to MCP server '{name}'.");
.await?
{
McpOAuthLoginOutcome::Completed => {
println!("Successfully logged in to MCP server '{name}'.")
}
McpOAuthLoginOutcome::Unsupported => {
bail!("MCP server '{name}' does not advertise OAuth support.")
}
}
Ok(())
}
@@ -472,6 +418,27 @@ async fn run_logout(config_overrides: &CliConfigOverrides, logout_args: LogoutAr
Ok(())
}
fn runtime_environment_from_config(config: &Config) -> McpRuntimeEnvironment {
let local_runtime_paths = ExecServerRuntimePaths::from_optional_paths(
config.codex_self_exe.clone(),
config.codex_linux_sandbox_exe.clone(),
);
let environment = match local_runtime_paths {
Ok(local_runtime_paths) => {
let environment_manager =
EnvironmentManager::new(EnvironmentManagerArgs::from_env(local_runtime_paths));
environment_manager
.default_environment()
.unwrap_or_else(|| environment_manager.local_environment())
}
Err(_) => Arc::new(
Environment::create_for_tests(std::env::var(CODEX_EXEC_SERVER_URL_ENV_VAR).ok())
.unwrap_or_else(|_| Environment::default_for_tests()),
),
};
McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf())
}
async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Result<()> {
let overrides = config_overrides
.parse_overrides()
@@ -483,6 +450,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
config.codex_home.to_path_buf(),
)));
let mcp_servers = mcp_manager.effective_servers(&config, /*auth*/ None).await;
let runtime_environment = runtime_environment_from_config(&config);
let mut entries: Vec<_> = mcp_servers.iter().collect();
entries.sort_by(|(a, _), (b, _)| a.cmp(b));
@@ -490,6 +458,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) ->
mcp_servers.iter(),
config.mcp_oauth_credentials_store_mode,
/*auth*/ None,
runtime_environment,
)
.await;

View File

@@ -24,12 +24,19 @@ pub use mcp::read_mcp_resource;
pub use mcp::McpAuthStatusEntry;
pub use mcp::McpOAuthLoginConfig;
pub use mcp::McpOAuthLoginOutcome;
pub use mcp::McpOAuthLoginSupport;
pub use mcp::McpOAuthScopesSource;
pub use mcp::ResolvedMcpOAuthScopes;
pub use mcp::compute_auth_statuses;
pub use mcp::discover_supported_scopes;
pub use mcp::discover_supported_scopes_for_server;
pub use mcp::http_client_for_server;
pub use mcp::oauth_login_support;
pub use mcp::oauth_login_support_for_server;
pub use mcp::perform_oauth_login_for_server;
pub use mcp::perform_oauth_login_return_url_for_server;
pub use mcp::perform_oauth_login_silent_for_server;
pub use mcp::resolve_oauth_scopes;
pub use mcp::should_retry_without_scopes;

View File

@@ -4,14 +4,28 @@ use anyhow::Result;
use codex_config::McpServerConfig;
use codex_config::McpServerTransportConfig;
use codex_config::types::OAuthCredentialsStoreMode;
use codex_exec_server::HttpClient;
use codex_exec_server::ReqwestHttpClient;
use codex_login::CodexAuth;
use codex_protocol::protocol::McpAuthStatus;
use codex_rmcp_client::OAuthProviderError;
use codex_rmcp_client::OauthLoginHandle;
use codex_rmcp_client::determine_streamable_http_auth_status;
use codex_rmcp_client::determine_streamable_http_auth_status_with_client;
use codex_rmcp_client::discover_streamable_http_oauth;
use codex_rmcp_client::discover_streamable_http_oauth_with_client;
use codex_rmcp_client::perform_oauth_login;
use codex_rmcp_client::perform_oauth_login_return_url;
use codex_rmcp_client::perform_oauth_login_return_url_with_client;
use codex_rmcp_client::perform_oauth_login_silent;
use codex_rmcp_client::perform_oauth_login_silent_with_client;
use codex_rmcp_client::perform_oauth_login_with_client;
use futures::future::join_all;
use std::sync::Arc;
use tracing::warn;
use crate::runtime::McpRuntimeEnvironment;
use super::CODEX_APPS_MCP_SERVER_NAME;
#[derive(Debug, Clone)]
@@ -49,7 +63,65 @@ pub struct McpAuthStatusEntry {
pub auth_status: McpAuthStatus,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum McpOAuthLoginOutcome {
Completed,
Unsupported,
}
#[derive(Clone)]
enum McpOAuthHttpClient {
LocalNoProxy,
Runtime(Arc<dyn HttpClient>),
}
impl McpOAuthHttpClient {
async fn login_support(&self, transport: &McpServerTransportConfig) -> McpOAuthLoginSupport {
match self {
Self::LocalNoProxy => oauth_login_support(transport).await,
Self::Runtime(http_client) => {
oauth_login_support_with_client(transport, http_client.clone()).await
}
}
}
async fn discover_supported_scopes(
&self,
transport: &McpServerTransportConfig,
) -> Option<Vec<String>> {
match self.login_support(transport).await {
McpOAuthLoginSupport::Supported(config) => config.discovered_scopes,
McpOAuthLoginSupport::Unsupported | McpOAuthLoginSupport::Unknown(_) => None,
}
}
}
pub async fn oauth_login_support(transport: &McpServerTransportConfig) -> McpOAuthLoginSupport {
oauth_login_support_with_discovery(transport).await
}
pub async fn oauth_login_support_for_server(
config: &McpServerConfig,
runtime_environment: McpRuntimeEnvironment,
) -> McpOAuthLoginSupport {
match config.experimental_environment.as_deref() {
Some("remote") => {
let http_client = match http_client_for_server(config, runtime_environment) {
Ok(http_client) => http_client,
Err(err) => return McpOAuthLoginSupport::Unknown(err),
};
oauth_login_support_with_client(&config.transport, http_client).await
}
None | Some("local") => oauth_login_support(&config.transport).await,
Some(environment) => McpOAuthLoginSupport::Unknown(anyhow::anyhow!(
"unsupported experimental_environment `{environment}`"
)),
}
}
async fn oauth_login_support_with_discovery(
transport: &McpServerTransportConfig,
) -> McpOAuthLoginSupport {
let McpServerTransportConfig::StreamableHttp {
url,
bearer_token_env_var,
@@ -77,6 +149,43 @@ pub async fn oauth_login_support(transport: &McpServerTransportConfig) -> McpOAu
}
}
async fn oauth_login_support_with_client(
transport: &McpServerTransportConfig,
http_client: Arc<dyn HttpClient>,
) -> McpOAuthLoginSupport {
let McpServerTransportConfig::StreamableHttp {
url,
bearer_token_env_var,
http_headers,
env_http_headers,
} = transport
else {
return McpOAuthLoginSupport::Unsupported;
};
if bearer_token_env_var.is_some() {
return McpOAuthLoginSupport::Unsupported;
}
match discover_streamable_http_oauth_with_client(
url,
http_headers.clone(),
env_http_headers.clone(),
http_client,
)
.await
{
Ok(Some(discovery)) => McpOAuthLoginSupport::Supported(McpOAuthLoginConfig {
url: url.clone(),
http_headers: http_headers.clone(),
env_http_headers: env_http_headers.clone(),
discovered_scopes: discovery.scopes_supported,
}),
Ok(None) => McpOAuthLoginSupport::Unsupported,
Err(err) => McpOAuthLoginSupport::Unknown(err),
}
}
pub async fn discover_supported_scopes(
transport: &McpServerTransportConfig,
) -> Option<Vec<String>> {
@@ -86,6 +195,16 @@ pub async fn discover_supported_scopes(
}
}
pub async fn discover_supported_scopes_for_server(
config: &McpServerConfig,
runtime_environment: McpRuntimeEnvironment,
) -> Option<Vec<String>> {
match oauth_login_support_for_server(config, runtime_environment).await {
McpOAuthLoginSupport::Supported(config) => config.discovered_scopes,
McpOAuthLoginSupport::Unsupported | McpOAuthLoginSupport::Unknown(_) => None,
}
}
pub fn resolve_oauth_scopes(
explicit_scopes: Option<Vec<String>>,
configured_scopes: Option<Vec<String>>,
@@ -125,10 +244,264 @@ pub fn should_retry_without_scopes(scopes: &ResolvedMcpOAuthScopes, error: &anyh
&& error.downcast_ref::<OAuthProviderError>().is_some()
}
#[allow(clippy::too_many_arguments)]
pub async fn perform_oauth_login_return_url_for_server(
server_name: &str,
config: &McpServerConfig,
store_mode: OAuthCredentialsStoreMode,
explicit_scopes: Option<Vec<String>>,
timeout_secs: Option<i64>,
callback_port: Option<u16>,
callback_url: Option<&str>,
runtime_environment: McpRuntimeEnvironment,
) -> Result<OauthLoginHandle> {
let McpServerTransportConfig::StreamableHttp {
url,
http_headers,
env_http_headers,
..
} = &config.transport
else {
anyhow::bail!("OAuth login is only supported for streamable HTTP servers.");
};
let oauth_http_client = oauth_http_client_for_server(config, runtime_environment)?;
let discovered_scopes = if explicit_scopes.is_none() && config.scopes.is_none() {
oauth_http_client
.discover_supported_scopes(&config.transport)
.await
} else {
None
};
let resolved_scopes =
resolve_oauth_scopes(explicit_scopes, config.scopes.clone(), discovered_scopes);
match oauth_http_client {
McpOAuthHttpClient::LocalNoProxy => {
perform_oauth_login_return_url(
server_name,
url,
store_mode,
http_headers.clone(),
env_http_headers.clone(),
&resolved_scopes.scopes,
config.oauth_resource.as_deref(),
timeout_secs,
callback_port,
callback_url,
)
.await
}
McpOAuthHttpClient::Runtime(http_client) => {
perform_oauth_login_return_url_with_client(
server_name,
url,
store_mode,
http_headers.clone(),
env_http_headers.clone(),
&resolved_scopes.scopes,
config.oauth_resource.as_deref(),
timeout_secs,
callback_port,
callback_url,
http_client,
)
.await
}
}
}
#[allow(clippy::too_many_arguments)]
pub async fn perform_oauth_login_silent_for_server(
server_name: &str,
config: &McpServerConfig,
store_mode: OAuthCredentialsStoreMode,
explicit_scopes: Option<Vec<String>>,
callback_port: Option<u16>,
callback_url: Option<&str>,
runtime_environment: McpRuntimeEnvironment,
) -> Result<McpOAuthLoginOutcome> {
let oauth_http_client = oauth_http_client_for_server(config, runtime_environment)?;
let oauth_config = match oauth_http_client.login_support(&config.transport).await {
McpOAuthLoginSupport::Supported(config) => config,
McpOAuthLoginSupport::Unsupported => return Ok(McpOAuthLoginOutcome::Unsupported),
McpOAuthLoginSupport::Unknown(err) => return Err(err),
};
let resolved_scopes = resolve_oauth_scopes(
explicit_scopes,
config.scopes.clone(),
oauth_config.discovered_scopes.clone(),
);
let final_result = match oauth_http_client {
McpOAuthHttpClient::LocalNoProxy => {
let first_attempt = perform_oauth_login_silent(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers.clone(),
oauth_config.env_http_headers.clone(),
&resolved_scopes.scopes,
config.oauth_resource.as_deref(),
callback_port,
callback_url,
)
.await;
match first_attempt {
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
perform_oauth_login_silent(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers,
oauth_config.env_http_headers,
&[],
config.oauth_resource.as_deref(),
callback_port,
callback_url,
)
.await
}
result => result,
}
}
McpOAuthHttpClient::Runtime(http_client) => {
let first_attempt = perform_oauth_login_silent_with_client(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers.clone(),
oauth_config.env_http_headers.clone(),
&resolved_scopes.scopes,
config.oauth_resource.as_deref(),
callback_port,
callback_url,
http_client.clone(),
)
.await;
match first_attempt {
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
perform_oauth_login_silent_with_client(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers,
oauth_config.env_http_headers,
&[],
config.oauth_resource.as_deref(),
callback_port,
callback_url,
http_client,
)
.await
}
result => result,
}
}
};
final_result.map(|()| McpOAuthLoginOutcome::Completed)
}
#[allow(clippy::too_many_arguments)]
pub async fn perform_oauth_login_for_server(
server_name: &str,
config: &McpServerConfig,
store_mode: OAuthCredentialsStoreMode,
explicit_scopes: Option<Vec<String>>,
callback_port: Option<u16>,
callback_url: Option<&str>,
runtime_environment: McpRuntimeEnvironment,
) -> Result<McpOAuthLoginOutcome> {
let oauth_http_client = oauth_http_client_for_server(config, runtime_environment)?;
let oauth_config = match oauth_http_client.login_support(&config.transport).await {
McpOAuthLoginSupport::Supported(config) => config,
McpOAuthLoginSupport::Unsupported => return Ok(McpOAuthLoginOutcome::Unsupported),
McpOAuthLoginSupport::Unknown(err) => return Err(err),
};
let resolved_scopes = resolve_oauth_scopes(
explicit_scopes,
config.scopes.clone(),
oauth_config.discovered_scopes.clone(),
);
let final_result = match oauth_http_client {
McpOAuthHttpClient::LocalNoProxy => {
let first_attempt = perform_oauth_login(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers.clone(),
oauth_config.env_http_headers.clone(),
&resolved_scopes.scopes,
config.oauth_resource.as_deref(),
callback_port,
callback_url,
)
.await;
match first_attempt {
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
perform_oauth_login(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers,
oauth_config.env_http_headers,
&[],
config.oauth_resource.as_deref(),
callback_port,
callback_url,
)
.await
}
result => result,
}
}
McpOAuthHttpClient::Runtime(http_client) => {
let first_attempt = perform_oauth_login_with_client(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers.clone(),
oauth_config.env_http_headers.clone(),
&resolved_scopes.scopes,
config.oauth_resource.as_deref(),
callback_port,
callback_url,
http_client.clone(),
)
.await;
match first_attempt {
Err(err) if should_retry_without_scopes(&resolved_scopes, &err) => {
perform_oauth_login_with_client(
server_name,
&oauth_config.url,
store_mode,
oauth_config.http_headers,
oauth_config.env_http_headers,
&[],
config.oauth_resource.as_deref(),
callback_port,
callback_url,
http_client,
)
.await
}
result => result,
}
}
};
final_result.map(|()| McpOAuthLoginOutcome::Completed)
}
pub async fn compute_auth_statuses<'a, I>(
servers: I,
store_mode: OAuthCredentialsStoreMode,
auth: Option<&CodexAuth>,
runtime_environment: McpRuntimeEnvironment,
) -> HashMap<String, McpAuthStatusEntry>
where
I: IntoIterator<Item = (&'a String, &'a McpServerConfig)>,
@@ -136,6 +509,7 @@ where
let futures = servers.into_iter().map(|(name, config)| {
let name = name.clone();
let config = config.clone();
let runtime_environment = runtime_environment.clone();
let has_runtime_auth = name == CODEX_APPS_MCP_SERVER_NAME
&& auth.is_some_and(CodexAuth::uses_codex_backend)
&& matches!(
@@ -146,14 +520,21 @@ where
}
);
async move {
let auth_status =
match compute_auth_status(&name, &config, store_mode, has_runtime_auth).await {
Ok(status) => status,
Err(error) => {
warn!("failed to determine auth status for MCP server `{name}`: {error:?}");
McpAuthStatus::Unsupported
}
};
let auth_status = match compute_auth_status(
&name,
&config,
store_mode,
has_runtime_auth,
runtime_environment,
)
.await
{
Ok(status) => status,
Err(error) => {
warn!("failed to determine auth status for MCP server `{name}`: {error:?}");
McpAuthStatus::Unsupported
}
};
let entry = McpAuthStatusEntry {
config,
auth_status,
@@ -170,6 +551,7 @@ async fn compute_auth_status(
config: &McpServerConfig,
store_mode: OAuthCredentialsStoreMode,
has_runtime_auth: bool,
runtime_environment: McpRuntimeEnvironment,
) -> Result<McpAuthStatus> {
if !config.enabled {
return Ok(McpAuthStatus::Unsupported);
@@ -186,28 +568,88 @@ async fn compute_auth_status(
bearer_token_env_var,
http_headers,
env_http_headers,
} => {
determine_streamable_http_auth_status(
server_name,
url,
bearer_token_env_var.as_deref(),
http_headers.clone(),
env_http_headers.clone(),
store_mode,
)
.await
} => match config.experimental_environment.as_deref() {
Some("remote") => {
let http_client = http_client_for_server(config, runtime_environment)?;
determine_streamable_http_auth_status_with_client(
server_name,
url,
bearer_token_env_var.as_deref(),
http_headers.clone(),
env_http_headers.clone(),
store_mode,
http_client,
)
.await
}
None | Some("local") => {
determine_streamable_http_auth_status(
server_name,
url,
bearer_token_env_var.as_deref(),
http_headers.clone(),
env_http_headers.clone(),
store_mode,
)
.await
}
Some(environment) => {
anyhow::bail!("unsupported experimental_environment `{environment}`")
}
},
}
}
pub fn http_client_for_server(
config: &McpServerConfig,
runtime_environment: McpRuntimeEnvironment,
) -> Result<Arc<dyn HttpClient>> {
match config.experimental_environment.as_deref() {
None | Some("local") => Ok(Arc::new(ReqwestHttpClient)),
Some("remote") => {
let environment = runtime_environment.environment();
if !environment.is_remote() {
anyhow::bail!("remote MCP server requires a remote environment");
}
Ok(environment.get_http_client())
}
Some(environment) => anyhow::bail!("unsupported experimental_environment `{environment}`"),
}
}
fn oauth_http_client_for_server(
config: &McpServerConfig,
runtime_environment: McpRuntimeEnvironment,
) -> Result<McpOAuthHttpClient> {
match config.experimental_environment.as_deref() {
None | Some("local") => Ok(McpOAuthHttpClient::LocalNoProxy),
Some("remote") => Ok(McpOAuthHttpClient::Runtime(http_client_for_server(
config,
runtime_environment,
)?)),
Some(environment) => anyhow::bail!("unsupported experimental_environment `{environment}`"),
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use anyhow::anyhow;
use codex_config::McpServerConfig;
use codex_config::McpServerTransportConfig;
use codex_exec_server::Environment;
use pretty_assertions::assert_eq;
use crate::runtime::McpRuntimeEnvironment;
use super::McpOAuthScopesSource;
use super::OAuthProviderError;
use super::ResolvedMcpOAuthScopes;
use super::http_client_for_server;
use super::resolve_oauth_scopes;
use super::should_retry_without_scopes;
@@ -318,4 +760,71 @@ mod tests {
&anyhow!("timed out waiting for OAuth callback"),
));
}
#[test]
fn local_server_uses_local_http_client_even_with_remote_runtime() {
let config = streamable_config(/*experimental_environment*/ None);
let runtime_environment = remote_runtime_environment();
assert!(http_client_for_server(&config, runtime_environment).is_ok());
}
#[test]
fn remote_server_uses_runtime_http_client_when_runtime_is_remote() {
let config = streamable_config(Some("remote"));
let runtime_environment = remote_runtime_environment();
assert!(http_client_for_server(&config, runtime_environment).is_ok());
}
#[tokio::test]
async fn remote_server_without_remote_runtime_returns_clear_error() {
let config = streamable_config(Some("remote"));
let runtime_environment = McpRuntimeEnvironment::new(
Arc::new(Environment::default_for_tests()),
PathBuf::from("/tmp"),
);
let Err(error) = http_client_for_server(&config, runtime_environment) else {
panic!("remote server should require remote runtime");
};
assert_eq!(
error.to_string(),
"remote MCP server requires a remote environment"
);
}
fn streamable_config(experimental_environment: Option<&str>) -> McpServerConfig {
McpServerConfig {
transport: McpServerTransportConfig::StreamableHttp {
url: "http://mcp.example.test/mcp".to_string(),
bearer_token_env_var: None,
http_headers: None,
env_http_headers: None,
},
experimental_environment: experimental_environment.map(str::to_string),
enabled: true,
required: false,
supports_parallel_tool_calls: false,
disabled_reason: None,
startup_timeout_sec: Some(Duration::from_secs(30)),
tool_timeout_sec: None,
default_tools_approval_mode: None,
enabled_tools: None,
disabled_tools: None,
scopes: None,
oauth_resource: None,
tools: HashMap::new(),
}
}
fn remote_runtime_environment() -> McpRuntimeEnvironment {
McpRuntimeEnvironment::new(
Arc::new(
Environment::create_for_tests(Some("ws://127.0.0.1:65535".to_string()))
.expect("create remote environment"),
),
PathBuf::from("/tmp"),
)
}
}

View File

@@ -1,11 +1,18 @@
pub use auth::McpAuthStatusEntry;
pub use auth::McpOAuthLoginConfig;
pub use auth::McpOAuthLoginOutcome;
pub use auth::McpOAuthLoginSupport;
pub use auth::McpOAuthScopesSource;
pub use auth::ResolvedMcpOAuthScopes;
pub use auth::compute_auth_statuses;
pub use auth::discover_supported_scopes;
pub use auth::discover_supported_scopes_for_server;
pub use auth::http_client_for_server;
pub use auth::oauth_login_support;
pub use auth::oauth_login_support_for_server;
pub use auth::perform_oauth_login_for_server;
pub use auth::perform_oauth_login_return_url_for_server;
pub use auth::perform_oauth_login_silent_for_server;
pub use auth::resolve_oauth_scopes;
pub use auth::should_retry_without_scopes;
@@ -223,6 +230,7 @@ pub async fn read_mcp_resource(
mcp_servers.iter(),
config.mcp_oauth_credentials_store_mode,
auth,
runtime_environment.clone(),
)
.await;
let (tx_event, rx_event) = unbounded();
@@ -286,6 +294,7 @@ pub async fn collect_mcp_server_status_snapshot_with_detail(
mcp_servers.iter(),
config.mcp_oauth_credentials_store_mode,
auth,
runtime_environment.clone(),
)
.await;

View File

@@ -253,20 +253,23 @@ pub async fn list_accessible_connectors_from_mcp_tools_with_environment_manager(
});
}
let environment = environment_manager
.default_environment()
.unwrap_or_else(|| environment_manager.local_environment());
let runtime_environment =
McpRuntimeEnvironment::new(environment.clone(), config.cwd.to_path_buf());
let auth_status_entries = compute_auth_statuses(
mcp_servers.iter(),
config.mcp_oauth_credentials_store_mode,
auth.as_ref(),
runtime_environment.clone(),
)
.await;
let (tx_event, rx_event) = unbounded();
drop(rx_event);
let environment = environment_manager
.default_environment()
.unwrap_or_else(|| environment_manager.local_environment());
let (mcp_connection_manager, cancel_token) = McpConnectionManager::new(
&mcp_servers,
config.mcp_oauth_credentials_store_mode,
@@ -275,7 +278,7 @@ pub async fn list_accessible_connectors_from_mcp_tools_with_environment_manager(
INITIAL_SUBMIT_ID.to_owned(),
tx_event,
PermissionProfile::default(),
McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf()),
runtime_environment,
config.codex_home.to_path_buf(),
codex_apps_tools_cache_key(auth.as_ref()),
ToolPluginProvenance::default(),

View File

@@ -11,7 +11,6 @@ use codex_protocol::request_user_input::RequestUserInputArgs;
use codex_protocol::request_user_input::RequestUserInputQuestion;
use codex_protocol::request_user_input::RequestUserInputQuestionOption;
use codex_protocol::request_user_input::RequestUserInputResponse;
use codex_rmcp_client::perform_oauth_login;
use tokio_util::sync::CancellationToken;
use tracing::warn;
@@ -19,11 +18,10 @@ use crate::SkillMetadata;
use crate::session::session::Session;
use crate::session::turn_context::TurnContext;
use crate::skills::model::SkillToolDependency;
use codex_mcp::McpOAuthLoginSupport;
use codex_mcp::McpOAuthLoginOutcome;
use codex_mcp::McpRuntimeEnvironment;
use codex_mcp::mcp_permission_prompt_is_auto_approved;
use codex_mcp::oauth_login_support;
use codex_mcp::resolve_oauth_scopes;
use codex_mcp::should_retry_without_scopes;
use codex_mcp::perform_oauth_login_for_server;
const SKILL_MCP_DEPENDENCY_PROMPT_ID: &str = "skill_mcp_dependency_install";
const MCP_DEPENDENCY_OPTION_INSTALL: &str = "Install";
@@ -126,15 +124,13 @@ pub(crate) async fn maybe_install_mcp_dependencies(
}
for (name, server_config) in added {
let oauth_config = match oauth_login_support(&server_config.transport).await {
McpOAuthLoginSupport::Supported(config) => config,
McpOAuthLoginSupport::Unsupported => continue,
McpOAuthLoginSupport::Unknown(err) => {
warn!("MCP server may or may not require login for dependency {name}: {err}");
continue;
}
};
let runtime_environment = McpRuntimeEnvironment::new(
turn_context
.environment
.clone()
.unwrap_or_else(|| sess.services.environment_manager.local_environment()),
turn_context.cwd.to_path_buf(),
);
sess.notify_background_event(
turn_context,
format!(
@@ -143,52 +139,20 @@ pub(crate) async fn maybe_install_mcp_dependencies(
)
.await;
let resolved_scopes = resolve_oauth_scopes(
/*explicit_scopes*/ None,
server_config.scopes.clone(),
oauth_config.discovered_scopes.clone(),
);
let first_attempt = perform_oauth_login(
match perform_oauth_login_for_server(
&name,
&oauth_config.url,
&server_config,
config.mcp_oauth_credentials_store_mode,
oauth_config.http_headers.clone(),
oauth_config.env_http_headers.clone(),
&resolved_scopes.scopes,
server_config.oauth_resource.as_deref(),
/*explicit_scopes*/ None,
config.mcp_oauth_callback_port,
config.mcp_oauth_callback_url.as_deref(),
runtime_environment,
)
.await;
if let Err(err) = first_attempt {
if should_retry_without_scopes(&resolved_scopes, &err) {
sess.notify_background_event(
turn_context,
format!(
"Retrying MCP {name} authentication without scopes after provider rejection."
),
)
.await;
if let Err(err) = perform_oauth_login(
&name,
&oauth_config.url,
config.mcp_oauth_credentials_store_mode,
oauth_config.http_headers,
oauth_config.env_http_headers,
&[],
server_config.oauth_resource.as_deref(),
config.mcp_oauth_callback_port,
config.mcp_oauth_callback_url.as_deref(),
)
.await
{
warn!("failed to login to MCP dependency {name}: {err}");
}
} else {
warn!("failed to login to MCP dependency {name}: {err}");
}
.await
{
Ok(McpOAuthLoginOutcome::Completed) => {}
Ok(McpOAuthLoginOutcome::Unsupported) => {}
Err(err) => warn!("failed to login to MCP dependency {name}: {err}"),
}
}

View File

@@ -32,6 +32,7 @@ use crate::tasks::UndoTask;
use crate::tasks::UserShellCommandMode;
use crate::tasks::UserShellCommandTask;
use crate::tasks::execute_user_shell_command;
use codex_mcp::McpRuntimeEnvironment;
use codex_mcp::collect_mcp_snapshot_from_manager;
use codex_mcp::compute_auth_statuses;
use codex_protocol::models::ContentItem;
@@ -538,12 +539,18 @@ pub async fn list_mcp_tools(sess: &Session, config: &Arc<Config>, sub_id: String
.mcp_manager
.effective_servers(config, auth.as_ref())
.await;
let environment = sess
.services
.environment_manager
.default_environment()
.unwrap_or_else(|| sess.services.environment_manager.local_environment());
let snapshot = collect_mcp_snapshot_from_manager(
&mcp_connection_manager,
compute_auth_statuses(
mcp_servers.iter(),
config.mcp_oauth_credentials_store_mode,
auth.as_ref(),
McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf()),
)
.await,
)

View File

@@ -219,8 +219,20 @@ impl Session {
.tool_plugin_provenance(config.as_ref())
.await;
let mcp_servers = with_codex_apps_mcp(mcp_servers, auth.as_ref(), &mcp_config);
let auth_statuses =
compute_auth_statuses(mcp_servers.iter(), store_mode, auth.as_ref()).await;
let runtime_environment = McpRuntimeEnvironment::new(
turn_context
.environment
.clone()
.unwrap_or_else(|| self.services.environment_manager.local_environment()),
turn_context.cwd.to_path_buf(),
);
let auth_statuses = compute_auth_statuses(
mcp_servers.iter(),
store_mode,
auth.as_ref(),
runtime_environment.clone(),
)
.await;
{
let mut guard = self.services.mcp_startup_cancellation_token.lock().await;
guard.cancel();
@@ -234,13 +246,7 @@ impl Session {
turn_context.sub_id.clone(),
self.get_tx_event(),
turn_context.permission_profile(),
McpRuntimeEnvironment::new(
turn_context
.environment
.clone()
.unwrap_or_else(|| self.services.environment_manager.local_environment()),
turn_context.cwd.to_path_buf(),
),
runtime_environment,
config.codex_home.to_path_buf(),
codex_apps_tools_cache_key(auth.as_ref()),
tool_plugin_provenance,

View File

@@ -445,15 +445,20 @@ impl Session {
let auth_manager_clone = Arc::clone(&auth_manager);
let config_for_mcp = Arc::clone(&config);
let mcp_manager_for_mcp = Arc::clone(&mcp_manager);
let environment_manager_for_mcp = Arc::clone(&environment_manager);
let auth_and_mcp_fut = async move {
let auth = auth_manager_clone.auth().await;
let mcp_servers = mcp_manager_for_mcp
.effective_servers(&config_for_mcp, auth.as_ref())
.await;
let environment = environment_manager_for_mcp
.default_environment()
.unwrap_or_else(|| environment_manager_for_mcp.local_environment());
let auth_statuses = compute_auth_statuses(
mcp_servers.iter(),
config_for_mcp.mcp_oauth_credentials_store_mode,
auth.as_ref(),
McpRuntimeEnvironment::new(environment, config_for_mcp.cwd.to_path_buf()),
)
.await;
(auth, mcp_servers, auth_statuses)

View File

@@ -14,7 +14,6 @@ axum = { workspace = true, default-features = false, features = [
"tokio",
] }
codex-api = { workspace = true }
codex-client = { workspace = true }
codex-config = { workspace = true }
codex-exec-server = { workspace = true }
codex-keyring-store = { workspace = true }

View File

@@ -1,31 +1,29 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Error;
use anyhow::Result;
use codex_exec_server::ExecServerError;
use codex_exec_server::HttpClient;
use codex_exec_server::HttpHeader;
use codex_exec_server::HttpRequestParams;
use codex_exec_server::HttpRequestResponse;
use codex_exec_server::HttpResponseBodyStream;
use codex_protocol::protocol::McpAuthStatus;
use reqwest::Client;
use reqwest::StatusCode;
use reqwest::Url;
use futures::FutureExt;
use futures::future::BoxFuture;
use reqwest::Method;
use reqwest::header::AUTHORIZATION;
use reqwest::header::HeaderMap;
use serde::Deserialize;
use reqwest::header::HeaderName;
use reqwest::header::HeaderValue;
use tracing::debug;
use crate::mcp_oauth_http::OAuthHttpClient;
use crate::mcp_oauth_http::StreamableHttpOAuthDiscovery;
use crate::oauth::has_oauth_tokens;
use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
use codex_config::types::OAuthCredentialsStoreMode;
const DISCOVERY_TIMEOUT: Duration = Duration::from_secs(5);
const OAUTH_DISCOVERY_HEADER: &str = "MCP-Protocol-Version";
const OAUTH_DISCOVERY_VERSION: &str = "2024-11-05";
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StreamableHttpOAuthDiscovery {
pub scopes_supported: Option<Vec<String>>,
}
/// Determine the authentication status for a streamable HTTP MCP server.
pub async fn determine_streamable_http_auth_status(
server_name: &str,
@@ -34,12 +32,36 @@ pub async fn determine_streamable_http_auth_status(
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
store_mode: OAuthCredentialsStoreMode,
) -> Result<McpAuthStatus> {
determine_streamable_http_auth_status_with_client(
server_name,
url,
bearer_token_env_var,
http_headers,
env_http_headers,
store_mode,
no_proxy_http_client(),
)
.await
}
/// Determine the authentication status for a streamable HTTP MCP server using
/// the caller-selected runtime HTTP client.
#[allow(clippy::too_many_arguments)]
pub async fn determine_streamable_http_auth_status_with_client(
server_name: &str,
url: &str,
bearer_token_env_var: Option<&str>,
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
store_mode: OAuthCredentialsStoreMode,
http_client: Arc<dyn HttpClient>,
) -> Result<McpAuthStatus> {
if bearer_token_env_var.is_some() {
return Ok(McpAuthStatus::BearerToken);
}
let default_headers = build_default_headers(http_headers, env_http_headers)?;
let default_headers = build_default_headers(http_headers.clone(), env_http_headers.clone())?;
if default_headers.contains_key(AUTHORIZATION) {
return Ok(McpAuthStatus::BearerToken);
}
@@ -48,7 +70,8 @@ pub async fn determine_streamable_http_auth_status(
return Ok(McpAuthStatus::OAuth);
}
match discover_streamable_http_oauth_with_headers(url, &default_headers).await {
let oauth_http = OAuthHttpClient::from_default_headers(http_client, default_headers);
match oauth_http.discover(url).await {
Ok(Some(_)) => Ok(McpAuthStatus::NotLoggedIn),
Ok(None) => Ok(McpAuthStatus::Unsupported),
Err(error) => {
@@ -74,121 +97,107 @@ pub async fn discover_streamable_http_oauth(
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
) -> Result<Option<StreamableHttpOAuthDiscovery>> {
let default_headers = build_default_headers(http_headers, env_http_headers)?;
discover_streamable_http_oauth_with_headers(url, &default_headers).await
discover_streamable_http_oauth_with_client(
url,
http_headers,
env_http_headers,
no_proxy_http_client(),
)
.await
}
async fn discover_streamable_http_oauth_with_headers(
pub(crate) fn no_proxy_http_client() -> Arc<dyn HttpClient> {
Arc::new(NoProxyReqwestHttpClient)
}
#[derive(Debug, Clone, Default)]
struct NoProxyReqwestHttpClient;
impl HttpClient for NoProxyReqwestHttpClient {
fn http_request(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, std::result::Result<HttpRequestResponse, ExecServerError>> {
async move { no_proxy_http_request(params).await }.boxed()
}
fn http_request_stream(
&self,
_params: HttpRequestParams,
) -> BoxFuture<
'_,
std::result::Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>,
> {
async move {
Err(ExecServerError::Protocol(
"streaming is not supported for OAuth discovery".to_string(),
))
}
.boxed()
}
}
async fn no_proxy_http_request(
params: HttpRequestParams,
) -> std::result::Result<HttpRequestResponse, ExecServerError> {
let method = Method::from_bytes(params.method.as_bytes())
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?;
let mut builder = reqwest::Client::builder().no_proxy();
if let Some(timeout_ms) = params.timeout_ms {
builder = builder.timeout(Duration::from_millis(timeout_ms));
}
let client = builder
.build()
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?;
let mut request = client.request(method, &params.url);
for header in params.headers {
let name = HeaderName::from_bytes(header.name.as_bytes())
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?;
let value = HeaderValue::from_str(&header.value)
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?;
request = request.header(name, value);
}
if let Some(body) = params.body {
request = request.body(body.into_inner());
}
let response = request
.send()
.await
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?;
let status = response.status().as_u16();
let headers = response
.headers()
.iter()
.filter_map(|(name, value)| {
value.to_str().ok().map(|value| HttpHeader {
name: name.to_string(),
value: value.to_string(),
})
})
.collect();
let body = response
.bytes()
.await
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))?
.to_vec();
Ok(HttpRequestResponse {
status,
headers,
body: body.into(),
})
}
pub async fn discover_streamable_http_oauth_with_client(
url: &str,
default_headers: &HeaderMap,
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
http_client: Arc<dyn HttpClient>,
) -> Result<Option<StreamableHttpOAuthDiscovery>> {
let base_url = Url::parse(url)?;
// Use no_proxy to avoid a bug in the system-configuration crate that
// can result in a panic. See #8912.
let builder = Client::builder().timeout(DISCOVERY_TIMEOUT).no_proxy();
let client = apply_default_headers(builder, default_headers).build()?;
let mut last_error: Option<Error> = None;
for candidate_path in discovery_paths(base_url.path()) {
let mut discovery_url = base_url.clone();
discovery_url.set_path(&candidate_path);
let response = match client
.get(discovery_url.clone())
.header(OAUTH_DISCOVERY_HEADER, OAUTH_DISCOVERY_VERSION)
.send()
.await
{
Ok(response) => response,
Err(err) => {
last_error = Some(err.into());
continue;
}
};
if response.status() != StatusCode::OK {
continue;
}
let metadata = match response.json::<OAuthDiscoveryMetadata>().await {
Ok(metadata) => metadata,
Err(err) => {
last_error = Some(err.into());
continue;
}
};
if metadata.authorization_endpoint.is_some() && metadata.token_endpoint.is_some() {
return Ok(Some(StreamableHttpOAuthDiscovery {
scopes_supported: normalize_scopes(metadata.scopes_supported),
}));
}
}
if let Some(err) = last_error {
debug!("OAuth discovery requests failed for {url}: {err:?}");
}
Ok(None)
}
#[derive(Debug, Deserialize)]
struct OAuthDiscoveryMetadata {
#[serde(default)]
authorization_endpoint: Option<String>,
#[serde(default)]
token_endpoint: Option<String>,
#[serde(default)]
scopes_supported: Option<Vec<String>>,
}
fn normalize_scopes(scopes_supported: Option<Vec<String>>) -> Option<Vec<String>> {
let scopes_supported = scopes_supported?;
let mut normalized = Vec::new();
for scope in scopes_supported {
let scope = scope.trim();
if scope.is_empty() {
continue;
}
let scope = scope.to_string();
if !normalized.contains(&scope) {
normalized.push(scope);
}
}
if normalized.is_empty() {
None
} else {
Some(normalized)
}
}
/// Implements RFC 8414 section 3.1 for discovering well-known oauth endpoints.
/// This is a requirement for MCP servers to support OAuth.
/// https://datatracker.ietf.org/doc/html/rfc8414#section-3.1
/// https://github.com/modelcontextprotocol/rust-sdk/blob/main/crates/rmcp/src/transport/auth.rs#L182
fn discovery_paths(base_path: &str) -> Vec<String> {
let trimmed = base_path.trim_start_matches('/').trim_end_matches('/');
let canonical = "/.well-known/oauth-authorization-server".to_string();
if trimmed.is_empty() {
return vec![canonical];
}
let mut candidates = Vec::new();
let mut push_unique = |candidate: String| {
if !candidates.contains(&candidate) {
candidates.push(candidate);
}
};
push_unique(format!("{canonical}/{trimmed}"));
push_unique(format!("/{trimmed}/.well-known/oauth-authorization-server"));
push_unique(canonical);
candidates
let oauth_http = OAuthHttpClient::new(http_client, http_headers, env_http_headers)?;
oauth_http.discover(url).await
}
#[cfg(test)]

View File

@@ -8,6 +8,7 @@ use std::time::Duration;
use axum::Router;
use axum::body::Body;
use axum::body::Bytes;
use axum::extract::Json;
use axum::extract::State;
use axum::http::HeaderMap;
@@ -122,6 +123,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
SESSION_POST_FAILURE_CONTROL_PATH,
post(arm_session_post_failure),
)
.route("/oauth/token", post(oauth_token))
.route(
"/.well-known/oauth-authorization-server/mcp",
get({
@@ -139,6 +141,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
serde_json::to_vec(&json!({
"authorization_endpoint": format!("{metadata_base}/oauth/authorize"),
"token_endpoint": format!("{metadata_base}/oauth/token"),
"registration_endpoint": format!("{metadata_base}/oauth/register"),
"scopes_supported": [""],
})).expect("failed to serialize metadata"),
))
@@ -386,7 +389,8 @@ async fn require_bearer(
request: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
if request.uri().path().contains("/.well-known/") {
let path = request.uri().path();
if path.contains("/.well-known/") || path.starts_with("/oauth/") {
return Ok(next.run(request).await);
}
if request
@@ -400,6 +404,33 @@ async fn require_bearer(
}
}
async fn oauth_token(body: Bytes) -> Result<Response, StatusCode> {
let form = String::from_utf8(body.to_vec()).map_err(|_| StatusCode::BAD_REQUEST)?;
if !form.contains("grant_type=refresh_token") {
return Err(StatusCode::BAD_REQUEST);
}
let access_token =
std::env::var("MCP_OAUTH_ACCESS_TOKEN").unwrap_or_else(|_| "refreshed-oauth-token".into());
let refresh_token =
std::env::var("MCP_OAUTH_REFRESH_TOKEN").unwrap_or_else(|_| "refresh-token".into());
#[expect(clippy::expect_used)]
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, "application/json")
.body(Body::from(
serde_json::to_vec(&json!({
"access_token": access_token,
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": refresh_token,
}))
.expect("failed to serialize token response"),
))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
async fn arm_session_post_failure(
State(state): State<SessionFailureState>,
Json(request): Json<ArmSessionPostFailureRequest>,

View File

@@ -35,6 +35,8 @@ use rmcp::transport::streamable_http_client::StreamableHttpPostResponse;
use sse_stream::Sse;
use sse_stream::SseStream;
use crate::oauth::OAuthPersistor;
const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
const JSON_MIME_TYPE: &str = "application/json";
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
@@ -45,6 +47,7 @@ pub(crate) struct StreamableHttpClientAdapter {
http_client: Arc<dyn HttpClient>,
default_headers: HeaderMap,
auth_provider: Option<SharedAuthProvider>,
oauth_persistor: Option<OAuthPersistor>,
}
#[derive(Debug, thiserror::Error)]
@@ -55,6 +58,8 @@ pub(crate) enum StreamableHttpClientAdapterError {
HttpRequest(#[from] ExecServerError),
#[error("invalid HTTP header: {0}")]
Header(String),
#[error("OAuth token error: {0}")]
Auth(String),
}
impl StreamableHttpClientAdapter {
@@ -62,11 +67,13 @@ impl StreamableHttpClientAdapter {
http_client: Arc<dyn HttpClient>,
default_headers: HeaderMap,
auth_provider: Option<SharedAuthProvider>,
oauth_persistor: Option<OAuthPersistor>,
) -> Self {
Self {
http_client,
default_headers,
auth_provider,
oauth_persistor,
}
}
}
@@ -82,7 +89,7 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
auth_token: Option<String>,
) -> std::result::Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>> {
let mut headers = self.default_headers.clone();
self.add_auth_headers(&mut headers);
self.add_auth_headers(&mut headers).await?;
insert_header(
&mut headers,
ACCEPT,
@@ -179,7 +186,7 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
auth_token: Option<String>,
) -> std::result::Result<(), StreamableHttpError<Self::Error>> {
let mut headers = self.default_headers.clone();
self.add_auth_headers(&mut headers);
self.add_auth_headers(&mut headers).await?;
if let Some(auth_token) = auth_token {
insert_header(
&mut headers,
@@ -232,7 +239,7 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
StreamableHttpError<Self::Error>,
> {
let mut headers = self.default_headers.clone();
self.add_auth_headers(&mut headers);
self.add_auth_headers(&mut headers).await?;
insert_header(
&mut headers,
ACCEPT,
@@ -308,10 +315,27 @@ impl StreamableHttpClient for StreamableHttpClientAdapter {
}
impl StreamableHttpClientAdapter {
fn add_auth_headers(&self, headers: &mut HeaderMap) {
async fn add_auth_headers(
&self,
headers: &mut HeaderMap,
) -> std::result::Result<(), StreamableHttpError<StreamableHttpClientAdapterError>> {
if let Some(auth_provider) = &self.auth_provider {
headers.extend(auth_provider.to_auth_headers());
}
if !headers.contains_key(AUTHORIZATION)
&& let Some(oauth_persistor) = &self.oauth_persistor
&& let Some(access_token) = oauth_persistor.access_token().await.map_err(|err| {
StreamableHttpError::Client(StreamableHttpClientAdapterError::Auth(err.to_string()))
})?
{
insert_header(
headers,
AUTHORIZATION,
format!("Bearer {access_token}"),
StreamableHttpClientAdapterError::Header,
)?;
}
Ok(())
}
}

View File

@@ -3,6 +3,7 @@ mod elicitation_client_service;
mod executor_process_transport;
mod http_client_adapter;
mod logging_client_handler;
mod mcp_oauth_http;
mod oauth;
mod perform_oauth_login;
mod program_resolver;
@@ -10,11 +11,13 @@ mod rmcp_client;
mod stdio_server_launcher;
mod utils;
pub use auth_status::StreamableHttpOAuthDiscovery;
pub use auth_status::determine_streamable_http_auth_status;
pub use auth_status::determine_streamable_http_auth_status_with_client;
pub use auth_status::discover_streamable_http_oauth;
pub use auth_status::discover_streamable_http_oauth_with_client;
pub use auth_status::supports_oauth_login;
pub use codex_protocol::protocol::McpAuthStatus;
pub use mcp_oauth_http::StreamableHttpOAuthDiscovery;
pub use oauth::StoredOAuthTokens;
pub use oauth::WrappedOAuthTokenResponse;
pub use oauth::delete_oauth_tokens;
@@ -24,7 +27,10 @@ pub use perform_oauth_login::OAuthProviderError;
pub use perform_oauth_login::OauthLoginHandle;
pub use perform_oauth_login::perform_oauth_login;
pub use perform_oauth_login::perform_oauth_login_return_url;
pub use perform_oauth_login::perform_oauth_login_return_url_with_client;
pub use perform_oauth_login::perform_oauth_login_silent;
pub use perform_oauth_login::perform_oauth_login_silent_with_client;
pub use perform_oauth_login::perform_oauth_login_with_client;
pub use rmcp::model::ElicitationAction;
pub use rmcp_client::Elicitation;
pub use rmcp_client::ElicitationResponse;

View File

@@ -0,0 +1,690 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use codex_exec_server::HttpClient;
use codex_exec_server::HttpHeader;
use codex_exec_server::HttpRequestParams;
use oauth2::AuthUrl;
use oauth2::AuthorizationCode;
use oauth2::ClientId;
use oauth2::ClientSecret;
use oauth2::CsrfToken;
use oauth2::EmptyExtraTokenFields;
use oauth2::PkceCodeChallenge;
use oauth2::PkceCodeVerifier;
use oauth2::RedirectUrl;
use oauth2::RefreshToken;
use oauth2::RequestTokenError;
use oauth2::Scope;
use oauth2::StandardErrorResponse;
use oauth2::StandardTokenResponse;
use oauth2::TokenResponse;
use oauth2::TokenUrl;
use oauth2::basic::BasicClient;
use oauth2::basic::BasicErrorResponseType;
use oauth2::basic::BasicTokenType;
use reqwest::StatusCode;
use reqwest::Url;
use reqwest::header::HeaderMap;
use serde::Deserialize;
use serde::Serialize;
use tracing::debug;
use crate::WrappedOAuthTokenResponse;
use crate::oauth::StoredOAuthTokens;
use crate::oauth::compute_expires_at_millis;
use crate::utils::build_default_headers;
const DISCOVERY_TIMEOUT: Duration = Duration::from_secs(5);
const OAUTH_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
const OAUTH_DISCOVERY_HEADER: &str = "MCP-Protocol-Version";
const OAUTH_DISCOVERY_VERSION: &str = "2024-11-05";
type OAuthTokenResponse = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
type OAuthErrorResponse = StandardErrorResponse<BasicErrorResponseType>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StreamableHttpOAuthDiscovery {
pub scopes_supported: Option<Vec<String>>,
}
#[derive(Clone)]
pub(crate) struct OAuthHttpClient {
http_client: Arc<dyn HttpClient>,
default_headers: HeaderMap,
}
#[derive(Debug, thiserror::Error)]
#[error("{0}")]
pub(crate) struct OAuthHttpError(String);
#[derive(Debug)]
pub(crate) struct OAuthAuthorizationSession {
pub authorization_url: String,
pub csrf_state: CsrfToken,
pkce_verifier: PkceCodeVerifier,
client: BasicClient<
oauth2::EndpointSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointSet,
>,
client_id: String,
}
#[derive(Debug, Clone)]
struct OAuthClientConfig {
client_id: String,
client_secret: Option<String>,
}
#[derive(Debug)]
enum DiscoveredOAuthMetadata {
AuthorizationServer(OAuthDiscoveryMetadata),
ProtectedResource {
metadata_url: Url,
authorization_servers: Vec<String>,
},
}
impl OAuthHttpClient {
pub(crate) fn new(
http_client: Arc<dyn HttpClient>,
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
) -> Result<Self> {
let default_headers = build_default_headers(http_headers, env_http_headers)?;
Ok(Self {
http_client,
default_headers,
})
}
pub(crate) fn from_default_headers(
http_client: Arc<dyn HttpClient>,
default_headers: HeaderMap,
) -> Self {
Self {
http_client,
default_headers,
}
}
pub(crate) async fn discover(&self, url: &str) -> Result<Option<StreamableHttpOAuthDiscovery>> {
let metadata = match self.discover_metadata(url).await? {
Some(metadata) => metadata,
None => return Ok(None),
};
Ok(Some(StreamableHttpOAuthDiscovery {
scopes_supported: normalize_scopes(metadata.scopes_supported),
}))
}
async fn discover_metadata(&self, url: &str) -> Result<Option<OAuthDiscoveryMetadata>> {
match self.discover_metadata_from_paths(url).await? {
Some(DiscoveredOAuthMetadata::AuthorizationServer(metadata)) => Ok(Some(metadata)),
Some(DiscoveredOAuthMetadata::ProtectedResource {
metadata_url,
authorization_servers,
}) => {
for authorization_server in authorization_servers {
let authorization_server = authorization_server.trim();
if authorization_server.is_empty() {
continue;
}
let authorization_server_url = match Url::parse(authorization_server) {
Ok(url) => url,
Err(_) => match metadata_url.join(authorization_server) {
Ok(url) => url,
Err(err) => {
debug!(
"failed to resolve OAuth authorization server URL `{authorization_server}`: {err}"
);
continue;
}
},
};
let discovered = if authorization_server_url.path().contains("/.well-known/") {
self.fetch_discovery_metadata(&authorization_server_url)
.await?
} else {
self.discover_metadata_from_paths(authorization_server_url.as_str())
.await?
};
if let Some(DiscoveredOAuthMetadata::AuthorizationServer(metadata)) = discovered
{
return Ok(Some(metadata));
}
}
Ok(None)
}
None => Ok(None),
}
}
async fn discover_metadata_from_paths(
&self,
url: &str,
) -> Result<Option<DiscoveredOAuthMetadata>> {
let base_url = Url::parse(url)?;
let mut last_error: Option<anyhow::Error> = None;
for candidate_path in discovery_paths(base_url.path()) {
let mut discovery_url = base_url.clone();
discovery_url.set_query(None);
discovery_url.set_fragment(None);
discovery_url.set_path(&candidate_path);
match self.fetch_discovery_metadata(&discovery_url).await {
Ok(Some(metadata)) => return Ok(Some(metadata)),
Ok(None) => {}
Err(err) => {
last_error = Some(err);
continue;
}
}
}
if let Some(err) = last_error {
debug!("OAuth discovery requests failed for {url}: {err:?}");
}
Ok(None)
}
async fn fetch_discovery_metadata(
&self,
discovery_url: &Url,
) -> Result<Option<DiscoveredOAuthMetadata>> {
let response = self
.request(
"GET",
discovery_url.as_str(),
vec![HttpHeader {
name: OAUTH_DISCOVERY_HEADER.to_string(),
value: OAUTH_DISCOVERY_VERSION.to_string(),
}],
/*body*/ None,
Some(DISCOVERY_TIMEOUT),
)
.await?;
if response.status != StatusCode::OK.as_u16() {
return Ok(None);
}
let metadata: OAuthDiscoveryMetadata = serde_json::from_slice(&response.body.0)?;
if metadata.authorization_endpoint.is_some() && metadata.token_endpoint.is_some() {
return Ok(Some(DiscoveredOAuthMetadata::AuthorizationServer(metadata)));
}
let authorization_servers = {
let mut authorization_servers = Vec::new();
let mut push_unique = |authorization_server: &String| {
let authorization_server = authorization_server.trim();
if !authorization_server.is_empty()
&& !authorization_servers
.iter()
.any(|existing| existing == authorization_server)
{
authorization_servers.push(authorization_server.to_string());
}
};
if let Some(authorization_server) = metadata.authorization_server.as_ref() {
push_unique(authorization_server);
}
if let Some(list) = metadata.authorization_servers.as_ref() {
for authorization_server in list {
push_unique(authorization_server);
}
}
authorization_servers
};
if authorization_servers.is_empty() {
Ok(None)
} else {
Ok(Some(DiscoveredOAuthMetadata::ProtectedResource {
metadata_url: discovery_url.clone(),
authorization_servers,
}))
}
}
pub(crate) async fn start_authorization(
&self,
server_url: &str,
scopes: &[String],
redirect_uri: &str,
client_name: &str,
) -> Result<OAuthAuthorizationSession> {
let metadata = self
.discover_metadata(server_url)
.await?
.ok_or_else(|| anyhow!("No authorization support detected"))?;
let client_config = self
.register_client(&metadata, client_name, redirect_uri)
.await?;
let client = oauth_client(&metadata, &client_config, redirect_uri)?;
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let mut auth_request = client
.authorize_url(CsrfToken::new_random)
.set_pkce_challenge(pkce_challenge);
for scope in scopes {
auth_request = auth_request.add_scope(Scope::new(scope.clone()));
}
let (authorization_url, csrf_state) = auth_request.url();
Ok(OAuthAuthorizationSession {
authorization_url: authorization_url.to_string(),
csrf_state,
pkce_verifier,
client,
client_id: client_config.client_id,
})
}
pub(crate) async fn exchange_code(
&self,
session: OAuthAuthorizationSession,
code: &str,
csrf_state: &str,
) -> Result<(String, OAuthTokenResponse)> {
if session.csrf_state.secret() != csrf_state {
return Err(anyhow!(
"OAuth callback state did not match authorization request"
));
}
let http_client = self.clone();
let token = session
.client
.exchange_code(AuthorizationCode::new(code.to_string()))
.set_pkce_verifier(session.pkce_verifier)
.request_async(&|request| {
let http_client = http_client.clone();
async move { http_client.oauth_request(request).await }
})
.await
.or_else(parse_token_from_parse_error)?;
Ok((session.client_id, token))
}
pub(crate) async fn refresh_token(
&self,
tokens: &StoredOAuthTokens,
) -> Result<Option<StoredOAuthTokens>> {
let refresh_token = match tokens.token_response.0.refresh_token() {
Some(refresh_token) => refresh_token.secret().to_string(),
None => return Ok(None),
};
let metadata = self
.discover_metadata(&tokens.url)
.await?
.ok_or_else(|| anyhow!("No authorization support detected"))?;
let client_config = OAuthClientConfig {
client_id: tokens.client_id.clone(),
client_secret: None,
};
let client = oauth_client(&metadata, &client_config, &tokens.url)?;
let http_client = self.clone();
let token = client
.exchange_refresh_token(&RefreshToken::new(refresh_token))
.request_async(&|request| {
let http_client = http_client.clone();
async move { http_client.oauth_request(request).await }
})
.await
.or_else(parse_token_from_parse_error)?;
let expires_at = compute_expires_at_millis(&token);
Ok(Some(StoredOAuthTokens {
server_name: tokens.server_name.clone(),
url: tokens.url.clone(),
client_id: tokens.client_id.clone(),
token_response: WrappedOAuthTokenResponse(token),
expires_at,
}))
}
async fn register_client(
&self,
metadata: &OAuthDiscoveryMetadata,
client_name: &str,
redirect_uri: &str,
) -> Result<OAuthClientConfig> {
let registration_url = metadata
.registration_endpoint
.as_deref()
.ok_or_else(|| anyhow!("Dynamic client registration not supported"))?;
if let Some(response_types_supported) = metadata.response_types_supported.as_ref()
&& !response_types_supported.iter().any(|value| value == "code")
{
return Err(anyhow!(
"OAuth server does not support authorization code flow"
));
}
let body = serde_json::to_vec(&ClientRegistrationRequest {
client_name: client_name.to_string(),
redirect_uris: vec![redirect_uri.to_string()],
grant_types: vec![
"authorization_code".to_string(),
"refresh_token".to_string(),
],
token_endpoint_auth_method: "none".to_string(),
response_types: vec!["code".to_string()],
})?;
let response = self
.request(
"POST",
registration_url,
vec![HttpHeader {
name: reqwest::header::CONTENT_TYPE.to_string(),
value: "application/json".to_string(),
}],
Some(body),
Some(OAUTH_HTTP_TIMEOUT),
)
.await?;
if !status_is_success(response.status) {
return Err(anyhow!(
"Dynamic registration failed: HTTP {}: {}",
response.status,
String::from_utf8_lossy(&response.body.0)
));
}
let registration: ClientRegistrationResponse = serde_json::from_slice(&response.body.0)
.context("failed to parse registration response")?;
Ok(OAuthClientConfig {
client_id: registration.client_id,
client_secret: registration
.client_secret
.filter(|secret| !secret.is_empty()),
})
}
async fn oauth_request(
&self,
request: oauth2::HttpRequest,
) -> std::result::Result<oauth2::HttpResponse, OAuthHttpError> {
let (parts, body) = request.into_parts();
let headers = parts
.headers
.iter()
.map(|(name, value)| {
let value = value
.to_str()
.map_err(|err| OAuthHttpError(format!("invalid OAuth header value: {err}")))?;
Ok(HttpHeader {
name: name.to_string(),
value: value.to_string(),
})
})
.collect::<std::result::Result<Vec<_>, OAuthHttpError>>()?;
let response = self
.request(
parts.method.as_str(),
parts.uri.to_string().as_str(),
headers,
Some(body),
Some(OAUTH_HTTP_TIMEOUT),
)
.await
.map_err(|err| OAuthHttpError(err.to_string()))?;
let mut oauth_response = oauth2::HttpResponse::new(response.body.0);
*oauth_response.status_mut() = oauth2::http::StatusCode::from_u16(response.status)
.map_err(|err| OAuthHttpError(format!("invalid OAuth response status: {err}")))?;
for header in response.headers {
let name = oauth2::http::HeaderName::from_bytes(header.name.as_bytes())
.map_err(|err| OAuthHttpError(format!("invalid OAuth response header: {err}")))?;
let value = oauth2::http::HeaderValue::from_str(&header.value).map_err(|err| {
OAuthHttpError(format!("invalid OAuth response header value: {err}"))
})?;
oauth_response.headers_mut().append(name, value);
}
Ok(oauth_response)
}
async fn request(
&self,
method: &str,
url: &str,
extra_headers: Vec<HttpHeader>,
body: Option<Vec<u8>>,
timeout: Option<Duration>,
) -> Result<codex_exec_server::HttpRequestResponse> {
let mut headers = protocol_headers(&self.default_headers)?;
headers.extend(extra_headers);
self.http_client
.http_request(HttpRequestParams {
method: method.to_string(),
url: url.to_string(),
headers,
body: body.map(Into::into),
timeout_ms: timeout
.map(|timeout| timeout.as_millis().clamp(1, u64::MAX as u128) as u64),
request_id: "oauth-request".to_string(),
stream_response: false,
})
.await
.map_err(|err| anyhow!(err))
}
}
#[derive(Debug, Deserialize)]
struct OAuthDiscoveryMetadata {
#[serde(default)]
authorization_endpoint: Option<String>,
#[serde(default)]
token_endpoint: Option<String>,
#[serde(default)]
registration_endpoint: Option<String>,
#[serde(default)]
scopes_supported: Option<Vec<String>>,
#[serde(default)]
response_types_supported: Option<Vec<String>>,
#[serde(default)]
authorization_server: Option<String>,
#[serde(default)]
authorization_servers: Option<Vec<String>>,
}
#[derive(Debug, Serialize)]
struct ClientRegistrationRequest {
client_name: String,
redirect_uris: Vec<String>,
grant_types: Vec<String>,
token_endpoint_auth_method: String,
response_types: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct ClientRegistrationResponse {
client_id: String,
client_secret: Option<String>,
}
fn oauth_client(
metadata: &OAuthDiscoveryMetadata,
config: &OAuthClientConfig,
redirect_uri: &str,
) -> Result<
BasicClient<
oauth2::EndpointSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointSet,
>,
> {
let authorization_endpoint = metadata
.authorization_endpoint
.clone()
.ok_or_else(|| anyhow!("OAuth metadata did not include authorization endpoint"))?;
let token_endpoint = metadata
.token_endpoint
.clone()
.ok_or_else(|| anyhow!("OAuth metadata did not include token endpoint"))?;
let mut client = BasicClient::new(ClientId::new(config.client_id.clone()))
.set_auth_uri(AuthUrl::new(authorization_endpoint)?)
.set_token_uri(TokenUrl::new(token_endpoint)?)
.set_redirect_uri(RedirectUrl::new(redirect_uri.to_string())?);
if let Some(secret) = config.client_secret.clone() {
client = client.set_client_secret(ClientSecret::new(secret));
}
Ok(client)
}
fn parse_token_from_parse_error(
error: RequestTokenError<OAuthHttpError, OAuthErrorResponse>,
) -> std::result::Result<OAuthTokenResponse, RequestTokenError<OAuthHttpError, OAuthErrorResponse>>
{
match error {
RequestTokenError::Parse(parse_error, body) => {
match serde_json::from_slice::<OAuthTokenResponse>(&body) {
Ok(parsed) => Ok(parsed),
Err(_) => Err(RequestTokenError::Parse(parse_error, body)),
}
}
error => Err(error),
}
}
fn protocol_headers(headers: &HeaderMap) -> Result<Vec<HttpHeader>> {
headers
.iter()
.map(|(name, value)| {
let value = value
.to_str()
.with_context(|| format!("invalid HTTP header value for `{name}`"))?;
Ok(HttpHeader {
name: name.to_string(),
value: value.to_string(),
})
})
.collect()
}
fn status_is_success(status: u16) -> bool {
(200..300).contains(&status)
}
pub(crate) fn normalize_scopes(scopes_supported: Option<Vec<String>>) -> Option<Vec<String>> {
let scopes_supported = scopes_supported?;
let mut normalized = Vec::new();
for scope in scopes_supported {
let scope = scope.trim();
if scope.is_empty() {
continue;
}
let scope = scope.to_string();
if !normalized.contains(&scope) {
normalized.push(scope);
}
}
if normalized.is_empty() {
None
} else {
Some(normalized)
}
}
/// Implements RFC 8414 section 3.1 and RFC 9728 section 3 for discovering
/// well-known OAuth endpoints.
pub(crate) fn discovery_paths(base_path: &str) -> Vec<String> {
let trimmed = base_path.trim_start_matches('/').trim_end_matches('/');
let mut candidates = Vec::new();
let mut push_unique = |candidate: String| {
if !candidates.contains(&candidate) {
candidates.push(candidate);
}
};
let mut push_well_known = |well_known: &str| {
let canonical = format!("/.well-known/{well_known}");
if trimmed.is_empty() {
push_unique(canonical);
} else {
push_unique(format!("{canonical}/{trimmed}"));
push_unique(format!("/{trimmed}/.well-known/{well_known}"));
push_unique(canonical);
}
};
push_well_known("oauth-authorization-server");
push_well_known("openid-configuration");
push_well_known("oauth-protected-resource");
candidates
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn discovery_paths_prefer_rfc_8414_resource_path() {
assert_eq!(
discovery_paths("/mcp"),
vec![
"/.well-known/oauth-authorization-server/mcp".to_string(),
"/mcp/.well-known/oauth-authorization-server".to_string(),
"/.well-known/oauth-authorization-server".to_string(),
"/.well-known/openid-configuration/mcp".to_string(),
"/mcp/.well-known/openid-configuration".to_string(),
"/.well-known/openid-configuration".to_string(),
"/.well-known/oauth-protected-resource/mcp".to_string(),
"/mcp/.well-known/oauth-protected-resource".to_string(),
"/.well-known/oauth-protected-resource".to_string(),
]
);
}
#[test]
fn discovery_paths_deduplicate_root_path() {
assert_eq!(
discovery_paths("/"),
vec![
"/.well-known/oauth-authorization-server".to_string(),
"/.well-known/openid-configuration".to_string(),
"/.well-known/oauth-protected-resource".to_string(),
]
);
}
#[test]
fn normalize_scopes_trims_empties_and_deduplicates() {
assert_eq!(
normalize_scopes(Some(vec![
"read".to_string(),
" write ".to_string(),
"".to_string(),
"read".to_string(),
])),
Some(vec!["read".to_string(), "write".to_string()])
);
}
}

View File

@@ -45,9 +45,10 @@ use tracing::warn;
use codex_keyring_store::DefaultKeyringStore;
use codex_keyring_store::KeyringStore;
use rmcp::transport::auth::AuthorizationManager;
use tokio::sync::Mutex;
use tokio::sync::Semaphore;
use crate::mcp_oauth_http::OAuthHttpClient;
use codex_utils_home_dir::find_codex_home;
const KEYRING_SERVICE: &str = "Codex MCP Credentials";
@@ -94,6 +95,43 @@ pub(crate) fn load_oauth_tokens(
}
}
enum OAuthTokensStorageState {
Found(StoredOAuthTokens),
Missing,
Unavailable,
}
fn load_oauth_tokens_for_persist(
server_name: &str,
url: &str,
store_mode: OAuthCredentialsStoreMode,
) -> Result<OAuthTokensStorageState> {
let keyring_store = DefaultKeyringStore;
match store_mode {
OAuthCredentialsStoreMode::Auto => {
load_oauth_tokens_for_persist_from_keyring_with_fallback_to_file(
&keyring_store,
server_name,
url,
)
}
OAuthCredentialsStoreMode::File => {
Ok(match load_oauth_tokens_from_file(server_name, url)? {
Some(tokens) => OAuthTokensStorageState::Found(tokens),
None => OAuthTokensStorageState::Missing,
})
}
OAuthCredentialsStoreMode::Keyring => Ok(
match load_oauth_tokens_from_keyring(&keyring_store, server_name, url)
.with_context(|| "failed to read OAuth tokens from keyring".to_string())?
{
Some(tokens) => OAuthTokensStorageState::Found(tokens),
None => OAuthTokensStorageState::Missing,
},
),
}
}
pub(crate) fn has_oauth_tokens(
server_name: &str,
url: &str,
@@ -134,6 +172,48 @@ fn load_oauth_tokens_from_keyring_with_fallback_to_file<K: KeyringStore>(
}
}
fn stored_tokens_match_without_expires_in(
actual: &StoredOAuthTokens,
expected: &StoredOAuthTokens,
) -> bool {
let actual_response = &actual.token_response.0;
let expected_response = &expected.token_response.0;
actual.server_name == expected.server_name
&& actual.url == expected.url
&& actual.client_id == expected.client_id
&& actual.expires_at == expected.expires_at
&& actual_response.access_token().secret() == expected_response.access_token().secret()
&& actual_response.token_type() == expected_response.token_type()
&& actual_response.refresh_token().map(RefreshToken::secret)
== expected_response.refresh_token().map(RefreshToken::secret)
&& actual_response.scopes() == expected_response.scopes()
&& actual_response.extra_fields() == expected_response.extra_fields()
}
fn load_oauth_tokens_for_persist_from_keyring_with_fallback_to_file<K: KeyringStore>(
keyring_store: &K,
server_name: &str,
url: &str,
) -> Result<OAuthTokensStorageState> {
match load_oauth_tokens_from_keyring(keyring_store, server_name, url) {
Ok(Some(tokens)) => Ok(OAuthTokensStorageState::Found(tokens)),
Ok(None) => Ok(match load_oauth_tokens_from_file(server_name, url)? {
Some(tokens) => OAuthTokensStorageState::Found(tokens),
None => OAuthTokensStorageState::Missing,
}),
Err(error) => {
warn!("failed to read OAuth tokens from keyring: {error}");
match load_oauth_tokens_from_file(server_name, url)
.with_context(|| format!("failed to read OAuth tokens from keyring: {error}"))?
{
Some(tokens) => Ok(OAuthTokensStorageState::Found(tokens)),
None => Ok(OAuthTokensStorageState::Unavailable),
}
}
}
}
fn load_oauth_tokens_from_keyring<K: KeyringStore>(
keyring_store: &K,
server_name: &str,
@@ -256,116 +336,161 @@ pub(crate) struct OAuthPersistor {
struct OAuthPersistorInner {
server_name: String,
url: String,
authorization_manager: Arc<Mutex<AuthorizationManager>>,
oauth_http: OAuthHttpClient,
store_mode: OAuthCredentialsStoreMode,
last_credentials: Mutex<Option<StoredOAuthTokens>>,
credentials_state: Mutex<OAuthCredentialsState>,
refresh_gate: Semaphore,
}
struct OAuthCredentialsState {
current: Option<StoredOAuthTokens>,
last_persisted: Option<StoredOAuthTokens>,
}
impl OAuthPersistor {
pub(crate) fn new(
server_name: String,
url: String,
authorization_manager: Arc<Mutex<AuthorizationManager>>,
oauth_http: OAuthHttpClient,
store_mode: OAuthCredentialsStoreMode,
initial_credentials: Option<StoredOAuthTokens>,
) -> Self {
Self {
inner: Arc::new(OAuthPersistorInner {
server_name,
url,
authorization_manager,
oauth_http,
store_mode,
last_credentials: Mutex::new(initial_credentials),
credentials_state: Mutex::new(OAuthCredentialsState {
current: initial_credentials.clone(),
last_persisted: initial_credentials,
}),
refresh_gate: Semaphore::new(1),
}),
}
}
/// Persists the latest stored credentials if they have changed.
/// Deletes the credentials if they are no longer present.
#[expect(
clippy::await_holding_invalid_type,
reason = "AuthorizationManager async access must be serialized through its mutex"
)]
pub(crate) async fn persist_if_needed(&self) -> Result<()> {
let (client_id, maybe_credentials) = {
let manager = self.inner.authorization_manager.clone();
let guard = manager.lock().await;
guard.get_credentials().await
}?;
match maybe_credentials {
Some(credentials) => {
let mut last_credentials = self.inner.last_credentials.lock().await;
let new_token_response = WrappedOAuthTokenResponse(credentials.clone());
let same_token = last_credentials
.as_ref()
.map(|prev| prev.token_response == new_token_response)
.unwrap_or(false);
let expires_at = if same_token {
last_credentials.as_ref().and_then(|prev| prev.expires_at)
} else {
compute_expires_at_millis(&credentials)
let mut credentials_state = self.inner.credentials_state.lock().await;
let Some(stored) = credentials_state.current.clone() else {
return Ok(());
};
match load_oauth_tokens_for_persist(
&self.inner.server_name,
&stored.url,
self.inner.store_mode,
)? {
OAuthTokensStorageState::Missing => {
credentials_state.current = None;
credentials_state.last_persisted = None;
}
OAuthTokensStorageState::Unavailable => {
let should_save = match credentials_state.last_persisted.as_ref() {
Some(last_persisted) => {
!stored_tokens_match_without_expires_in(&stored, last_persisted)
}
None => true,
};
let stored = StoredOAuthTokens {
server_name: self.inner.server_name.clone(),
url: self.inner.url.clone(),
client_id,
token_response: new_token_response,
expires_at,
};
if last_credentials.as_ref() != Some(&stored) {
if should_save {
save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?;
*last_credentials = Some(stored);
credentials_state.last_persisted = Some(stored);
}
}
None => {
let mut last_serialized = self.inner.last_credentials.lock().await;
if last_serialized.take().is_some()
&& let Err(error) = delete_oauth_tokens(
&self.inner.server_name,
&self.inner.url,
self.inner.store_mode,
)
{
warn!(
"failed to remove OAuth tokens for server {}: {error}",
self.inner.server_name
);
OAuthTokensStorageState::Found(current_store)
if stored_tokens_match_without_expires_in(&current_store, &stored) =>
{
credentials_state.last_persisted = Some(current_store);
}
OAuthTokensStorageState::Found(current_store) => {
let store_matches_last_persisted = credentials_state
.last_persisted
.as_ref()
.is_some_and(|last_persisted| {
stored_tokens_match_without_expires_in(&current_store, last_persisted)
});
if store_matches_last_persisted {
save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?;
credentials_state.last_persisted = Some(stored);
} else {
credentials_state.current = Some(current_store.clone());
credentials_state.last_persisted = Some(current_store);
}
}
}
Ok(())
}
#[expect(
clippy::await_holding_invalid_type,
reason = "AuthorizationManager async access must be serialized through its mutex"
)]
pub(crate) async fn refresh_if_needed(&self) -> Result<()> {
let expires_at = {
let guard = self.inner.last_credentials.lock().await;
guard.as_ref().and_then(|tokens| tokens.expires_at)
let _permit = self
.inner
.refresh_gate
.acquire()
.await
.context("OAuth refresh gate was closed")?;
let tokens = {
let state = self.inner.credentials_state.lock().await;
let Some(tokens) = state.current.as_ref() else {
return Ok(());
};
if !token_needs_refresh(tokens.expires_at) {
return Ok(());
}
tokens.clone()
};
if !token_needs_refresh(expires_at) {
return Ok(());
}
{
let manager = self.inner.authorization_manager.clone();
let guard = manager.lock().await;
guard.refresh_token().await.with_context(|| {
if let Some(refreshed) = self
.inner
.oauth_http
.refresh_token(&tokens)
.await
.with_context(|| {
format!(
"failed to refresh OAuth tokens for server {}",
self.inner.server_name
)
})?;
})?
{
let mut state = self.inner.credentials_state.lock().await;
if state.current.as_ref() == Some(&tokens) {
state.current = Some(refreshed);
}
}
self.persist_if_needed().await
}
pub(crate) async fn access_token(&self) -> Result<Option<String>> {
let cached = {
let state = self.inner.credentials_state.lock().await;
state.current.as_ref().map(|tokens| {
(
tokens.token_response.0.access_token().secret().to_string(),
tokens.expires_at,
)
})
};
if let Err(err) = self.refresh_if_needed().await {
if let Some((access_token, expires_at)) = cached
&& expires_in_from_timestamp(expires_at.unwrap_or(u64::MAX)).is_some()
{
warn!(
"failed to refresh OAuth tokens for server {}; using cached access token: {err:#}",
self.inner.server_name
);
return Ok(Some(access_token));
}
return Err(err);
}
let state = self.inner.credentials_state.lock().await;
let Some(tokens) = state.current.as_ref() else {
return Ok(None);
};
if expires_in_from_timestamp(tokens.expires_at.unwrap_or(u64::MAX)).is_none() {
return Ok(None);
}
Ok(Some(
tokens.token_response.0.access_token().secret().to_string(),
))
}
}
const FALLBACK_FILENAME: &str = ".credentials.json";
@@ -595,6 +720,12 @@ fn sha_256_prefix(value: &Value) -> Result<String> {
mod tests {
use super::*;
use anyhow::Result;
use codex_exec_server::ExecServerError;
use codex_exec_server::HttpClient;
use codex_exec_server::HttpRequestParams;
use codex_exec_server::HttpRequestResponse;
use codex_exec_server::HttpResponseBodyStream;
use futures::future::BoxFuture;
use keyring::Error as KeyringError;
use pretty_assertions::assert_eq;
use std::sync::Mutex;
@@ -693,6 +824,24 @@ mod tests {
Ok(())
}
#[test]
fn load_oauth_tokens_for_persist_keeps_keyring_errors_distinct() -> Result<()> {
let _env = TempCodexHome::new();
let store = MockKeyringStore::default();
let tokens = sample_tokens();
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
store.set_error(&key, KeyringError::Invalid("error".into(), "load".into()));
let state = super::load_oauth_tokens_for_persist_from_keyring_with_fallback_to_file(
&store,
&tokens.server_name,
&tokens.url,
)?;
assert!(matches!(state, super::OAuthTokensStorageState::Unavailable));
Ok(())
}
#[test]
fn save_oauth_tokens_prefers_keyring_when_available() -> Result<()> {
let _env = TempCodexHome::new();
@@ -846,6 +995,197 @@ mod tests {
assert!(tokens.token_response.0.expires_in().is_none());
}
#[tokio::test]
async fn persistor_does_not_recreate_deleted_file_credentials() -> Result<()> {
let _env = TempCodexHome::new();
let tokens = sample_tokens();
super::save_oauth_tokens_to_file(&tokens)?;
assert!(super::fallback_file_path()?.exists());
super::delete_oauth_tokens(
&tokens.server_name,
&tokens.url,
OAuthCredentialsStoreMode::File,
)?;
let persistor = OAuthPersistor::new(
tokens.server_name.clone(),
OAuthHttpClient::new(
Arc::new(FailingHttpClient),
/*http_headers*/ None,
/*env_http_headers*/ None,
)?,
OAuthCredentialsStoreMode::File,
Some(tokens.clone()),
);
persistor.persist_if_needed().await?;
assert!(
super::load_oauth_tokens(
&tokens.server_name,
&tokens.url,
OAuthCredentialsStoreMode::File
)?
.is_none()
);
Ok(())
}
#[tokio::test]
async fn persistor_does_not_clobber_newer_file_credentials() -> Result<()> {
let _env = TempCodexHome::new();
let tokens = sample_tokens();
super::save_oauth_tokens_to_file(&tokens)?;
let persistor = OAuthPersistor::new(
tokens.server_name.clone(),
OAuthHttpClient::new(
Arc::new(FailingHttpClient),
/*http_headers*/ None,
/*env_http_headers*/ None,
)?,
OAuthCredentialsStoreMode::File,
Some(tokens.clone()),
);
let newer_tokens = sample_tokens_with_access_token("newer-access-token");
super::save_oauth_tokens_to_file(&newer_tokens)?;
persistor.persist_if_needed().await?;
let loaded = super::load_oauth_tokens(
&tokens.server_name,
&tokens.url,
OAuthCredentialsStoreMode::File,
)?
.expect("newer tokens should remain stored");
assert_tokens_match_without_expiry(&loaded, &newer_tokens);
Ok(())
}
#[tokio::test]
async fn persistor_saves_newer_memory_credentials_despite_expires_in_drift() -> Result<()> {
let _env = TempCodexHome::new();
let mut stored_tokens = sample_tokens();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0));
stored_tokens.expires_at = Some(now.as_millis() as u64 + 10_000);
let expires_in = Duration::from_secs(3600);
stored_tokens
.token_response
.0
.set_expires_in(Some(&expires_in));
super::save_oauth_tokens_to_file(&stored_tokens)?;
let persistor = OAuthPersistor::new(
stored_tokens.server_name.clone(),
OAuthHttpClient::new(
Arc::new(FailingHttpClient),
/*http_headers*/ None,
/*env_http_headers*/ None,
)?,
OAuthCredentialsStoreMode::File,
Some(stored_tokens.clone()),
);
let newer_tokens = sample_tokens_with_access_token("newer-access-token");
{
let mut state = persistor.inner.credentials_state.lock().await;
state.current = Some(newer_tokens.clone());
state.last_persisted = Some(stored_tokens);
}
persistor.persist_if_needed().await?;
let loaded = super::load_oauth_tokens(
&newer_tokens.server_name,
&newer_tokens.url,
OAuthCredentialsStoreMode::File,
)?
.expect("newer tokens should be stored");
assert_tokens_match_without_expiry(&loaded, &newer_tokens);
Ok(())
}
#[tokio::test]
async fn access_token_uses_cached_token_when_refresh_fails_before_expiry() -> Result<()> {
let mut tokens = sample_tokens();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0));
tokens.expires_at = Some(now.as_millis() as u64 + 10_000);
let persistor = OAuthPersistor::new(
tokens.server_name.clone(),
OAuthHttpClient::new(
Arc::new(FailingHttpClient),
/*http_headers*/ None,
/*env_http_headers*/ None,
)?,
OAuthCredentialsStoreMode::File,
Some(tokens),
);
assert_eq!(
persistor.access_token().await?,
Some("access-token".to_string())
);
Ok(())
}
#[tokio::test]
async fn access_token_rejects_expired_token_without_refresh_token() -> Result<()> {
let mut tokens = sample_tokens();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0));
tokens.expires_at = Some((now.as_millis() as u64).saturating_sub(1000));
tokens.token_response.0.set_refresh_token(None);
let persistor = OAuthPersistor::new(
tokens.server_name.clone(),
OAuthHttpClient::new(
Arc::new(FailingHttpClient),
/*http_headers*/ None,
/*env_http_headers*/ None,
)?,
OAuthCredentialsStoreMode::File,
Some(tokens),
);
assert_eq!(persistor.access_token().await?, None);
Ok(())
}
struct FailingHttpClient;
impl HttpClient for FailingHttpClient {
fn http_request(
&self,
_params: HttpRequestParams,
) -> BoxFuture<'_, std::result::Result<HttpRequestResponse, ExecServerError>> {
Box::pin(async {
Err(ExecServerError::Disconnected(
"forced refresh failure".to_string(),
))
})
}
fn http_request_stream(
&self,
_params: HttpRequestParams,
) -> BoxFuture<
'_,
std::result::Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>,
> {
Box::pin(async {
Err(ExecServerError::Disconnected(
"forced refresh failure".to_string(),
))
})
}
}
fn assert_tokens_match_without_expiry(
actual: &StoredOAuthTokens,
expected: &StoredOAuthTokens,
@@ -888,8 +1228,12 @@ mod tests {
}
fn sample_tokens() -> StoredOAuthTokens {
sample_tokens_with_access_token("access-token")
}
fn sample_tokens_with_access_token(access_token: &str) -> StoredOAuthTokens {
let mut response = OAuthTokenResponse::new(
AccessToken::new("access-token".to_string()),
AccessToken::new(access_token.to_string()),
BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);

View File

@@ -7,9 +7,8 @@ use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use anyhow::bail;
use reqwest::ClientBuilder;
use codex_exec_server::HttpClient;
use reqwest::Url;
use rmcp::transport::auth::OAuthState;
use tiny_http::Response;
use tiny_http::Server;
use tokio::sync::oneshot;
@@ -18,10 +17,11 @@ use urlencoding::decode;
use crate::StoredOAuthTokens;
use crate::WrappedOAuthTokenResponse;
use crate::auth_status::no_proxy_http_client;
use crate::mcp_oauth_http::OAuthAuthorizationSession;
use crate::mcp_oauth_http::OAuthHttpClient;
use crate::oauth::compute_expires_at_millis;
use crate::save_oauth_tokens;
use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
use codex_config::types::OAuthCredentialsStoreMode;
struct OauthHeaders {
@@ -80,6 +80,34 @@ pub async fn perform_oauth_login(
oauth_resource: Option<&str>,
callback_port: Option<u16>,
callback_url: Option<&str>,
) -> Result<()> {
perform_oauth_login_with_client(
server_name,
server_url,
store_mode,
http_headers,
env_http_headers,
scopes,
oauth_resource,
callback_port,
callback_url,
no_proxy_http_client(),
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn perform_oauth_login_with_client(
server_name: &str,
server_url: &str,
store_mode: OAuthCredentialsStoreMode,
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
scopes: &[String],
oauth_resource: Option<&str>,
callback_port: Option<u16>,
callback_url: Option<&str>,
http_client: Arc<dyn HttpClient>,
) -> Result<()> {
perform_oauth_login_with_browser_output(
server_name,
@@ -92,6 +120,7 @@ pub async fn perform_oauth_login(
callback_port,
callback_url,
/*emit_browser_url*/ true,
http_client,
)
.await
}
@@ -107,6 +136,34 @@ pub async fn perform_oauth_login_silent(
oauth_resource: Option<&str>,
callback_port: Option<u16>,
callback_url: Option<&str>,
) -> Result<()> {
perform_oauth_login_silent_with_client(
server_name,
server_url,
store_mode,
http_headers,
env_http_headers,
scopes,
oauth_resource,
callback_port,
callback_url,
no_proxy_http_client(),
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn perform_oauth_login_silent_with_client(
server_name: &str,
server_url: &str,
store_mode: OAuthCredentialsStoreMode,
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
scopes: &[String],
oauth_resource: Option<&str>,
callback_port: Option<u16>,
callback_url: Option<&str>,
http_client: Arc<dyn HttpClient>,
) -> Result<()> {
perform_oauth_login_with_browser_output(
server_name,
@@ -119,6 +176,7 @@ pub async fn perform_oauth_login_silent(
callback_port,
callback_url,
/*emit_browser_url*/ false,
http_client,
)
.await
}
@@ -135,6 +193,7 @@ async fn perform_oauth_login_with_browser_output(
callback_port: Option<u16>,
callback_url: Option<&str>,
emit_browser_url: bool,
http_client: Arc<dyn HttpClient>,
) -> Result<()> {
let headers = OauthHeaders {
http_headers,
@@ -151,6 +210,7 @@ async fn perform_oauth_login_with_browser_output(
callback_port,
callback_url,
/*timeout_secs*/ None,
http_client,
)
.await?
.finish(emit_browser_url)
@@ -169,6 +229,36 @@ pub async fn perform_oauth_login_return_url(
timeout_secs: Option<i64>,
callback_port: Option<u16>,
callback_url: Option<&str>,
) -> Result<OauthLoginHandle> {
perform_oauth_login_return_url_with_client(
server_name,
server_url,
store_mode,
http_headers,
env_http_headers,
scopes,
oauth_resource,
timeout_secs,
callback_port,
callback_url,
no_proxy_http_client(),
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn perform_oauth_login_return_url_with_client(
server_name: &str,
server_url: &str,
store_mode: OAuthCredentialsStoreMode,
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
scopes: &[String],
oauth_resource: Option<&str>,
timeout_secs: Option<i64>,
callback_port: Option<u16>,
callback_url: Option<&str>,
http_client: Arc<dyn HttpClient>,
) -> Result<OauthLoginHandle> {
let headers = OauthHeaders {
http_headers,
@@ -185,6 +275,7 @@ pub async fn perform_oauth_login_return_url(
callback_port,
callback_url,
timeout_secs,
http_client,
)
.await?;
@@ -329,7 +420,8 @@ impl OauthLoginHandle {
struct OauthLoginFlow {
auth_url: String,
oauth_state: OAuthState,
oauth_http: OAuthHttpClient,
oauth_session: Option<OAuthAuthorizationSession>,
rx: oneshot::Receiver<CallbackResult>,
guard: CallbackServerGuard,
server_name: String,
@@ -412,6 +504,7 @@ impl OauthLoginFlow {
callback_port: Option<u16>,
callback_url: Option<&str>,
timeout_secs: Option<i64>,
http_client: Arc<dyn HttpClient>,
) -> Result<Self> {
const DEFAULT_OAUTH_TIMEOUT_SECS: i64 = 300;
@@ -437,25 +530,19 @@ impl OauthLoginFlow {
http_headers,
env_http_headers,
} = headers;
let default_headers = build_default_headers(http_headers, env_http_headers)?;
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
oauth_state
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
let oauth_http = OAuthHttpClient::new(http_client, http_headers, env_http_headers)?;
let oauth_session = oauth_http
.start_authorization(server_url, scopes, &redirect_uri, "Codex")
.await?;
let auth_url = append_query_param(
&oauth_state.get_authorization_url().await?,
"resource",
oauth_resource,
);
let auth_url =
append_query_param(&oauth_session.authorization_url, "resource", oauth_resource);
let timeout_secs = timeout_secs.unwrap_or(DEFAULT_OAUTH_TIMEOUT_SECS).max(1);
let timeout = Duration::from_secs(timeout_secs as u64);
Ok(Self {
auth_url,
oauth_state,
oauth_http,
oauth_session: Some(oauth_session),
rx,
guard,
server_name: server_name.to_string(),
@@ -503,19 +590,16 @@ impl OauthLoginFlow {
CallbackResult::Error(error) => return Err(anyhow!(error)),
};
self.oauth_state
.handle_callback(&code, &csrf_state)
let oauth_session = self
.oauth_session
.take()
.ok_or_else(|| anyhow!("OAuth login flow was already completed"))?;
let (client_id, credentials) = self
.oauth_http
.exchange_code(oauth_session, &code, &csrf_state)
.await
.context("failed to handle OAuth callback")?;
let (client_id, credentials_opt) = self
.oauth_state
.get_credentials()
.await
.context("failed to retrieve OAuth credentials")?;
let credentials = credentials_opt
.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
let expires_at = compute_expires_at_millis(&credentials);
let stored = StoredOAuthTokens {
server_name: self.server_name.clone(),

View File

@@ -12,7 +12,6 @@ use std::time::Instant;
use anyhow::Result;
use anyhow::anyhow;
use codex_api::SharedAuthProvider;
use codex_client::build_reqwest_client_with_custom_ca;
use codex_config::types::McpServerEnvVar;
use codex_exec_server::HttpClient;
use futures::FutureExt;
@@ -45,9 +44,6 @@ use rmcp::service::RoleClient;
use rmcp::service::RunningService;
use rmcp::service::{self};
use rmcp::transport::StreamableHttpClientTransport;
use rmcp::transport::auth::AuthClient;
use rmcp::transport::auth::AuthError;
use rmcp::transport::auth::OAuthState;
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
use rmcp::transport::streamable_http_client::StreamableHttpError;
use serde::Deserialize;
@@ -63,12 +59,12 @@ use crate::elicitation_client_service::ElicitationClientService;
use crate::http_client_adapter::StreamableHttpClientAdapter;
use crate::http_client_adapter::StreamableHttpClientAdapterError;
use crate::load_oauth_tokens;
use crate::mcp_oauth_http::OAuthHttpClient;
use crate::oauth::OAuthPersistor;
use crate::oauth::StoredOAuthTokens;
use crate::stdio_server_launcher::StdioServerCommand;
use crate::stdio_server_launcher::StdioServerLauncher;
use crate::stdio_server_launcher::StdioServerTransport;
use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
use codex_config::types::OAuthCredentialsStoreMode;
@@ -80,7 +76,7 @@ enum PendingTransport {
transport: StreamableHttpClientTransport<StreamableHttpClientAdapter>,
},
StreamableHttpWithOAuth {
transport: StreamableHttpClientTransport<AuthClient<StreamableHttpClientAdapter>>,
transport: StreamableHttpClientTransport<StreamableHttpClientAdapter>,
oauth_persistor: OAuthPersistor,
},
}
@@ -708,11 +704,7 @@ impl RmcpClient {
oauth_persistor,
})
}
Err(err)
if err.downcast_ref::<AuthError>().is_some_and(|auth_err| {
matches!(auth_err, AuthError::NoAuthorizationSupport)
}) =>
{
Err(err) if err.to_string().contains("No authorization support") => {
let access_token = initial_tokens
.token_response
.0
@@ -730,6 +722,7 @@ impl RmcpClient {
Arc::clone(http_client),
default_headers,
/*auth_provider*/ None,
/*oauth_persistor*/ None,
),
http_config,
);
@@ -749,6 +742,7 @@ impl RmcpClient {
Arc::clone(http_client),
default_headers,
auth_provider.clone(),
/*oauth_persistor*/ None,
),
http_config,
);
@@ -943,52 +937,32 @@ async fn create_oauth_transport_and_runtime(
default_headers: HeaderMap,
http_client: Arc<dyn HttpClient>,
) -> Result<(
StreamableHttpClientTransport<AuthClient<StreamableHttpClientAdapter>>,
StreamableHttpClientTransport<StreamableHttpClientAdapter>,
OAuthPersistor,
)> {
let builder = apply_default_headers(reqwest::Client::builder(), &default_headers);
let oauth_metadata_client = build_reqwest_client_with_custom_ca(builder)?;
// TODO(aibrahim): teach OAuth bootstrap and refresh to use the same
// shared HTTP client abstraction instead of always creating the local
// reqwest metadata client here.
let mut oauth_state =
OAuthState::new(url.to_string(), Some(oauth_metadata_client.clone())).await?;
let oauth_http =
OAuthHttpClient::from_default_headers(Arc::clone(&http_client), default_headers.clone());
if oauth_http.discover(url).await?.is_none() {
anyhow::bail!("No authorization support detected");
}
oauth_state
.set_credentials(
&initial_tokens.client_id,
initial_tokens.token_response.0.clone(),
)
.await?;
let manager = match oauth_state {
OAuthState::Authorized(manager) => manager,
OAuthState::Unauthorized(manager) => manager,
OAuthState::Session(_) | OAuthState::AuthorizedHttpClient(_) => {
return Err(anyhow!("unexpected OAuth state during client setup"));
}
};
let auth_client = AuthClient::new(
StreamableHttpClientAdapter::new(http_client, default_headers, /*auth_provider*/ None),
manager,
);
let auth_manager = auth_client.auth_manager.clone();
let transport = StreamableHttpClientTransport::with_client(
auth_client,
StreamableHttpClientTransportConfig::with_uri(url.to_string()),
);
let runtime = OAuthPersistor::new(
let oauth_persistor = OAuthPersistor::new(
server_name.to_string(),
url.to_string(),
auth_manager,
oauth_http,
credentials_store,
Some(initial_tokens),
);
Ok((transport, runtime))
let transport = StreamableHttpClientTransport::with_client(
StreamableHttpClientAdapter::new(
http_client,
default_headers,
/*auth_provider*/ None,
Some(oauth_persistor.clone()),
),
StreamableHttpClientTransportConfig::with_uri(url.to_string()),
);
Ok((transport, oauth_persistor))
}
#[cfg(test)]

View File

@@ -1,7 +1,6 @@
use anyhow::Result;
use anyhow::anyhow;
use codex_config::types::McpServerEnvVar;
use reqwest::ClientBuilder;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderName;
use reqwest::header::HeaderValue;
@@ -115,17 +114,6 @@ pub(crate) fn build_default_headers(
Ok(headers)
}
pub(crate) fn apply_default_headers(
builder: ClientBuilder,
default_headers: &HeaderMap,
) -> ClientBuilder {
if default_headers.is_empty() {
builder
} else {
builder.default_headers(default_headers.clone())
}
}
#[cfg(unix)]
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
"HOME",

View File

@@ -0,0 +1,309 @@
//! Regression coverage for remote-environment Streamable HTTP OAuth.
//!
//! The OAuth issuer in this test uses an unresolvable hostname. If any
//! discovery, registration, or token-exchange request bypasses the injected
//! `HttpClient`, the test fails before the callback can complete.
mod streamable_http_test_support;
use std::ffi::OsString;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::MutexGuard;
use std::sync::OnceLock;
use std::sync::PoisonError;
use codex_config::types::OAuthCredentialsStoreMode;
use codex_exec_server::ExecServerError;
use codex_exec_server::HttpClient;
use codex_exec_server::HttpHeader;
use codex_exec_server::HttpRequestParams;
use codex_exec_server::HttpRequestResponse;
use codex_exec_server::HttpResponseBodyStream;
use codex_rmcp_client::StoredOAuthTokens;
use codex_rmcp_client::WrappedOAuthTokenResponse;
use codex_rmcp_client::perform_oauth_login_return_url_with_client;
use codex_rmcp_client::save_oauth_tokens;
use futures::FutureExt as _;
use futures::future::BoxFuture;
use oauth2::AccessToken;
use oauth2::EmptyExtraTokenFields;
use oauth2::RefreshToken;
use oauth2::basic::BasicTokenType;
use pretty_assertions::assert_eq;
use rmcp::transport::auth::OAuthTokenResponse;
use serde_json::Value;
use serde_json::json;
use serial_test::serial;
use tempfile::TempDir;
use streamable_http_test_support::call_echo_tool;
use streamable_http_test_support::create_remote_oauth_client;
use streamable_http_test_support::expected_echo_result;
use streamable_http_test_support::spawn_exec_server;
use streamable_http_test_support::spawn_streamable_http_server_with_oauth_bearer;
#[derive(Clone, Default)]
struct RecordingHttpClient {
requests: Arc<Mutex<Vec<HttpRequestParams>>>,
}
impl RecordingHttpClient {
fn recorded_requests(&self) -> Vec<HttpRequestParams> {
self.requests
.lock()
.unwrap_or_else(PoisonError::into_inner)
.clone()
}
}
impl HttpClient for RecordingHttpClient {
fn http_request(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>> {
let requests = Arc::clone(&self.requests);
async move {
requests
.lock()
.unwrap_or_else(PoisonError::into_inner)
.push(params.clone());
let response = match (params.method.as_str(), params.url.as_str()) {
("GET", "http://oauth.test/.well-known/oauth-authorization-server/mcp") => {
json_response(json!({
"authorization_endpoint": "http://oauth.test/authorize",
"token_endpoint": "http://oauth.test/token",
"registration_endpoint": "http://oauth.test/register",
"scopes_supported": ["tools.read", "tools.write"],
"response_types_supported": ["code"],
}))?
}
("POST", "http://oauth.test/register") => json_response(json!({
"client_id": "registered-client",
}))?,
("POST", "http://oauth.test/token") => json_response(json!({
"access_token": "remote-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "remote-refresh-token",
}))?,
_ => {
return Err(ExecServerError::HttpRequest(format!(
"unexpected HTTP request: {} {}",
params.method, params.url
)));
}
};
Ok(response)
}
.boxed()
}
fn http_request_stream(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> {
async move {
Err(ExecServerError::HttpRequest(format!(
"unexpected streaming HTTP request: {} {}",
params.method, params.url
)))
}
.boxed()
}
}
fn json_response(body: Value) -> Result<HttpRequestResponse, ExecServerError> {
let body = serde_json::to_vec(&body)
.map_err(|err| ExecServerError::HttpRequest(format!("serialize JSON response: {err}")))?;
Ok(HttpRequestResponse {
status: 200,
headers: vec![HttpHeader {
name: "content-type".to_string(),
value: "application/json".to_string(),
}],
body: body.into(),
})
}
struct TempCodexHome {
_guard: MutexGuard<'static, ()>,
previous: Option<OsString>,
_dir: TempDir,
}
impl TempCodexHome {
fn new() -> anyhow::Result<Self> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
let guard = LOCK
.get_or_init(Mutex::default)
.lock()
.unwrap_or_else(PoisonError::into_inner);
let previous = std::env::var_os("CODEX_HOME");
let dir = TempDir::new()?;
unsafe {
std::env::set_var("CODEX_HOME", dir.path());
}
Ok(Self {
_guard: guard,
previous,
_dir: dir,
})
}
}
impl Drop for TempCodexHome {
fn drop(&mut self) {
unsafe {
match self.previous.as_ref() {
Some(value) => std::env::set_var("CODEX_HOME", value),
None => std::env::remove_var("CODEX_HOME"),
}
}
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[serial]
async fn browser_callback_flow_uses_injected_http_client_for_oauth_requests() -> anyhow::Result<()>
{
let _codex_home = TempCodexHome::new()?;
let http_client = RecordingHttpClient::default();
let handle = perform_oauth_login_return_url_with_client(
"remote-oauth-test",
"http://oauth.test/mcp",
OAuthCredentialsStoreMode::File,
/*http_headers*/ None,
/*env_http_headers*/ None,
&["tools.read".to_string()],
/*oauth_resource*/ None,
Some(10),
/*callback_port*/ None,
/*callback_url*/ None,
Arc::new(http_client.clone()),
)
.await?;
let authorization_url = reqwest::Url::parse(handle.authorization_url())?;
let mut state = None;
let mut redirect_uri = None;
for (name, value) in authorization_url.query_pairs() {
match name.as_ref() {
"state" => state = Some(value.into_owned()),
"redirect_uri" => redirect_uri = Some(value.into_owned()),
_ => {}
}
}
let state = state.expect("authorization URL includes state");
let redirect_uri = redirect_uri.expect("authorization URL includes redirect_uri");
let callback_url = format!("{redirect_uri}?code=provider-code&state={state}");
let callback_response = reqwest::get(callback_url).await?;
assert_eq!(callback_response.status(), reqwest::StatusCode::OK);
handle.wait().await?;
let requests = http_client.recorded_requests();
assert_eq!(
requests
.iter()
.map(|request| (request.method.as_str(), request.url.as_str()))
.collect::<Vec<_>>(),
vec![
(
"GET",
"http://oauth.test/.well-known/oauth-authorization-server/mcp"
),
("POST", "http://oauth.test/register"),
("POST", "http://oauth.test/token"),
]
);
let registration_body: Value = serde_json::from_slice(
&requests[1]
.body
.as_ref()
.expect("registration request has a body")
.0,
)?;
assert_eq!(
registration_body,
json!({
"client_name": "Codex",
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code", "refresh_token"],
"token_endpoint_auth_method": "none",
"response_types": ["code"],
})
);
let token_body = String::from_utf8(
requests[2]
.body
.as_ref()
.expect("token request has a body")
.0
.clone(),
)?;
assert!(token_body.contains("grant_type=authorization_code"));
assert!(token_body.contains("code=provider-code"));
assert!(token_body.contains("client_id=registered-client"));
assert!(token_body.contains("code_verifier="));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[serial]
async fn stored_oauth_refreshes_and_authenticates_remote_streamable_http() -> anyhow::Result<()> {
let _codex_home = TempCodexHome::new()?;
let refresh_token = "remote-refresh-token";
let refreshed_access_token = "refreshed-remote-access-token";
let (_server, base_url) =
spawn_streamable_http_server_with_oauth_bearer(refreshed_access_token, refresh_token)
.await?;
let exec_server = spawn_exec_server().await?;
save_oauth_tokens(
"test-streamable-http-remote-oauth",
&expired_oauth_tokens(
"test-streamable-http-remote-oauth",
&format!("{base_url}/mcp"),
"expired-access-token",
refresh_token,
),
OAuthCredentialsStoreMode::File,
)?;
let client = create_remote_oauth_client(&base_url, exec_server.client.clone()).await?;
let result = call_echo_tool(&client, "remote-oauth").await?;
assert_eq!(result, expected_echo_result("remote-oauth"));
Ok(())
}
fn expired_oauth_tokens(
server_name: &str,
url: &str,
access_token: &str,
refresh_token: &str,
) -> StoredOAuthTokens {
let mut token_response = OAuthTokenResponse::new(
AccessToken::new(access_token.to_string()),
BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
token_response.set_refresh_token(Some(RefreshToken::new(refresh_token.to_string())));
StoredOAuthTokens {
server_name: server_name.to_string(),
url: url.to_string(),
client_id: "stored-client".to_string(),
token_response: WrappedOAuthTokenResponse(token_response),
expires_at: Some(0),
}
}

View File

@@ -128,10 +128,40 @@ pub(crate) async fn create_remote_client(
base_url: &str,
http_client: ExecServerClient,
) -> anyhow::Result<RmcpClient> {
let client = RmcpClient::new_streamable_http_client(
create_remote_client_with_bearer(
"test-streamable-http-remote",
&format!("{base_url}/mcp"),
base_url,
Some("test-bearer".to_string()),
http_client,
)
.await
}
/// Creates a Streamable HTTP RMCP client that authenticates using stored OAuth
/// credentials through the remote runtime HTTP API.
pub(crate) async fn create_remote_oauth_client(
base_url: &str,
http_client: ExecServerClient,
) -> anyhow::Result<RmcpClient> {
create_remote_client_with_bearer(
"test-streamable-http-remote-oauth",
base_url,
/*bearer_token*/ None,
http_client,
)
.await
}
async fn create_remote_client_with_bearer(
server_name: &str,
base_url: &str,
bearer_token: Option<String>,
http_client: ExecServerClient,
) -> anyhow::Result<RmcpClient> {
let client = RmcpClient::new_streamable_http_client(
server_name,
&format!("{base_url}/mcp"),
bearer_token,
/*http_headers*/ None,
/*env_http_headers*/ None,
OAuthCredentialsStoreMode::File,
@@ -193,16 +223,38 @@ pub(crate) async fn arm_session_post_failure(
}
pub(crate) async fn spawn_streamable_http_server() -> anyhow::Result<(Child, String)> {
spawn_streamable_http_server_with_env(&[]).await
}
pub(crate) async fn spawn_streamable_http_server_with_oauth_bearer(
bearer_token: &str,
refresh_token: &str,
) -> anyhow::Result<(Child, String)> {
spawn_streamable_http_server_with_env(&[
("MCP_EXPECT_BEARER", bearer_token),
("MCP_OAUTH_ACCESS_TOKEN", bearer_token),
("MCP_OAUTH_REFRESH_TOKEN", refresh_token),
])
.await
}
async fn spawn_streamable_http_server_with_env(
env: &[(&str, &str)],
) -> anyhow::Result<(Child, String)> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
drop(listener);
let bind_addr = format!("127.0.0.1:{port}");
let base_url = format!("http://{bind_addr}");
let mut child = Command::new(streamable_http_server_bin()?)
let mut command = Command::new(streamable_http_server_bin()?);
command
.kill_on_drop(true)
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
.spawn()?;
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr);
for (name, value) in env {
command.env(name, value);
}
let mut child = command.spawn()?;
wait_for_streamable_http_server(&mut child, &bind_addr, Duration::from_secs(5)).await?;
Ok((child, base_url))