Compare commits

...

10 Commits

Author SHA1 Message Date
Ahmed Ibrahim
f128f30011 Merge main into dev/remote-mcp-http-oauth-full-ci
Co-authored-by: Codex <noreply@openai.com>
2026-04-25 13:41:29 +00:00
Ahmed Ibrahim
2596065878 Cover OAuth refresh in local and remote MCP tests
Co-authored-by: Codex <noreply@openai.com>
2026-04-22 19:58:08 -07:00
Ahmed Ibrahim
9713a34772 Use proxy bridge for OAuth HttpClient routing
Co-authored-by: Codex <noreply@openai.com>
2026-04-22 19:39:17 -07:00
Ahmed Ibrahim
cefc86e851 Merge branch 'dev/remote-mcp-http-wire-full-ci' into dev/remote-mcp-http-oauth-full-ci 2026-04-22 19:31:26 -07:00
Ahmed Ibrahim
fc8c0aa238 Merge branch 'main' into dev/remote-mcp-http-wire-full-ci 2026-04-22 19:31:12 -07:00
Ahmed Ibrahim
f8ebe83fc9 Use HttpClient directly for OAuth bootstrap
Co-authored-by: Codex <noreply@openai.com>
2026-04-22 19:18:23 -07:00
Ahmed Ibrahim
b8780250d6 Route OAuth bootstrap through HttpClient
Co-authored-by: Codex <noreply@openai.com>
2026-04-22 18:33:47 -07:00
Ahmed Ibrahim
e716329890 codex: fix CI failure on PR #18584
Co-authored-by: Codex <noreply@openai.com>
2026-04-22 17:58:49 -07:00
Ahmed Ibrahim
6da1978ff4 Keep Streamable HTTP tests in rmcp_client suite
Co-authored-by: Codex <noreply@openai.com>
2026-04-22 17:51:02 -07:00
Ahmed Ibrahim
a92dc1f1b5 Wire remote streamable HTTP MCP
Co-authored-by: Codex <noreply@openai.com>
2026-04-22 17:41:40 -07:00
8 changed files with 896 additions and 48 deletions

1
codex-rs/Cargo.lock generated
View File

