mirror of
https://github.com/openai/codex.git
synced 2026-05-02 02:17:22 +00:00
Compare commits
10 Commits
xli-codex/
...
dev/remote
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f128f30011 | ||
|
|
2596065878 | ||
|
|
9713a34772 | ||
|
|
cefc86e851 | ||
|
|
fc8c0aa238 | ||
|
|
f8ebe83fc9 | ||
|
|
b8780250d6 | ||
|
|
e716329890 | ||
|
|
6da1978ff4 | ||
|
|
a92dc1f1b5 |
1
codex-rs/Cargo.lock
generated
1
codex-rs/Cargo.lock
generated
@@ -3163,6 +3163,7 @@ name = "codex-rmcp-client"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"bytes",
|
||||
"codex-api",
|
||||
|
||||
@@ -99,6 +99,7 @@ enum McpCallEvent {
|
||||
}
|
||||
|
||||
const REMOTE_MCP_ENVIRONMENT: &str = "remote";
|
||||
const LOCAL_MCP_ENVIRONMENT: &str = "local";
|
||||
|
||||
fn remote_aware_experimental_environment() -> Option<String> {
|
||||
// These tests run locally in normal CI and against the Docker-backed
|
||||
@@ -107,6 +108,10 @@ fn remote_aware_experimental_environment() -> Option<String> {
|
||||
std::env::var_os(remote_env_env_var()).map(|_| REMOTE_MCP_ENVIRONMENT.to_string())
|
||||
}
|
||||
|
||||
fn remote_only_experimental_environment() -> Option<String> {
|
||||
std::env::var_os(remote_env_env_var()).map(|_| REMOTE_MCP_ENVIRONMENT.to_string())
|
||||
}
|
||||
|
||||
/// Returns the stdio MCP test server command path for the active test placement.
|
||||
///
|
||||
/// Local test runs can execute the host-built test binary directly. Remote-aware
|
||||
@@ -1851,6 +1856,53 @@ struct StreamableHttpTestServer {
|
||||
process: StreamableHttpTestServerProcess,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum StreamableHttpTestServerBindMode {
|
||||
HostVisible,
|
||||
RemoteLoopbackOnly,
|
||||
}
|
||||
|
||||
struct StreamableHttpTestServerOptions<'a> {
|
||||
expected_env_value: &'a str,
|
||||
expected_bearer: Option<&'a str>,
|
||||
refreshed_access_token: Option<&'a str>,
|
||||
bind_mode: StreamableHttpTestServerBindMode,
|
||||
}
|
||||
|
||||
impl<'a> StreamableHttpTestServerOptions<'a> {
|
||||
fn host_visible(expected_env_value: &'a str, expected_bearer: Option<&'a str>) -> Self {
|
||||
Self {
|
||||
expected_env_value,
|
||||
expected_bearer,
|
||||
refreshed_access_token: None,
|
||||
bind_mode: StreamableHttpTestServerBindMode::HostVisible,
|
||||
}
|
||||
}
|
||||
|
||||
fn with_refreshed_access_token(mut self, refreshed_access_token: &'a str) -> Self {
|
||||
self.refreshed_access_token = Some(refreshed_access_token);
|
||||
self
|
||||
}
|
||||
|
||||
fn remote_loopback_only(mut self) -> Self {
|
||||
self.bind_mode = StreamableHttpTestServerBindMode::RemoteLoopbackOnly;
|
||||
self
|
||||
}
|
||||
|
||||
fn remote_loopback_only_if_remote(self, placement: OAuthRefreshPlacement) -> Self {
|
||||
match placement {
|
||||
OAuthRefreshPlacement::Local => self,
|
||||
OAuthRefreshPlacement::Remote => self.remote_loopback_only(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum OAuthTokenExpiry {
|
||||
Valid,
|
||||
Expired,
|
||||
}
|
||||
|
||||
/// Tracks whether the Streamable HTTP test server runs on the host or remotely.
|
||||
enum StreamableHttpTestServerProcess {
|
||||
Local(Child),
|
||||
@@ -1964,7 +2016,11 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
// it is a host process.
|
||||
let expected_env_value = "propagated-env-http";
|
||||
let Some(http_server) =
|
||||
start_streamable_http_test_server(expected_env_value, /*expected_token*/ None).await?
|
||||
start_streamable_http_test_server(StreamableHttpTestServerOptions::host_visible(
|
||||
expected_env_value,
|
||||
/*expected_bearer*/ None,
|
||||
))
|
||||
.await?
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
@@ -2150,8 +2206,10 @@ async fn streamable_http_with_oauth_round_trip_impl() -> anyhow::Result<()> {
|
||||
let expected_token = "initial-access-token";
|
||||
let client_id = "test-client-id";
|
||||
let refresh_token = "initial-refresh-token";
|
||||
let Some(http_server) =
|
||||
start_streamable_http_test_server(expected_env_value, Some(expected_token)).await?
|
||||
let Some(http_server) = start_streamable_http_test_server(
|
||||
StreamableHttpTestServerOptions::host_visible(expected_env_value, Some(expected_token)),
|
||||
)
|
||||
.await?
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
@@ -2168,6 +2226,7 @@ async fn streamable_http_with_oauth_round_trip_impl() -> anyhow::Result<()> {
|
||||
client_id,
|
||||
expected_token,
|
||||
refresh_token,
|
||||
OAuthTokenExpiry::Valid,
|
||||
)?;
|
||||
|
||||
// Phase 4: configure Codex with the OAuth-backed Streamable HTTP MCP
|
||||
@@ -2287,10 +2346,236 @@ async fn streamable_http_with_oauth_round_trip_impl() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial(codex_home)]
|
||||
fn streamable_http_with_oauth_refresh_round_trip_local() -> anyhow::Result<()> {
|
||||
run_streamable_http_with_oauth_refresh_round_trip(OAuthRefreshPlacement::Local)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial(codex_home)]
|
||||
fn streamable_http_with_oauth_refresh_round_trip_remote() -> anyhow::Result<()> {
|
||||
if remote_only_experimental_environment().is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
run_streamable_http_with_oauth_refresh_round_trip(OAuthRefreshPlacement::Remote)
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum OAuthRefreshPlacement {
|
||||
Local,
|
||||
Remote,
|
||||
}
|
||||
|
||||
fn run_streamable_http_with_oauth_refresh_round_trip(
|
||||
placement: OAuthRefreshPlacement,
|
||||
) -> anyhow::Result<()> {
|
||||
const TEST_STACK_SIZE_BYTES: usize = 8 * 1024 * 1024;
|
||||
|
||||
let thread_name = match placement {
|
||||
OAuthRefreshPlacement::Local => "streamable_http_with_oauth_refresh_round_trip_local",
|
||||
OAuthRefreshPlacement::Remote => "streamable_http_with_oauth_refresh_round_trip_remote",
|
||||
};
|
||||
|
||||
let handle = std::thread::Builder::new()
|
||||
.name(thread_name.to_string())
|
||||
.stack_size(TEST_STACK_SIZE_BYTES)
|
||||
.spawn(move || -> anyhow::Result<()> {
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(1)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
runtime.block_on(streamable_http_with_oauth_refresh_round_trip_impl(
|
||||
placement,
|
||||
))
|
||||
})?;
|
||||
|
||||
match handle.join() {
|
||||
Ok(result) => result,
|
||||
Err(_) => Err(anyhow::anyhow!("{thread_name} thread panicked")),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
async fn streamable_http_with_oauth_refresh_round_trip_impl(
|
||||
placement: OAuthRefreshPlacement,
|
||||
) -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
|
||||
let call_id = "call-790";
|
||||
let server_name = match placement {
|
||||
OAuthRefreshPlacement::Local => "rmcp_http_oauth_refresh_local",
|
||||
OAuthRefreshPlacement::Remote => "rmcp_http_oauth_refresh_remote",
|
||||
};
|
||||
let tool_name = format!("mcp__{server_name}__echo");
|
||||
let namespace = format!("mcp__{server_name}__");
|
||||
|
||||
mount_sse_once(
|
||||
&server,
|
||||
responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_function_call_with_namespace(
|
||||
call_id,
|
||||
&namespace,
|
||||
"echo",
|
||||
"{\"message\":\"ping\"}",
|
||||
),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
mount_sse_once(
|
||||
&server,
|
||||
responses::sse(vec![
|
||||
responses::ev_assistant_message(
|
||||
"msg-1",
|
||||
"rmcp streamable http oauth refresh echo tool completed successfully.",
|
||||
),
|
||||
responses::ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let expected_env_value = match placement {
|
||||
OAuthRefreshPlacement::Local => "propagated-env-http-oauth-refresh-local",
|
||||
OAuthRefreshPlacement::Remote => "propagated-env-http-oauth-refresh-remote",
|
||||
};
|
||||
let initial_access_token = "expired-access-token";
|
||||
let refreshed_access_token = match placement {
|
||||
OAuthRefreshPlacement::Local => "refreshed-access-token-local",
|
||||
OAuthRefreshPlacement::Remote => "refreshed-access-token-remote",
|
||||
};
|
||||
let refresh_token = "initial-refresh-token";
|
||||
// The remote case binds the test server to 127.0.0.1 inside the remote
|
||||
// container so the orchestrator cannot reach metadata or token endpoints
|
||||
// directly. If refresh still succeeds, it had to go through the selected
|
||||
// remote `HttpClient`.
|
||||
let Some(http_server) = start_streamable_http_test_server(
|
||||
StreamableHttpTestServerOptions::host_visible(
|
||||
expected_env_value,
|
||||
Some(refreshed_access_token),
|
||||
)
|
||||
.with_refreshed_access_token(refreshed_access_token)
|
||||
.remote_loopback_only_if_remote(placement),
|
||||
)
|
||||
.await?
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
let server_url = http_server.url().to_string();
|
||||
|
||||
let temp_home = Arc::new(tempdir()?);
|
||||
let _codex_home_guard = EnvVarGuard::set("CODEX_HOME", temp_home.path().as_os_str());
|
||||
write_fallback_oauth_tokens(
|
||||
temp_home.path(),
|
||||
server_name,
|
||||
&server_url,
|
||||
"test-client-id",
|
||||
initial_access_token,
|
||||
refresh_token,
|
||||
OAuthTokenExpiry::Expired,
|
||||
)?;
|
||||
|
||||
let experimental_environment = match placement {
|
||||
OAuthRefreshPlacement::Local => Some(LOCAL_MCP_ENVIRONMENT.to_string()),
|
||||
OAuthRefreshPlacement::Remote => Some(REMOTE_MCP_ENVIRONMENT.to_string()),
|
||||
};
|
||||
let fixture = test_codex()
|
||||
.with_home(temp_home.clone())
|
||||
.with_config(move |config| {
|
||||
config.mcp_oauth_credentials_store_mode = serde_json::from_value(json!("file"))
|
||||
.expect("`file` should deserialize as OAuthCredentialsStoreMode");
|
||||
insert_mcp_server(
|
||||
config,
|
||||
server_name,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: server_url,
|
||||
bearer_token_env_var: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
},
|
||||
TestMcpServerOptions {
|
||||
experimental_environment,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
})
|
||||
.build_remote_aware(&server)
|
||||
.await?;
|
||||
let session_model = fixture.session_configured.model.clone();
|
||||
|
||||
wait_for_mcp_tool(&fixture, &tool_name).await?;
|
||||
|
||||
fixture
|
||||
.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![UserInput::Text {
|
||||
text: "call the rmcp streamable http oauth echo tool".into(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: fixture.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
approvals_reviewer: None,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
permission_profile: None,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: None,
|
||||
service_tier: None,
|
||||
collaboration_mode: None,
|
||||
personality: None,
|
||||
environments: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let end_event = wait_for_event(&fixture.codex, |ev| {
|
||||
matches!(ev, EventMsg::McpToolCallEnd(_))
|
||||
})
|
||||
.await;
|
||||
let EventMsg::McpToolCallEnd(end) = end_event else {
|
||||
unreachable!("event guard guarantees McpToolCallEnd");
|
||||
};
|
||||
|
||||
let result = end
|
||||
.result
|
||||
.as_ref()
|
||||
.expect("rmcp echo tool should return success");
|
||||
assert_eq!(result.is_error, Some(false));
|
||||
let structured = result
|
||||
.structured_content
|
||||
.as_ref()
|
||||
.expect("structured content");
|
||||
let Value::Object(map) = structured else {
|
||||
panic!("structured content should be an object: {structured:?}");
|
||||
};
|
||||
assert_eq!(
|
||||
map.get("echo").and_then(Value::as_str),
|
||||
Some("ECHOING: ping")
|
||||
);
|
||||
assert_eq!(
|
||||
map.get("env").and_then(Value::as_str),
|
||||
Some(expected_env_value)
|
||||
);
|
||||
|
||||
wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
|
||||
server.verify().await;
|
||||
|
||||
assert_eq!(
|
||||
read_fallback_oauth_access_token(temp_home.path(), server_name, http_server.url())?,
|
||||
refreshed_access_token
|
||||
);
|
||||
|
||||
http_server.shutdown().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Starts the Streamable HTTP MCP test server in the active test placement.
|
||||
async fn start_streamable_http_test_server(
|
||||
expected_env_value: &str,
|
||||
expected_token: Option<&str>,
|
||||
options: StreamableHttpTestServerOptions<'_>,
|
||||
) -> anyhow::Result<Option<StreamableHttpTestServer>> {
|
||||
let rmcp_http_server_bin = match cargo_bin("test_streamable_http_server") {
|
||||
Ok(path) => path,
|
||||
@@ -2305,8 +2590,7 @@ async fn start_streamable_http_test_server(
|
||||
start_remote_streamable_http_test_server(
|
||||
&container_name,
|
||||
&rmcp_http_server_bin,
|
||||
expected_env_value,
|
||||
expected_token,
|
||||
options,
|
||||
)
|
||||
.await?,
|
||||
));
|
||||
@@ -2322,10 +2606,13 @@ async fn start_streamable_http_test_server(
|
||||
command
|
||||
.kill_on_drop(true)
|
||||
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
|
||||
.env("MCP_TEST_VALUE", expected_env_value);
|
||||
if let Some(expected_token) = expected_token {
|
||||
.env("MCP_TEST_VALUE", options.expected_env_value);
|
||||
if let Some(expected_token) = options.expected_bearer {
|
||||
command.env("MCP_EXPECT_BEARER", expected_token);
|
||||
}
|
||||
if let Some(refreshed_access_token) = options.refreshed_access_token {
|
||||
command.env("MCP_REFRESHED_ACCESS_TOKEN", refreshed_access_token);
|
||||
}
|
||||
let mut child = command.spawn()?;
|
||||
|
||||
wait_for_local_streamable_http_server(&mut child, &server_url, Duration::from_secs(5)).await?;
|
||||
@@ -2339,8 +2626,7 @@ async fn start_streamable_http_test_server(
|
||||
async fn start_remote_streamable_http_test_server(
|
||||
container_name: &str,
|
||||
rmcp_http_server_bin: &Path,
|
||||
expected_env_value: &str,
|
||||
expected_token: Option<&str>,
|
||||
options: StreamableHttpTestServerOptions<'_>,
|
||||
) -> anyhow::Result<StreamableHttpTestServer> {
|
||||
let remote_path = copy_binary_to_remote_env(
|
||||
container_name,
|
||||
@@ -2352,20 +2638,32 @@ async fn start_remote_streamable_http_test_server(
|
||||
let mut env_assignments = vec![
|
||||
format!(
|
||||
"MCP_STREAMABLE_HTTP_BIND_ADDR={}",
|
||||
sh_single_quote("0.0.0.0:0")
|
||||
sh_single_quote(match options.bind_mode {
|
||||
StreamableHttpTestServerBindMode::HostVisible => "0.0.0.0:0",
|
||||
StreamableHttpTestServerBindMode::RemoteLoopbackOnly => "127.0.0.1:0",
|
||||
})
|
||||
),
|
||||
format!(
|
||||
"MCP_STREAMABLE_HTTP_BOUND_ADDR_FILE={}",
|
||||
sh_single_quote(&bound_addr_file)
|
||||
),
|
||||
format!("MCP_TEST_VALUE={}", sh_single_quote(expected_env_value)),
|
||||
format!(
|
||||
"MCP_TEST_VALUE={}",
|
||||
sh_single_quote(options.expected_env_value)
|
||||
),
|
||||
];
|
||||
if let Some(expected_token) = expected_token {
|
||||
if let Some(expected_token) = options.expected_bearer {
|
||||
env_assignments.push(format!(
|
||||
"MCP_EXPECT_BEARER={}",
|
||||
sh_single_quote(expected_token)
|
||||
));
|
||||
}
|
||||
if let Some(refreshed_access_token) = options.refreshed_access_token {
|
||||
env_assignments.push(format!(
|
||||
"MCP_REFRESHED_ACCESS_TOKEN={}",
|
||||
sh_single_quote(refreshed_access_token)
|
||||
));
|
||||
}
|
||||
|
||||
let script = format!(
|
||||
"{} nohup {} > {} 2>&1 < /dev/null & echo $!",
|
||||
@@ -2395,13 +2693,21 @@ async fn start_remote_streamable_http_test_server(
|
||||
let remote_bind_addr =
|
||||
wait_for_remote_bound_addr(container_name, &bound_addr_file, Duration::from_secs(5))
|
||||
.await?;
|
||||
let container_ip = remote_container_ip(container_name)?;
|
||||
let server_url = format!("http://{}:{}/mcp", container_ip, remote_bind_addr.port());
|
||||
let server_host = match options.bind_mode {
|
||||
StreamableHttpTestServerBindMode::HostVisible => remote_container_ip(container_name)?,
|
||||
StreamableHttpTestServerBindMode::RemoteLoopbackOnly => "127.0.0.1".to_string(),
|
||||
};
|
||||
let server_url = format!("http://{}:{}/mcp", server_host, remote_bind_addr.port());
|
||||
// The orchestrator can see the Docker container IP, but the behavior under
|
||||
// test is whether the remote-side MCP client can reach it. Probe through
|
||||
// remote HTTP before handing the URL to the Codex fixture.
|
||||
wait_for_remote_streamable_http_server(&server_url, Duration::from_secs(5)).await?;
|
||||
if expected_token.is_some() {
|
||||
if options.expected_bearer.is_some()
|
||||
&& matches!(
|
||||
options.bind_mode,
|
||||
StreamableHttpTestServerBindMode::HostVisible
|
||||
)
|
||||
{
|
||||
wait_for_streamable_http_metadata(&server_url, Duration::from_secs(5)).await?;
|
||||
}
|
||||
|
||||
@@ -2643,12 +2949,18 @@ fn write_fallback_oauth_tokens(
|
||||
client_id: &str,
|
||||
access_token: &str,
|
||||
refresh_token: &str,
|
||||
expiry: OAuthTokenExpiry,
|
||||
) -> anyhow::Result<()> {
|
||||
let expires_at = SystemTime::now()
|
||||
.checked_add(Duration::from_secs(3600))
|
||||
.ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))?
|
||||
.duration_since(UNIX_EPOCH)?
|
||||
.as_millis() as u64;
|
||||
let expires_at = match expiry {
|
||||
OAuthTokenExpiry::Valid => SystemTime::now()
|
||||
.checked_add(Duration::from_secs(3600))
|
||||
.ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))?,
|
||||
OAuthTokenExpiry::Expired => SystemTime::now()
|
||||
.checked_sub(Duration::from_secs(3600))
|
||||
.ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))?,
|
||||
}
|
||||
.duration_since(UNIX_EPOCH)?
|
||||
.as_millis() as u64;
|
||||
|
||||
let store = serde_json::json!({
|
||||
"stub": {
|
||||
@@ -2667,6 +2979,33 @@ fn write_fallback_oauth_tokens(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_fallback_oauth_access_token(
|
||||
home: &Path,
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let file_path = home.join(".credentials.json");
|
||||
let store: Value = serde_json::from_slice(&fs::read(&file_path)?)?;
|
||||
store
|
||||
.as_object()
|
||||
.into_iter()
|
||||
.flat_map(|store| store.values())
|
||||
.find(|entry| {
|
||||
entry
|
||||
.get("server_name")
|
||||
.and_then(Value::as_str)
|
||||
.is_some_and(|value| value == server_name)
|
||||
&& entry
|
||||
.get("server_url")
|
||||
.and_then(Value::as_str)
|
||||
.is_some_and(|value| value == server_url)
|
||||
})
|
||||
.and_then(|entry| entry.get("access_token"))
|
||||
.and_then(Value::as_str)
|
||||
.map(ToOwned::to_owned)
|
||||
.ok_or_else(|| anyhow::anyhow!("missing fallback OAuth access token"))
|
||||
}
|
||||
|
||||
struct EnvVarGuard {
|
||||
key: &'static str,
|
||||
original: Option<OsString>,
|
||||
|
||||
@@ -9,6 +9,7 @@ workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
async-trait = { workspace = true }
|
||||
axum = { workspace = true, default-features = false, features = [
|
||||
"http1",
|
||||
"tokio",
|
||||
|
||||
@@ -62,6 +62,7 @@ const MEMO_URI: &str = "memo://codex/example-note";
|
||||
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
|
||||
const MCP_SESSION_ID_HEADER: &str = "mcp-session-id";
|
||||
const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure";
|
||||
const OAUTH_TOKEN_PATH: &str = "/oauth/token";
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct SessionFailureState {
|
||||
@@ -116,6 +117,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
fs::write(bound_addr_file, actual_bind_addr.to_string())?;
|
||||
}
|
||||
eprintln!("starting rmcp streamable http test server on http://{actual_bind_addr}/mcp");
|
||||
let refreshed_access_token = std::env::var("MCP_REFRESHED_ACCESS_TOKEN")
|
||||
.unwrap_or_else(|_| "refreshed-access-token".to_string());
|
||||
|
||||
let router = Router::new()
|
||||
.route(
|
||||
@@ -146,6 +149,28 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
OAUTH_TOKEN_PATH,
|
||||
post({
|
||||
move || async move {
|
||||
let refreshed_access_token = refreshed_access_token.clone();
|
||||
#[expect(clippy::expect_used)]
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(
|
||||
serde_json::to_vec(&json!({
|
||||
"access_token": refreshed_access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600_u64,
|
||||
"refresh_token": "refreshed-refresh-token",
|
||||
}))
|
||||
.expect("failed to serialize token response"),
|
||||
))
|
||||
.expect("valid token response")
|
||||
}
|
||||
}),
|
||||
)
|
||||
.nest_service(
|
||||
"/mcp",
|
||||
StreamableHttpService::new(
|
||||
@@ -386,7 +411,8 @@ async fn require_bearer(
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
if request.uri().path().contains("/.well-known/") {
|
||||
if request.uri().path().contains("/.well-known/") || request.uri().path().starts_with("/oauth/")
|
||||
{
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
if request
|
||||
|
||||
@@ -4,6 +4,7 @@ mod executor_process_transport;
|
||||
mod http_client_adapter;
|
||||
mod logging_client_handler;
|
||||
mod oauth;
|
||||
mod oauth_http_client;
|
||||
mod perform_oauth_login;
|
||||
mod program_resolver;
|
||||
mod rmcp_client;
|
||||
|
||||
@@ -48,6 +48,7 @@ use codex_keyring_store::KeyringStore;
|
||||
use rmcp::transport::auth::AuthorizationManager;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::oauth_http_client::OAuthHttpProxy;
|
||||
use codex_utils_home_dir::find_codex_home;
|
||||
|
||||
const KEYRING_SERVICE: &str = "Codex MCP Credentials";
|
||||
@@ -258,6 +259,7 @@ struct OAuthPersistorInner {
|
||||
server_name: String,
|
||||
url: String,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
_oauth_http_proxy: Option<OAuthHttpProxy>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
last_credentials: Mutex<Option<StoredOAuthTokens>>,
|
||||
}
|
||||
@@ -267,6 +269,7 @@ impl OAuthPersistor {
|
||||
server_name: String,
|
||||
url: String,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
oauth_http_proxy: Option<OAuthHttpProxy>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
initial_credentials: Option<StoredOAuthTokens>,
|
||||
) -> Self {
|
||||
@@ -275,6 +278,7 @@ impl OAuthPersistor {
|
||||
server_name,
|
||||
url,
|
||||
authorization_manager,
|
||||
_oauth_http_proxy: oauth_http_proxy,
|
||||
store_mode,
|
||||
last_credentials: Mutex::new(initial_credentials),
|
||||
}),
|
||||
|
||||
483
codex-rs/rmcp-client/src/oauth_http_client.rs
Normal file
483
codex-rs/rmcp-client/src/oauth_http_client.rs
Normal file
@@ -0,0 +1,483 @@
|
||||
//! OAuth bootstrap helpers that route non-browser HTTP through the shared
|
||||
//! `HttpClient` capability.
|
||||
//!
|
||||
//! The browser-facing part of MCP OAuth stays on the orchestrator, but the
|
||||
//! HTTP requests used for discovery, registration, token exchange, and refresh
|
||||
//! need to follow the selected MCP placement. This module discovers OAuth
|
||||
//! metadata through `HttpClient` and exposes token/registration endpoints
|
||||
//! through a small localhost proxy so RMCP's OAuth manager can keep using its
|
||||
//! existing `reqwest`-based internals.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::thread;
|
||||
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_exec_server::HttpClient;
|
||||
use codex_exec_server::HttpHeader;
|
||||
use codex_exec_server::HttpRequestParams;
|
||||
use reqwest::Method;
|
||||
use reqwest::StatusCode;
|
||||
use reqwest::Url;
|
||||
use reqwest::header::HeaderMap;
|
||||
use rmcp::transport::auth::AuthError;
|
||||
use rmcp::transport::auth::AuthorizationManager;
|
||||
use rmcp::transport::auth::AuthorizationMetadata;
|
||||
use rmcp::transport::auth::CredentialStore;
|
||||
use rmcp::transport::auth::StoredCredentials;
|
||||
use tiny_http::Header;
|
||||
use tiny_http::Response;
|
||||
use tiny_http::Server;
|
||||
|
||||
const DISCOVERY_TIMEOUT_MS: u64 = 30_000;
|
||||
const MCP_PROTOCOL_VERSION_HEADER: &str = "MCP-Protocol-Version";
|
||||
const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
|
||||
const TOKEN_PROXY_PATH: &str = "/oauth/token";
|
||||
const REGISTRATION_PROXY_PATH: &str = "/oauth/register";
|
||||
|
||||
/// The OAuth manager plus any bridge state needed to keep its non-browser
|
||||
/// HTTP requests on the selected `HttpClient`.
|
||||
pub(crate) struct OAuthHttpSetup {
|
||||
pub(crate) authorization_manager: AuthorizationManager,
|
||||
pub(crate) proxy: Option<OAuthHttpProxy>,
|
||||
}
|
||||
|
||||
/// Builds an OAuth manager whose metadata discovery and token flows follow the
|
||||
/// selected `HttpClient`, while still leaving browser launch and callback
|
||||
/// handling on the orchestrator.
|
||||
pub(crate) async fn create_oauth_http_setup(
|
||||
url: &str,
|
||||
default_headers: &HeaderMap,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
initial_credentials: Option<StoredCredentials>,
|
||||
) -> Result<OAuthHttpSetup> {
|
||||
let metadata = discover_authorization_metadata(url, default_headers, Arc::clone(&http_client))
|
||||
.await
|
||||
.map_err(|error| anyhow!(error))?;
|
||||
// RMCP's OAuth manager still issues token and registration requests through
|
||||
// reqwest. Rewrite just those endpoints through a tiny local bridge so the
|
||||
// actual network hop still follows the selected `HttpClient`.
|
||||
let proxy = OAuthHttpProxy::new(Arc::clone(&http_client), default_headers, &metadata)?;
|
||||
|
||||
let mut authorization_manager = AuthorizationManager::new(url)
|
||||
.await
|
||||
.map_err(|error| anyhow!(error))?;
|
||||
authorization_manager.set_metadata(proxy.rewrite_metadata(&metadata)?);
|
||||
|
||||
if let Some(initial_credentials) = initial_credentials {
|
||||
let client_id = initial_credentials.client_id.clone();
|
||||
authorization_manager.set_credential_store(StaticCredentialStore::new(initial_credentials));
|
||||
authorization_manager
|
||||
.configure_client_id(&client_id)
|
||||
.map_err(|error| anyhow!(error))?;
|
||||
}
|
||||
|
||||
Ok(OAuthHttpSetup {
|
||||
authorization_manager,
|
||||
proxy: Some(proxy),
|
||||
})
|
||||
}
|
||||
|
||||
async fn discover_authorization_metadata(
|
||||
url: &str,
|
||||
default_headers: &HeaderMap,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
) -> Result<AuthorizationMetadata, AuthError> {
|
||||
let base_url = Url::parse(url).map_err(|error| AuthError::OAuthError(error.to_string()))?;
|
||||
|
||||
for discovery_url in generate_discovery_urls(&base_url) {
|
||||
let response = http_client
|
||||
.http_request(HttpRequestParams {
|
||||
method: Method::GET.as_str().to_string(),
|
||||
url: discovery_url.to_string(),
|
||||
headers: discovery_headers(default_headers)?,
|
||||
body: None,
|
||||
timeout_ms: Some(DISCOVERY_TIMEOUT_MS),
|
||||
request_id: next_request_id("oauth-discovery"),
|
||||
stream_response: false,
|
||||
})
|
||||
.await
|
||||
.map_err(|error| AuthError::OAuthError(error.to_string()))?;
|
||||
|
||||
if response.status != StatusCode::OK.as_u16() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let metadata = serde_json::from_slice::<AuthorizationMetadata>(&response.body.into_inner())
|
||||
.map_err(|error| AuthError::OAuthError(error.to_string()))?;
|
||||
return Ok(metadata);
|
||||
}
|
||||
|
||||
Err(AuthError::NoAuthorizationSupport)
|
||||
}
|
||||
|
||||
fn generate_discovery_urls(base_url: &Url) -> Vec<Url> {
|
||||
let mut candidates = Vec::new();
|
||||
let trimmed = base_url
|
||||
.path()
|
||||
.trim_start_matches('/')
|
||||
.trim_end_matches('/');
|
||||
let mut push_candidate = |discovery_path: String| {
|
||||
let mut discovery_url = base_url.clone();
|
||||
discovery_url.set_query(None);
|
||||
discovery_url.set_fragment(None);
|
||||
discovery_url.set_path(&discovery_path);
|
||||
candidates.push(discovery_url);
|
||||
};
|
||||
|
||||
if trimmed.is_empty() {
|
||||
push_candidate("/.well-known/oauth-authorization-server".to_string());
|
||||
push_candidate("/.well-known/openid-configuration".to_string());
|
||||
} else {
|
||||
push_candidate(format!("/.well-known/oauth-authorization-server/{trimmed}"));
|
||||
push_candidate(format!("/.well-known/openid-configuration/{trimmed}"));
|
||||
push_candidate(format!("/{trimmed}/.well-known/openid-configuration"));
|
||||
push_candidate("/.well-known/oauth-authorization-server".to_string());
|
||||
}
|
||||
|
||||
candidates
|
||||
}
|
||||
|
||||
fn discovery_headers(default_headers: &HeaderMap) -> Result<Vec<HttpHeader>, AuthError> {
|
||||
let mut headers = protocol_headers(default_headers);
|
||||
headers.push(HttpHeader {
|
||||
name: MCP_PROTOCOL_VERSION_HEADER.to_string(),
|
||||
value: MCP_PROTOCOL_VERSION.to_string(),
|
||||
});
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
/// A tiny localhost bridge that lets RMCP keep using its internal reqwest
|
||||
/// OAuth flows while the real network requests still go through `HttpClient`.
|
||||
pub(crate) struct OAuthHttpProxy {
|
||||
server: Arc<Server>,
|
||||
base_url: String,
|
||||
routes: Arc<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl OAuthHttpProxy {
|
||||
fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
default_headers: &HeaderMap,
|
||||
metadata: &AuthorizationMetadata,
|
||||
) -> Result<Self> {
|
||||
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|error| anyhow!(error))?);
|
||||
let address = match server.server_addr() {
|
||||
tiny_http::ListenAddr::IP(address) => address,
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
_ => return Err(anyhow!("unable to determine OAuth HTTP proxy bind address")),
|
||||
};
|
||||
let base_url = format!("http://{address}");
|
||||
|
||||
let mut routes = HashMap::new();
|
||||
routes.insert(
|
||||
TOKEN_PROXY_PATH.to_string(),
|
||||
metadata.token_endpoint.clone(),
|
||||
);
|
||||
if let Some(registration_endpoint) = metadata.registration_endpoint.clone() {
|
||||
routes.insert(REGISTRATION_PROXY_PATH.to_string(), registration_endpoint);
|
||||
}
|
||||
let routes = Arc::new(routes);
|
||||
|
||||
let server_for_thread = Arc::clone(&server);
|
||||
let routes_for_thread = Arc::clone(&routes);
|
||||
let default_headers = default_headers.clone();
|
||||
let runtime = tokio::runtime::Handle::current();
|
||||
thread::spawn(move || {
|
||||
while let Ok(mut request) = server_for_thread.recv() {
|
||||
let request_url = request.url().to_string();
|
||||
let route_key = request_url
|
||||
.split_once('?')
|
||||
.map(|(path, _)| path)
|
||||
.unwrap_or(request_url.as_str())
|
||||
.to_string();
|
||||
let Some(target_url) = routes_for_thread.get(&route_key).cloned() else {
|
||||
let _ = request.respond(Response::empty(StatusCode::NOT_FOUND.as_u16()));
|
||||
continue;
|
||||
};
|
||||
|
||||
let mut target = target_url;
|
||||
if let Some((_, query)) = request_url.split_once('?') {
|
||||
target.push('?');
|
||||
target.push_str(query);
|
||||
}
|
||||
|
||||
let method = request.method().as_str().to_string();
|
||||
let mut body = Vec::new();
|
||||
if request.as_reader().read_to_end(&mut body).is_err() {
|
||||
let _ = request.respond(Response::empty(StatusCode::BAD_REQUEST.as_u16()));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Preserve the original request shape as much as we can, but
|
||||
// let the selected `HttpClient` own the outbound HTTP call.
|
||||
let mut headers = protocol_headers(&default_headers);
|
||||
for header in request.headers() {
|
||||
let field = header.field.as_str().to_string();
|
||||
if field.eq_ignore_ascii_case("host")
|
||||
|| field.eq_ignore_ascii_case("content-length")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
headers.push(HttpHeader {
|
||||
name: field,
|
||||
value: header.value.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let response = runtime.block_on(http_client.http_request(HttpRequestParams {
|
||||
method,
|
||||
url: target,
|
||||
headers,
|
||||
body: Some(body.into()),
|
||||
timeout_ms: Some(DISCOVERY_TIMEOUT_MS),
|
||||
request_id: next_request_id("oauth-proxy"),
|
||||
stream_response: false,
|
||||
}));
|
||||
|
||||
let response = match response {
|
||||
Ok(response) => response,
|
||||
Err(_) => {
|
||||
let _ = request.respond(Response::empty(StatusCode::BAD_GATEWAY.as_u16()));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let mut proxy_response = Response::from_data(response.body.into_inner())
|
||||
.with_status_code(response.status);
|
||||
for header in response.headers {
|
||||
if let Ok(header) = Header::from_bytes(header.name.as_bytes(), header.value) {
|
||||
proxy_response.add_header(header);
|
||||
}
|
||||
}
|
||||
let _ = request.respond(proxy_response);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
server,
|
||||
base_url,
|
||||
routes,
|
||||
})
|
||||
}
|
||||
|
||||
fn rewrite_metadata(&self, metadata: &AuthorizationMetadata) -> Result<AuthorizationMetadata> {
|
||||
let mut rewritten = metadata.clone();
|
||||
rewritten.token_endpoint = self.proxied_url(TOKEN_PROXY_PATH)?;
|
||||
rewritten.registration_endpoint = if self.routes.contains_key(REGISTRATION_PROXY_PATH) {
|
||||
Some(self.proxied_url(REGISTRATION_PROXY_PATH)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(rewritten)
|
||||
}
|
||||
|
||||
fn proxied_url(&self, path: &str) -> Result<String> {
|
||||
let mut url = Url::parse(&self.base_url).map_err(|error| anyhow!(error))?;
|
||||
url.set_path(path);
|
||||
Ok(url.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for OAuthHttpProxy {
|
||||
fn drop(&mut self) {
|
||||
// Wake the background thread so it can notice shutdown and exit.
|
||||
self.server.unblock();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StaticCredentialStore {
|
||||
stored: Arc<tokio::sync::Mutex<Option<StoredCredentials>>>,
|
||||
}
|
||||
|
||||
impl StaticCredentialStore {
|
||||
fn new(credentials: StoredCredentials) -> Self {
|
||||
Self {
|
||||
stored: Arc::new(tokio::sync::Mutex::new(Some(credentials))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl CredentialStore for StaticCredentialStore {
|
||||
async fn load(&self) -> Result<Option<StoredCredentials>, AuthError> {
|
||||
Ok(self.stored.lock().await.clone())
|
||||
}
|
||||
|
||||
async fn save(&self, credentials: StoredCredentials) -> Result<(), AuthError> {
|
||||
*self.stored.lock().await = Some(credentials);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn clear(&self) -> Result<(), AuthError> {
|
||||
*self.stored.lock().await = None;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn protocol_headers(headers: &HeaderMap) -> Vec<HttpHeader> {
|
||||
headers
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
Some(HttpHeader {
|
||||
name: name.as_str().to_string(),
|
||||
value: value.to_str().ok()?.to_string(),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn next_request_id(prefix: &str) -> String {
|
||||
static NEXT_REQUEST_ID: AtomicU64 = AtomicU64::new(1);
|
||||
let id = NEXT_REQUEST_ID.fetch_add(1, Ordering::Relaxed);
|
||||
format!("{prefix}-{id}")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Mutex;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_exec_server::ExecServerError;
|
||||
use codex_exec_server::HttpRequestResponse;
|
||||
use futures::FutureExt;
|
||||
use futures::future::BoxFuture;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct TestHttpClient {
|
||||
requests: Arc<Mutex<Vec<HttpRequestParams>>>,
|
||||
responses: Arc<Mutex<Vec<Result<HttpRequestResponse, ExecServerError>>>>,
|
||||
}
|
||||
|
||||
impl TestHttpClient {
|
||||
fn push_response(&self, response: Result<HttpRequestResponse, ExecServerError>) {
|
||||
self.responses.lock().unwrap().push(response);
|
||||
}
|
||||
|
||||
fn requests(&self) -> Vec<HttpRequestParams> {
|
||||
self.requests.lock().unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpClient for TestHttpClient {
|
||||
fn http_request(
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>> {
|
||||
let requests = Arc::clone(&self.requests);
|
||||
let responses = Arc::clone(&self.responses);
|
||||
async move {
|
||||
requests.lock().unwrap().push(params);
|
||||
responses.lock().unwrap().remove(0)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn http_request_stream(
|
||||
&self,
|
||||
_params: HttpRequestParams,
|
||||
) -> BoxFuture<
|
||||
'_,
|
||||
Result<
|
||||
(
|
||||
codex_exec_server::HttpRequestResponse,
|
||||
codex_exec_server::HttpResponseBodyStream,
|
||||
),
|
||||
ExecServerError,
|
||||
>,
|
||||
> {
|
||||
async move { Err(ExecServerError::Protocol("unused".to_string())) }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn discover_authorization_metadata_uses_http_client() -> Result<()> {
|
||||
let http_client = TestHttpClient::default();
|
||||
http_client.push_response(Ok(HttpRequestResponse {
|
||||
status: StatusCode::OK.as_u16(),
|
||||
headers: Vec::new(),
|
||||
body: serde_json::to_vec(&AuthorizationMetadata {
|
||||
authorization_endpoint: "https://example.com/authorize".to_string(),
|
||||
token_endpoint: "https://example.com/token".to_string(),
|
||||
registration_endpoint: None,
|
||||
issuer: None,
|
||||
jwks_uri: None,
|
||||
scopes_supported: None,
|
||||
response_types_supported: None,
|
||||
additional_fields: HashMap::new(),
|
||||
})?
|
||||
.into(),
|
||||
}));
|
||||
|
||||
let metadata = discover_authorization_metadata(
|
||||
"https://example.com/mcp",
|
||||
&HeaderMap::new(),
|
||||
Arc::new(http_client.clone()),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(
|
||||
metadata.authorization_endpoint,
|
||||
"https://example.com/authorize"
|
||||
);
|
||||
assert_eq!(metadata.token_endpoint, "https://example.com/token");
|
||||
assert_eq!(
|
||||
http_client.requests()[0].url,
|
||||
"https://example.com/.well-known/oauth-authorization-server/mcp"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn oauth_http_proxy_forwards_token_requests() -> Result<()> {
|
||||
let http_client = TestHttpClient::default();
|
||||
http_client.push_response(Ok(HttpRequestResponse {
|
||||
status: StatusCode::OK.as_u16(),
|
||||
headers: vec![HttpHeader {
|
||||
name: "content-type".to_string(),
|
||||
value: "application/json".to_string(),
|
||||
}],
|
||||
body: br#"{"access_token":"abc","token_type":"Bearer"}"#.to_vec().into(),
|
||||
}));
|
||||
|
||||
let proxy = OAuthHttpProxy::new(
|
||||
Arc::new(http_client.clone()),
|
||||
&HeaderMap::new(),
|
||||
&AuthorizationMetadata {
|
||||
authorization_endpoint: "https://example.com/authorize".to_string(),
|
||||
token_endpoint: "https://remote.example.com/token".to_string(),
|
||||
registration_endpoint: None,
|
||||
issuer: None,
|
||||
jwks_uri: None,
|
||||
scopes_supported: None,
|
||||
response_types_supported: None,
|
||||
additional_fields: HashMap::new(),
|
||||
},
|
||||
)?;
|
||||
|
||||
let response = reqwest::Client::new()
|
||||
.post(proxy.proxied_url(TOKEN_PROXY_PATH)?)
|
||||
.header("content-type", "application/x-www-form-urlencoded")
|
||||
.body("grant_type=refresh_token")
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let requests = http_client.requests();
|
||||
assert_eq!(requests.len(), 1);
|
||||
assert_eq!(requests[0].url, "https://remote.example.com/token");
|
||||
assert_eq!(
|
||||
String::from_utf8(requests[0].body.clone().unwrap().into_inner())?,
|
||||
"grant_type=refresh_token"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,6 @@ use std::time::Instant;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_api::SharedAuthProvider;
|
||||
use codex_client::build_reqwest_client_with_custom_ca;
|
||||
use codex_config::types::McpServerEnvVar;
|
||||
use codex_exec_server::HttpClient;
|
||||
use futures::FutureExt;
|
||||
@@ -47,7 +46,7 @@ use rmcp::service::{self};
|
||||
use rmcp::transport::StreamableHttpClientTransport;
|
||||
use rmcp::transport::auth::AuthClient;
|
||||
use rmcp::transport::auth::AuthError;
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use rmcp::transport::auth::StoredCredentials;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpError;
|
||||
use serde::Deserialize;
|
||||
@@ -65,10 +64,10 @@ use crate::http_client_adapter::StreamableHttpClientAdapterError;
|
||||
use crate::load_oauth_tokens;
|
||||
use crate::oauth::OAuthPersistor;
|
||||
use crate::oauth::StoredOAuthTokens;
|
||||
use crate::oauth_http_client::create_oauth_http_setup;
|
||||
use crate::stdio_server_launcher::StdioServerCommand;
|
||||
use crate::stdio_server_launcher::StdioServerLauncher;
|
||||
use crate::stdio_server_launcher::StdioServerTransport;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
use codex_config::types::OAuthCredentialsStoreMode;
|
||||
|
||||
@@ -946,28 +945,21 @@ async fn create_oauth_transport_and_runtime(
|
||||
StreamableHttpClientTransport<AuthClient<StreamableHttpClientAdapter>>,
|
||||
OAuthPersistor,
|
||||
)> {
|
||||
let builder = apply_default_headers(reqwest::Client::builder(), &default_headers);
|
||||
let oauth_metadata_client = build_reqwest_client_with_custom_ca(builder)?;
|
||||
// TODO(aibrahim): teach OAuth bootstrap and refresh to use the same
|
||||
// shared HTTP client abstraction instead of always creating the local
|
||||
// reqwest metadata client here.
|
||||
let mut oauth_state =
|
||||
OAuthState::new(url.to_string(), Some(oauth_metadata_client.clone())).await?;
|
||||
|
||||
oauth_state
|
||||
.set_credentials(
|
||||
&initial_tokens.client_id,
|
||||
initial_tokens.token_response.0.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let manager = match oauth_state {
|
||||
OAuthState::Authorized(manager) => manager,
|
||||
OAuthState::Unauthorized(manager) => manager,
|
||||
OAuthState::Session(_) | OAuthState::AuthorizedHttpClient(_) => {
|
||||
return Err(anyhow!("unexpected OAuth state during client setup"));
|
||||
}
|
||||
};
|
||||
// `create_oauth_http_setup` returns both the initialized RMCP
|
||||
// `AuthorizationManager` and the proxy bridge that keeps later token
|
||||
// refresh requests on the selected `HttpClient`. Keep both pieces together
|
||||
// here so the transport and persistor share the same OAuth runtime.
|
||||
let oauth_http_setup = create_oauth_http_setup(
|
||||
url,
|
||||
&default_headers,
|
||||
http_client.clone(),
|
||||
Some(StoredCredentials {
|
||||
client_id: initial_tokens.client_id.clone(),
|
||||
token_response: Some(initial_tokens.token_response.0.clone()),
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
let manager = oauth_http_setup.authorization_manager;
|
||||
|
||||
let auth_client = AuthClient::new(
|
||||
StreamableHttpClientAdapter::new(http_client, default_headers, /*auth_provider*/ None),
|
||||
@@ -984,6 +976,7 @@ async fn create_oauth_transport_and_runtime(
|
||||
server_name.to_string(),
|
||||
url.to_string(),
|
||||
auth_manager,
|
||||
oauth_http_setup.proxy,
|
||||
credentials_store,
|
||||
Some(initial_tokens),
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user