@@ -3163,6 +3163,7 @@ name = "codex-rmcp-client"
version = "0.0.0"
dependencies = [
"anyhow",
"async-trait",
"axum",
"bytes",
"codex-api",

View File

@@ -99,6 +99,7 @@ enum McpCallEvent {
}
const REMOTE_MCP_ENVIRONMENT: &str = "remote";
const LOCAL_MCP_ENVIRONMENT: &str = "local";
fn remote_aware_experimental_environment() -> Option<String> {
// These tests run locally in normal CI and against the Docker-backed
@@ -107,6 +108,10 @@ fn remote_aware_experimental_environment() -> Option<String> {
std::env::var_os(remote_env_env_var()).map(|_| REMOTE_MCP_ENVIRONMENT.to_string())
}
fn remote_only_experimental_environment() -> Option<String> {
std::env::var_os(remote_env_env_var()).map(|_| REMOTE_MCP_ENVIRONMENT.to_string())
}
/// Returns the stdio MCP test server command path for the active test placement.
///
/// Local test runs can execute the host-built test binary directly. Remote-aware
@@ -1851,6 +1856,53 @@ struct StreamableHttpTestServer {
process: StreamableHttpTestServerProcess,
}
#[derive(Clone, Copy)]
enum StreamableHttpTestServerBindMode {
HostVisible,
RemoteLoopbackOnly,
}
struct StreamableHttpTestServerOptions<'a> {
expected_env_value: &'a str,
expected_bearer: Option<&'a str>,
refreshed_access_token: Option<&'a str>,
bind_mode: StreamableHttpTestServerBindMode,
}
impl<'a> StreamableHttpTestServerOptions<'a> {
fn host_visible(expected_env_value: &'a str, expected_bearer: Option<&'a str>) -> Self {
Self {
expected_env_value,
expected_bearer,
refreshed_access_token: None,
bind_mode: StreamableHttpTestServerBindMode::HostVisible,
}
}
fn with_refreshed_access_token(mut self, refreshed_access_token: &'a str) -> Self {
self.refreshed_access_token = Some(refreshed_access_token);
self
}
fn remote_loopback_only(mut self) -> Self {
self.bind_mode = StreamableHttpTestServerBindMode::RemoteLoopbackOnly;
self
}
fn remote_loopback_only_if_remote(self, placement: OAuthRefreshPlacement) -> Self {
match placement {
OAuthRefreshPlacement::Local => self,
OAuthRefreshPlacement::Remote => self.remote_loopback_only(),
}
}
}
#[derive(Clone, Copy)]
enum OAuthTokenExpiry {
Valid,
Expired,
}
/// Tracks whether the Streamable HTTP test server runs on the host or remotely.
enum StreamableHttpTestServerProcess {
Local(Child),
@@ -1964,7 +2016,11 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
// it is a host process.
let expected_env_value = "propagated-env-http";
let Some(http_server) =
start_streamable_http_test_server(expected_env_value, /*expected_token*/ None).await?
start_streamable_http_test_server(StreamableHttpTestServerOptions::host_visible(
expected_env_value,
/*expected_bearer*/ None,
))
.await?
else {
return Ok(());
};
@@ -2150,8 +2206,10 @@ async fn streamable_http_with_oauth_round_trip_impl() -> anyhow::Result<()> {
let expected_token = "initial-access-token";
let client_id = "test-client-id";
let refresh_token = "initial-refresh-token";
let Some(http_server) =
start_streamable_http_test_server(expected_env_value, Some(expected_token)).await?
let Some(http_server) = start_streamable_http_test_server(
StreamableHttpTestServerOptions::host_visible(expected_env_value, Some(expected_token)),
)
.await?
else {
return Ok(());
};
@@ -2168,6 +2226,7 @@ async fn streamable_http_with_oauth_round_trip_impl() -> anyhow::Result<()> {
client_id,
expected_token,
refresh_token,
OAuthTokenExpiry::Valid,
)?;
// Phase 4: configure Codex with the OAuth-backed Streamable HTTP MCP
@@ -2287,10 +2346,236 @@ async fn streamable_http_with_oauth_round_trip_impl() -> anyhow::Result<()> {
Ok(())
}
#[test]
#[serial(codex_home)]
fn streamable_http_with_oauth_refresh_round_trip_local() -> anyhow::Result<()> {
run_streamable_http_with_oauth_refresh_round_trip(OAuthRefreshPlacement::Local)
}
#[test]
#[serial(codex_home)]
fn streamable_http_with_oauth_refresh_round_trip_remote() -> anyhow::Result<()> {
if remote_only_experimental_environment().is_none() {
return Ok(());
}
run_streamable_http_with_oauth_refresh_round_trip(OAuthRefreshPlacement::Remote)
}
#[derive(Clone, Copy)]
enum OAuthRefreshPlacement {
Local,
Remote,
}
fn run_streamable_http_with_oauth_refresh_round_trip(
placement: OAuthRefreshPlacement,
) -> anyhow::Result<()> {
const TEST_STACK_SIZE_BYTES: usize = 8 * 1024 * 1024;
let thread_name = match placement {
OAuthRefreshPlacement::Local => "streamable_http_with_oauth_refresh_round_trip_local",
OAuthRefreshPlacement::Remote => "streamable_http_with_oauth_refresh_round_trip_remote",
};
let handle = std::thread::Builder::new()
.name(thread_name.to_string())
.stack_size(TEST_STACK_SIZE_BYTES)
.spawn(move || -> anyhow::Result<()> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(1)
.enable_all()
.build()?;
runtime.block_on(streamable_http_with_oauth_refresh_round_trip_impl(
placement,
))
})?;
match handle.join() {
Ok(result) => result,
Err(_) => Err(anyhow::anyhow!("{thread_name} thread panicked")),
}
}
#[allow(clippy::expect_used)]
async fn streamable_http_with_oauth_refresh_round_trip_impl(
placement: OAuthRefreshPlacement,
) -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let server = responses::start_mock_server().await;
let call_id = "call-790";
let server_name = match placement {
OAuthRefreshPlacement::Local => "rmcp_http_oauth_refresh_local",
OAuthRefreshPlacement::Remote => "rmcp_http_oauth_refresh_remote",
};
let tool_name = format!("mcp__{server_name}__echo");
let namespace = format!("mcp__{server_name}__");
mount_sse_once(
&server,
responses::sse(vec![
responses::ev_response_created("resp-1"),
responses::ev_function_call_with_namespace(
call_id,
&namespace,
"echo",
"{\"message\":\"ping\"}",
),
responses::ev_completed("resp-1"),
]),
)
.await;
mount_sse_once(
&server,
responses::sse(vec![
responses::ev_assistant_message(
"msg-1",
"rmcp streamable http oauth refresh echo tool completed successfully.",
),
responses::ev_completed("resp-2"),
]),
)
.await;
let expected_env_value = match placement {
OAuthRefreshPlacement::Local => "propagated-env-http-oauth-refresh-local",
OAuthRefreshPlacement::Remote => "propagated-env-http-oauth-refresh-remote",
};
let initial_access_token = "expired-access-token";
let refreshed_access_token = match placement {
OAuthRefreshPlacement::Local => "refreshed-access-token-local",
OAuthRefreshPlacement::Remote => "refreshed-access-token-remote",
};
let refresh_token = "initial-refresh-token";
// The remote case binds the test server to 127.0.0.1 inside the remote
// container so the orchestrator cannot reach metadata or token endpoints
// directly. If refresh still succeeds, it had to go through the selected
// remote `HttpClient`.
let Some(http_server) = start_streamable_http_test_server(
StreamableHttpTestServerOptions::host_visible(
expected_env_value,
Some(refreshed_access_token),
)
.with_refreshed_access_token(refreshed_access_token)
.remote_loopback_only_if_remote(placement),
)
.await?
else {
return Ok(());
};
let server_url = http_server.url().to_string();
let temp_home = Arc::new(tempdir()?);
let _codex_home_guard = EnvVarGuard::set("CODEX_HOME", temp_home.path().as_os_str());
write_fallback_oauth_tokens(
temp_home.path(),
server_name,
&server_url,
"test-client-id",
initial_access_token,
refresh_token,
OAuthTokenExpiry::Expired,
)?;
let experimental_environment = match placement {
OAuthRefreshPlacement::Local => Some(LOCAL_MCP_ENVIRONMENT.to_string()),
OAuthRefreshPlacement::Remote => Some(REMOTE_MCP_ENVIRONMENT.to_string()),
};
let fixture = test_codex()
.with_home(temp_home.clone())
.with_config(move |config| {
config.mcp_oauth_credentials_store_mode = serde_json::from_value(json!("file"))
.expect("`file` should deserialize as OAuthCredentialsStoreMode");
insert_mcp_server(
config,
server_name,
McpServerTransportConfig::StreamableHttp {
url: server_url,
bearer_token_env_var: None,
http_headers: None,
env_http_headers: None,
},
TestMcpServerOptions {
experimental_environment,
..Default::default()
},
);
})
.build_remote_aware(&server)
.await?;
let session_model = fixture.session_configured.model.clone();
wait_for_mcp_tool(&fixture, &tool_name).await?;
fixture
.codex
.submit(Op::UserTurn {
items: vec![UserInput::Text {
text: "call the rmcp streamable http oauth echo tool".into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
cwd: fixture.cwd.path().to_path_buf(),
approval_policy: AskForApproval::Never,
approvals_reviewer: None,
sandbox_policy: SandboxPolicy::new_read_only_policy(),
permission_profile: None,
model: session_model,
effort: None,
summary: None,
service_tier: None,
collaboration_mode: None,
personality: None,
environments: None,
})
.await?;
let end_event = wait_for_event(&fixture.codex, |ev| {
matches!(ev, EventMsg::McpToolCallEnd(_))
})
.await;
let EventMsg::McpToolCallEnd(end) = end_event else {
unreachable!("event guard guarantees McpToolCallEnd");
};
let result = end
.result
.as_ref()
.expect("rmcp echo tool should return success");
assert_eq!(result.is_error, Some(false));
let structured = result
.structured_content
.as_ref()
.expect("structured content");
let Value::Object(map) = structured else {
panic!("structured content should be an object: {structured:?}");
};
assert_eq!(
map.get("echo").and_then(Value::as_str),
Some("ECHOING: ping")
);
assert_eq!(
map.get("env").and_then(Value::as_str),
Some(expected_env_value)
);
wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
server.verify().await;
assert_eq!(
read_fallback_oauth_access_token(temp_home.path(), server_name, http_server.url())?,
refreshed_access_token
);
http_server.shutdown().await;
Ok(())
}
/// Starts the Streamable HTTP MCP test server in the active test placement.
async fn start_streamable_http_test_server(
expected_env_value: &str,
expected_token: Option<&str>,
options: StreamableHttpTestServerOptions<'_>,
) -> anyhow::Result<Option<StreamableHttpTestServer>> {
let rmcp_http_server_bin = match cargo_bin("test_streamable_http_server") {
Ok(path) => path,
@@ -2305,8 +2590,7 @@ async fn start_streamable_http_test_server(
start_remote_streamable_http_test_server(
&container_name,
&rmcp_http_server_bin,
expected_env_value,
expected_token,
options,
)
.await?,
));
@@ -2322,10 +2606,13 @@ async fn start_streamable_http_test_server(
command
.kill_on_drop(true)
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
.env("MCP_TEST_VALUE", expected_env_value);
if let Some(expected_token) = expected_token {
.env("MCP_TEST_VALUE", options.expected_env_value);
if let Some(expected_token) = options.expected_bearer {
command.env("MCP_EXPECT_BEARER", expected_token);
}
if let Some(refreshed_access_token) = options.refreshed_access_token {
command.env("MCP_REFRESHED_ACCESS_TOKEN", refreshed_access_token);
}
let mut child = command.spawn()?;
wait_for_local_streamable_http_server(&mut child, &server_url, Duration::from_secs(5)).await?;
@@ -2339,8 +2626,7 @@ async fn start_streamable_http_test_server(
async fn start_remote_streamable_http_test_server(
container_name: &str,
rmcp_http_server_bin: &Path,
expected_env_value: &str,
expected_token: Option<&str>,
options: StreamableHttpTestServerOptions<'_>,
) -> anyhow::Result<StreamableHttpTestServer> {
let remote_path = copy_binary_to_remote_env(
container_name,
@@ -2352,20 +2638,32 @@ async fn start_remote_streamable_http_test_server(
let mut env_assignments = vec![
format!(
"MCP_STREAMABLE_HTTP_BIND_ADDR={}",
sh_single_quote("0.0.0.0:0")
sh_single_quote(match options.bind_mode {
StreamableHttpTestServerBindMode::HostVisible => "0.0.0.0:0",
StreamableHttpTestServerBindMode::RemoteLoopbackOnly => "127.0.0.1:0",
})
),
format!(
"MCP_STREAMABLE_HTTP_BOUND_ADDR_FILE={}",
sh_single_quote(&bound_addr_file)
),
format!("MCP_TEST_VALUE={}", sh_single_quote(expected_env_value)),
format!(
"MCP_TEST_VALUE={}",
sh_single_quote(options.expected_env_value)
),
];
if let Some(expected_token) = expected_token {
if let Some(expected_token) = options.expected_bearer {
env_assignments.push(format!(
"MCP_EXPECT_BEARER={}",
sh_single_quote(expected_token)
));
}
if let Some(refreshed_access_token) = options.refreshed_access_token {
env_assignments.push(format!(
"MCP_REFRESHED_ACCESS_TOKEN={}",
sh_single_quote(refreshed_access_token)
));
}
let script = format!(
"{} nohup {} > {} 2>&1 < /dev/null & echo $!",
@@ -2395,13 +2693,21 @@ async fn start_remote_streamable_http_test_server(
let remote_bind_addr =
wait_for_remote_bound_addr(container_name, &bound_addr_file, Duration::from_secs(5))
.await?;
let container_ip = remote_container_ip(container_name)?;
let server_url = format!("http://{}:{}/mcp", container_ip, remote_bind_addr.port());
let server_host = match options.bind_mode {
StreamableHttpTestServerBindMode::HostVisible => remote_container_ip(container_name)?,
StreamableHttpTestServerBindMode::RemoteLoopbackOnly => "127.0.0.1".to_string(),
};
let server_url = format!("http://{}:{}/mcp", server_host, remote_bind_addr.port());
// The orchestrator can see the Docker container IP, but the behavior under
// test is whether the remote-side MCP client can reach it. Probe through
// remote HTTP before handing the URL to the Codex fixture.
wait_for_remote_streamable_http_server(&server_url, Duration::from_secs(5)).await?;
if expected_token.is_some() {
if options.expected_bearer.is_some()
&& matches!(
options.bind_mode,
StreamableHttpTestServerBindMode::HostVisible
)
{
wait_for_streamable_http_metadata(&server_url, Duration::from_secs(5)).await?;
}
@@ -2643,12 +2949,18 @@ fn write_fallback_oauth_tokens(
client_id: &str,
access_token: &str,
refresh_token: &str,
expiry: OAuthTokenExpiry,
) -> anyhow::Result<()> {
let expires_at = SystemTime::now()
.checked_add(Duration::from_secs(3600))
.ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))?
.duration_since(UNIX_EPOCH)?
.as_millis() as u64;
let expires_at = match expiry {
OAuthTokenExpiry::Valid => SystemTime::now()
.checked_add(Duration::from_secs(3600))
.ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))?,
OAuthTokenExpiry::Expired => SystemTime::now()
.checked_sub(Duration::from_secs(3600))
.ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))?,
}
.duration_since(UNIX_EPOCH)?
.as_millis() as u64;
let store = serde_json::json!({
"stub": {
@@ -2667,6 +2979,33 @@ fn write_fallback_oauth_tokens(
Ok(())
}
fn read_fallback_oauth_access_token(
home: &Path,
server_name: &str,
server_url: &str,
) -> anyhow::Result<String> {
let file_path = home.join(".credentials.json");
let store: Value = serde_json::from_slice(&fs::read(&file_path)?)?;
store
.as_object()
.into_iter()
.flat_map(|store| store.values())
.find(|entry| {
entry
.get("server_name")
.and_then(Value::as_str)
.is_some_and(|value| value == server_name)
&& entry
.get("server_url")
.and_then(Value::as_str)
.is_some_and(|value| value == server_url)
})
.and_then(|entry| entry.get("access_token"))
.and_then(Value::as_str)
.map(ToOwned::to_owned)
.ok_or_else(|| anyhow::anyhow!("missing fallback OAuth access token"))
}
struct EnvVarGuard {
key: &'static str,
original: Option<OsString>,

View File

@@ -9,6 +9,7 @@ workspace = true
[dependencies]
anyhow = "1"
async-trait = { workspace = true }
axum = { workspace = true, default-features = false, features = [
"http1",
"tokio",

View File

@@ -62,6 +62,7 @@ const MEMO_URI: &str = "memo://codex/example-note";
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
const MCP_SESSION_ID_HEADER: &str = "mcp-session-id";
const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure";
const OAUTH_TOKEN_PATH: &str = "/oauth/token";
#[derive(Clone, Default)]
struct SessionFailureState {
@@ -116,6 +117,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
fs::write(bound_addr_file, actual_bind_addr.to_string())?;
}
eprintln!("starting rmcp streamable http test server on http://{actual_bind_addr}/mcp");
let refreshed_access_token = std::env::var("MCP_REFRESHED_ACCESS_TOKEN")
.unwrap_or_else(|_| "refreshed-access-token".to_string());
let router = Router::new()
.route(
@@ -146,6 +149,28 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
}),
)
.route(
OAUTH_TOKEN_PATH,
post({
move || async move {
let refreshed_access_token = refreshed_access_token.clone();
#[expect(clippy::expect_used)]
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, "application/json")
.body(Body::from(
serde_json::to_vec(&json!({
"access_token": refreshed_access_token,
"token_type": "Bearer",
"expires_in": 3600_u64,
"refresh_token": "refreshed-refresh-token",
}))
.expect("failed to serialize token response"),
))
.expect("valid token response")
}
}),
)
.nest_service(
"/mcp",
StreamableHttpService::new(
@@ -386,7 +411,8 @@ async fn require_bearer(
request: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
if request.uri().path().contains("/.well-known/") {
if request.uri().path().contains("/.well-known/") || request.uri().path().starts_with("/oauth/")
{
return Ok(next.run(request).await);
}
if request

View File

@@ -4,6 +4,7 @@ mod executor_process_transport;
mod http_client_adapter;
mod logging_client_handler;
mod oauth;
mod oauth_http_client;
mod perform_oauth_login;
mod program_resolver;
mod rmcp_client;

View File

@@ -48,6 +48,7 @@ use codex_keyring_store::KeyringStore;
use rmcp::transport::auth::AuthorizationManager;
use tokio::sync::Mutex;
use crate::oauth_http_client::OAuthHttpProxy;
use codex_utils_home_dir::find_codex_home;
const KEYRING_SERVICE: &str = "Codex MCP Credentials";
@@ -258,6 +259,7 @@ struct OAuthPersistorInner {
server_name: String,
url: String,
authorization_manager: Arc<Mutex<AuthorizationManager>>,
_oauth_http_proxy: Option<OAuthHttpProxy>,
store_mode: OAuthCredentialsStoreMode,
last_credentials: Mutex<Option<StoredOAuthTokens>>,
}
@@ -267,6 +269,7 @@ impl OAuthPersistor {
server_name: String,
url: String,
authorization_manager: Arc<Mutex<AuthorizationManager>>,
oauth_http_proxy: Option<OAuthHttpProxy>,
store_mode: OAuthCredentialsStoreMode,
initial_credentials: Option<StoredOAuthTokens>,
) -> Self {
@@ -275,6 +278,7 @@ impl OAuthPersistor {
server_name,
url,
authorization_manager,
_oauth_http_proxy: oauth_http_proxy,
store_mode,
last_credentials: Mutex::new(initial_credentials),
}),

View File

@@ -0,0 +1,483 @@
//! OAuth bootstrap helpers that route non-browser HTTP through the shared
//! `HttpClient` capability.
//!
//! The browser-facing part of MCP OAuth stays on the orchestrator, but the
//! HTTP requests used for discovery, registration, token exchange, and refresh
//! need to follow the selected MCP placement. This module discovers OAuth
//! metadata through `HttpClient` and exposes token/registration endpoints
//! through a small localhost proxy so RMCP's OAuth manager can keep using its
//! existing `reqwest`-based internals.
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use std::thread;
use anyhow::Result;
use anyhow::anyhow;
use codex_exec_server::HttpClient;
use codex_exec_server::HttpHeader;
use codex_exec_server::HttpRequestParams;
use reqwest::Method;
use reqwest::StatusCode;
use reqwest::Url;
use reqwest::header::HeaderMap;
use rmcp::transport::auth::AuthError;
use rmcp::transport::auth::AuthorizationManager;
use rmcp::transport::auth::AuthorizationMetadata;
use rmcp::transport::auth::CredentialStore;
use rmcp::transport::auth::StoredCredentials;
use tiny_http::Header;
use tiny_http::Response;
use tiny_http::Server;
const DISCOVERY_TIMEOUT_MS: u64 = 30_000;
const MCP_PROTOCOL_VERSION_HEADER: &str = "MCP-Protocol-Version";
const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
const TOKEN_PROXY_PATH: &str = "/oauth/token";
const REGISTRATION_PROXY_PATH: &str = "/oauth/register";
/// The OAuth manager plus any bridge state needed to keep its non-browser
/// HTTP requests on the selected `HttpClient`.
pub(crate) struct OAuthHttpSetup {
pub(crate) authorization_manager: AuthorizationManager,
pub(crate) proxy: Option<OAuthHttpProxy>,
}
/// Builds an OAuth manager whose metadata discovery and token flows follow the
/// selected `HttpClient`, while still leaving browser launch and callback
/// handling on the orchestrator.
pub(crate) async fn create_oauth_http_setup(
url: &str,
default_headers: &HeaderMap,
http_client: Arc<dyn HttpClient>,
initial_credentials: Option<StoredCredentials>,
) -> Result<OAuthHttpSetup> {
let metadata = discover_authorization_metadata(url, default_headers, Arc::clone(&http_client))
.await
.map_err(|error| anyhow!(error))?;
// RMCP's OAuth manager still issues token and registration requests through
// reqwest. Rewrite just those endpoints through a tiny local bridge so the
// actual network hop still follows the selected `HttpClient`.
let proxy = OAuthHttpProxy::new(Arc::clone(&http_client), default_headers, &metadata)?;
let mut authorization_manager = AuthorizationManager::new(url)
.await
.map_err(|error| anyhow!(error))?;
authorization_manager.set_metadata(proxy.rewrite_metadata(&metadata)?);
if let Some(initial_credentials) = initial_credentials {
let client_id = initial_credentials.client_id.clone();
authorization_manager.set_credential_store(StaticCredentialStore::new(initial_credentials));
authorization_manager
.configure_client_id(&client_id)
.map_err(|error| anyhow!(error))?;
}
Ok(OAuthHttpSetup {
authorization_manager,
proxy: Some(proxy),
})
}
async fn discover_authorization_metadata(
url: &str,
default_headers: &HeaderMap,
http_client: Arc<dyn HttpClient>,
) -> Result<AuthorizationMetadata, AuthError> {
let base_url = Url::parse(url).map_err(|error| AuthError::OAuthError(error.to_string()))?;
for discovery_url in generate_discovery_urls(&base_url) {
let response = http_client
.http_request(HttpRequestParams {
method: Method::GET.as_str().to_string(),
url: discovery_url.to_string(),
headers: discovery_headers(default_headers)?,
body: None,
timeout_ms: Some(DISCOVERY_TIMEOUT_MS),
request_id: next_request_id("oauth-discovery"),
stream_response: false,
})
.await
.map_err(|error| AuthError::OAuthError(error.to_string()))?;
if response.status != StatusCode::OK.as_u16() {
continue;
}
let metadata = serde_json::from_slice::<AuthorizationMetadata>(&response.body.into_inner())
.map_err(|error| AuthError::OAuthError(error.to_string()))?;
return Ok(metadata);
}
Err(AuthError::NoAuthorizationSupport)
}
fn generate_discovery_urls(base_url: &Url) -> Vec<Url> {
let mut candidates = Vec::new();
let trimmed = base_url
.path()
.trim_start_matches('/')
.trim_end_matches('/');
let mut push_candidate = |discovery_path: String| {
let mut discovery_url = base_url.clone();
discovery_url.set_query(None);
discovery_url.set_fragment(None);
discovery_url.set_path(&discovery_path);
candidates.push(discovery_url);
};
if trimmed.is_empty() {
push_candidate("/.well-known/oauth-authorization-server".to_string());
push_candidate("/.well-known/openid-configuration".to_string());
} else {
push_candidate(format!("/.well-known/oauth-authorization-server/{trimmed}"));
push_candidate(format!("/.well-known/openid-configuration/{trimmed}"));
push_candidate(format!("/{trimmed}/.well-known/openid-configuration"));
push_candidate("/.well-known/oauth-authorization-server".to_string());
}
candidates
}
fn discovery_headers(default_headers: &HeaderMap) -> Result<Vec<HttpHeader>, AuthError> {
let mut headers = protocol_headers(default_headers);
headers.push(HttpHeader {
name: MCP_PROTOCOL_VERSION_HEADER.to_string(),
value: MCP_PROTOCOL_VERSION.to_string(),
});
Ok(headers)
}
/// A tiny localhost bridge that lets RMCP keep using its internal reqwest
/// OAuth flows while the real network requests still go through `HttpClient`.
pub(crate) struct OAuthHttpProxy {
server: Arc<Server>,
base_url: String,
routes: Arc<HashMap<String, String>>,
}
impl OAuthHttpProxy {
fn new(
http_client: Arc<dyn HttpClient>,
default_headers: &HeaderMap,
metadata: &AuthorizationMetadata,
) -> Result<Self> {
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|error| anyhow!(error))?);
let address = match server.server_addr() {
tiny_http::ListenAddr::IP(address) => address,
#[cfg(not(target_os = "windows"))]
_ => return Err(anyhow!("unable to determine OAuth HTTP proxy bind address")),
};
let base_url = format!("http://{address}");
let mut routes = HashMap::new();
routes.insert(
TOKEN_PROXY_PATH.to_string(),
metadata.token_endpoint.clone(),
);
if let Some(registration_endpoint) = metadata.registration_endpoint.clone() {
routes.insert(REGISTRATION_PROXY_PATH.to_string(), registration_endpoint);
}
let routes = Arc::new(routes);
let server_for_thread = Arc::clone(&server);
let routes_for_thread = Arc::clone(&routes);
let default_headers = default_headers.clone();
let runtime = tokio::runtime::Handle::current();
thread::spawn(move || {
while let Ok(mut request) = server_for_thread.recv() {
let request_url = request.url().to_string();
let route_key = request_url
.split_once('?')
.map(|(path, _)| path)
.unwrap_or(request_url.as_str())
.to_string();
let Some(target_url) = routes_for_thread.get(&route_key).cloned() else {
let _ = request.respond(Response::empty(StatusCode::NOT_FOUND.as_u16()));
continue;
};
let mut target = target_url;
if let Some((_, query)) = request_url.split_once('?') {
target.push('?');
target.push_str(query);
}
let method = request.method().as_str().to_string();
let mut body = Vec::new();
if request.as_reader().read_to_end(&mut body).is_err() {
let _ = request.respond(Response::empty(StatusCode::BAD_REQUEST.as_u16()));
continue;
}
// Preserve the original request shape as much as we can, but
// let the selected `HttpClient` own the outbound HTTP call.
let mut headers = protocol_headers(&default_headers);
for header in request.headers() {
let field = header.field.as_str().to_string();
if field.eq_ignore_ascii_case("host")
|| field.eq_ignore_ascii_case("content-length")
{
continue;
}
headers.push(HttpHeader {
name: field,
value: header.value.to_string(),
});
}
let response = runtime.block_on(http_client.http_request(HttpRequestParams {
method,
url: target,
headers,
body: Some(body.into()),
timeout_ms: Some(DISCOVERY_TIMEOUT_MS),
request_id: next_request_id("oauth-proxy"),
stream_response: false,
}));
let response = match response {
Ok(response) => response,
Err(_) => {
let _ = request.respond(Response::empty(StatusCode::BAD_GATEWAY.as_u16()));
continue;
}
};
let mut proxy_response = Response::from_data(response.body.into_inner())
.with_status_code(response.status);
for header in response.headers {
if let Ok(header) = Header::from_bytes(header.name.as_bytes(), header.value) {
proxy_response.add_header(header);
}
}
let _ = request.respond(proxy_response);
}
});
Ok(Self {
server,
base_url,
routes,
})
}
fn rewrite_metadata(&self, metadata: &AuthorizationMetadata) -> Result<AuthorizationMetadata> {
let mut rewritten = metadata.clone();
rewritten.token_endpoint = self.proxied_url(TOKEN_PROXY_PATH)?;
rewritten.registration_endpoint = if self.routes.contains_key(REGISTRATION_PROXY_PATH) {
Some(self.proxied_url(REGISTRATION_PROXY_PATH)?)
} else {
None
};
Ok(rewritten)
}
fn proxied_url(&self, path: &str) -> Result<String> {
let mut url = Url::parse(&self.base_url).map_err(|error| anyhow!(error))?;
url.set_path(path);
Ok(url.to_string())
}
}
impl Drop for OAuthHttpProxy {
fn drop(&mut self) {
// Wake the background thread so it can notice shutdown and exit.
self.server.unblock();
}
}
#[derive(Clone)]
struct StaticCredentialStore {
stored: Arc<tokio::sync::Mutex<Option<StoredCredentials>>>,
}
impl StaticCredentialStore {
fn new(credentials: StoredCredentials) -> Self {
Self {
stored: Arc::new(tokio::sync::Mutex::new(Some(credentials))),
}
}
}
#[async_trait::async_trait]
impl CredentialStore for StaticCredentialStore {
async fn load(&self) -> Result<Option<StoredCredentials>, AuthError> {
Ok(self.stored.lock().await.clone())
}
async fn save(&self, credentials: StoredCredentials) -> Result<(), AuthError> {
*self.stored.lock().await = Some(credentials);
Ok(())
}
async fn clear(&self) -> Result<(), AuthError> {
*self.stored.lock().await = None;
Ok(())
}
}
fn protocol_headers(headers: &HeaderMap) -> Vec<HttpHeader> {
headers
.iter()
.filter_map(|(name, value)| {
Some(HttpHeader {
name: name.as_str().to_string(),
value: value.to_str().ok()?.to_string(),
})
})
.collect()
}
fn next_request_id(prefix: &str) -> String {
static NEXT_REQUEST_ID: AtomicU64 = AtomicU64::new(1);
let id = NEXT_REQUEST_ID.fetch_add(1, Ordering::Relaxed);
format!("{prefix}-{id}")
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use anyhow::Result;
use codex_exec_server::ExecServerError;
use codex_exec_server::HttpRequestResponse;
use futures::FutureExt;
use futures::future::BoxFuture;
use pretty_assertions::assert_eq;
use super::*;
#[derive(Clone, Default)]
struct TestHttpClient {
requests: Arc<Mutex<Vec<HttpRequestParams>>>,
responses: Arc<Mutex<Vec<Result<HttpRequestResponse, ExecServerError>>>>,
}
impl TestHttpClient {
fn push_response(&self, response: Result<HttpRequestResponse, ExecServerError>) {
self.responses.lock().unwrap().push(response);
}
fn requests(&self) -> Vec<HttpRequestParams> {
self.requests.lock().unwrap().clone()
}
}
impl HttpClient for TestHttpClient {
fn http_request(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>> {
let requests = Arc::clone(&self.requests);
let responses = Arc::clone(&self.responses);
async move {
requests.lock().unwrap().push(params);
responses.lock().unwrap().remove(0)
}
.boxed()
}
fn http_request_stream(
&self,
_params: HttpRequestParams,
) -> BoxFuture<
'_,
Result<
(
codex_exec_server::HttpRequestResponse,
codex_exec_server::HttpResponseBodyStream,
),
ExecServerError,
>,
> {
async move { Err(ExecServerError::Protocol("unused".to_string())) }.boxed()
}
}
#[tokio::test]
async fn discover_authorization_metadata_uses_http_client() -> Result<()> {
let http_client = TestHttpClient::default();
http_client.push_response(Ok(HttpRequestResponse {
status: StatusCode::OK.as_u16(),
headers: Vec::new(),
body: serde_json::to_vec(&AuthorizationMetadata {
authorization_endpoint: "https://example.com/authorize".to_string(),
token_endpoint: "https://example.com/token".to_string(),
registration_endpoint: None,
issuer: None,
jwks_uri: None,
scopes_supported: None,
response_types_supported: None,
additional_fields: HashMap::new(),
})?
.into(),
}));
let metadata = discover_authorization_metadata(
"https://example.com/mcp",
&HeaderMap::new(),
Arc::new(http_client.clone()),
)
.await?;
assert_eq!(
metadata.authorization_endpoint,
"https://example.com/authorize"
);
assert_eq!(metadata.token_endpoint, "https://example.com/token");
assert_eq!(
http_client.requests()[0].url,
"https://example.com/.well-known/oauth-authorization-server/mcp"
);
Ok(())
}
#[tokio::test]
async fn oauth_http_proxy_forwards_token_requests() -> Result<()> {
let http_client = TestHttpClient::default();
http_client.push_response(Ok(HttpRequestResponse {
status: StatusCode::OK.as_u16(),
headers: vec![HttpHeader {
name: "content-type".to_string(),
value: "application/json".to_string(),
}],
body: br#"{"access_token":"abc","token_type":"Bearer"}"#.to_vec().into(),
}));
let proxy = OAuthHttpProxy::new(
Arc::new(http_client.clone()),
&HeaderMap::new(),
&AuthorizationMetadata {
authorization_endpoint: "https://example.com/authorize".to_string(),
token_endpoint: "https://remote.example.com/token".to_string(),
registration_endpoint: None,
issuer: None,
jwks_uri: None,
scopes_supported: None,
response_types_supported: None,
additional_fields: HashMap::new(),
},
)?;
let response = reqwest::Client::new()
.post(proxy.proxied_url(TOKEN_PROXY_PATH)?)
.header("content-type", "application/x-www-form-urlencoded")
.body("grant_type=refresh_token")
.send()
.await?;
assert_eq!(response.status(), StatusCode::OK);
let requests = http_client.requests();
assert_eq!(requests.len(), 1);
assert_eq!(requests[0].url, "https://remote.example.com/token");
assert_eq!(
String::from_utf8(requests[0].body.clone().unwrap().into_inner())?,
"grant_type=refresh_token"
);
Ok(())
}
}

View File

@@ -12,7 +12,6 @@ use std::time::Instant;
use anyhow::Result;
use anyhow::anyhow;
use codex_api::SharedAuthProvider;
use codex_client::build_reqwest_client_with_custom_ca;
use codex_config::types::McpServerEnvVar;
use codex_exec_server::HttpClient;
use futures::FutureExt;
@@ -47,7 +46,7 @@ use rmcp::service::{self};
use rmcp::transport::StreamableHttpClientTransport;
use rmcp::transport::auth::AuthClient;
use rmcp::transport::auth::AuthError;
use rmcp::transport::auth::OAuthState;
use rmcp::transport::auth::StoredCredentials;
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
use rmcp::transport::streamable_http_client::StreamableHttpError;
use serde::Deserialize;
@@ -65,10 +64,10 @@ use crate::http_client_adapter::StreamableHttpClientAdapterError;
use crate::load_oauth_tokens;
use crate::oauth::OAuthPersistor;
use crate::oauth::StoredOAuthTokens;
use crate::oauth_http_client::create_oauth_http_setup;
use crate::stdio_server_launcher::StdioServerCommand;
use crate::stdio_server_launcher::StdioServerLauncher;
use crate::stdio_server_launcher::StdioServerTransport;
use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
use codex_config::types::OAuthCredentialsStoreMode;
@@ -946,28 +945,21 @@ async fn create_oauth_transport_and_runtime(
StreamableHttpClientTransport<AuthClient<StreamableHttpClientAdapter>>,
OAuthPersistor,
)> {
let builder = apply_default_headers(reqwest::Client::builder(), &default_headers);
let oauth_metadata_client = build_reqwest_client_with_custom_ca(builder)?;
// TODO(aibrahim): teach OAuth bootstrap and refresh to use the same
// shared HTTP client abstraction instead of always creating the local
// reqwest metadata client here.
let mut oauth_state =
OAuthState::new(url.to_string(), Some(oauth_metadata_client.clone())).await?;
oauth_state
.set_credentials(
&initial_tokens.client_id,
initial_tokens.token_response.0.clone(),
)
.await?;
let manager = match oauth_state {
OAuthState::Authorized(manager) => manager,
OAuthState::Unauthorized(manager) => manager,
OAuthState::Session(_) | OAuthState::AuthorizedHttpClient(_) => {
return Err(anyhow!("unexpected OAuth state during client setup"));
}
};
// `create_oauth_http_setup` returns both the initialized RMCP
// `AuthorizationManager` and the proxy bridge that keeps later token
// refresh requests on the selected `HttpClient`. Keep both pieces together
// here so the transport and persistor share the same OAuth runtime.
let oauth_http_setup = create_oauth_http_setup(
url,
&default_headers,
http_client.clone(),
Some(StoredCredentials {
client_id: initial_tokens.client_id.clone(),
token_response: Some(initial_tokens.token_response.0.clone()),
}),
)
.await?;
let manager = oauth_http_setup.authorization_manager;
let auth_client = AuthClient::new(
StreamableHttpClientAdapter::new(http_client, default_headers, /*auth_provider*/ None),
@@ -984,6 +976,7 @@ async fn create_oauth_transport_and_runtime(
server_name.to_string(),
url.to_string(),
auth_manager,
oauth_http_setup.proxy,
credentials_store,
Some(initial_tokens),
);