Compare commits

...

1 Commits

Author SHA1 Message Date
Steven Lee
9a64fd84e5 Add callback ids to local MCP OAuth redirects 2026-04-29 22:39:58 -07:00
3 changed files with 97 additions and 0 deletions

1
codex-rs/Cargo.lock generated
View File

@@ -3263,6 +3263,7 @@ version = "0.0.0"
dependencies = [
"anyhow",
"axum",
"base64 0.22.1",
"bytes",
"codex-api",
"codex-client",

View File

@@ -13,6 +13,7 @@ axum = { workspace = true, default-features = false, features = [
"http1",
"tokio",
] }
base64 = { workspace = true }
codex-api = { workspace = true }
codex-client = { workspace = true }
codex-config = { workspace = true }

View File

@@ -7,9 +7,13 @@ use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use anyhow::bail;
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use reqwest::ClientBuilder;
use reqwest::Url;
use rmcp::transport::auth::OAuthState;
use sha2::Digest;
use sha2::Sha256;
use tiny_http::Response;
use tiny_http::Server;
use tokio::sync::oneshot;
@@ -378,6 +382,31 @@ fn resolve_redirect_uri(server: &Server, callback_url: Option<&str>) -> Result<S
Ok(callback_url.to_string())
}
fn callback_id_from_server_url(server_url: &str) -> Result<String> {
let mut parsed =
Url::parse(server_url).with_context(|| format!("invalid MCP server URL `{server_url}`"))?;
parsed
.host_str()
.ok_or_else(|| anyhow!("MCP server URL `{server_url}` must include a host"))?;
parsed.set_fragment(None);
let digest = Sha256::digest(parsed.as_str().as_bytes());
Ok(URL_SAFE_NO_PAD.encode(&digest[..9]))
}
fn append_callback_id_to_redirect_uri(redirect_uri: &str, callback_id: &str) -> Result<String> {
let mut parsed = Url::parse(redirect_uri)
.with_context(|| format!("invalid redirect URI `{redirect_uri}`"))?;
let path = parsed.path();
let new_path = if path.ends_with('/') {
format!("{path}{callback_id}")
} else {
format!("{path}/{callback_id}")
};
parsed.set_path(&new_path);
Ok(parsed.to_string())
}
fn callback_path_from_redirect_uri(redirect_uri: &str) -> Result<String> {
let parsed = Url::parse(redirect_uri)
.with_context(|| format!("invalid redirect URI `{redirect_uri}`"))?;
@@ -428,6 +457,8 @@ impl OauthLoginFlow {
};
let redirect_uri = resolve_redirect_uri(&server, callback_url)?;
let callback_id = callback_id_from_server_url(server_url)?;
let redirect_uri = append_callback_id_to_redirect_uri(&redirect_uri, &callback_id)?;
let callback_path = callback_path_from_redirect_uri(&redirect_uri)?;
let (tx, rx) = oneshot::channel();
@@ -577,7 +608,9 @@ mod tests {
use super::CallbackOutcome;
use super::OAuthProviderError;
use super::append_callback_id_to_redirect_uri;
use super::append_query_param;
use super::callback_id_from_server_url;
use super::callback_path_from_redirect_uri;
use super::parse_oauth_callback;
@@ -593,6 +626,19 @@ mod tests {
assert!(matches!(parsed, CallbackOutcome::Success(_)));
}
#[test]
fn parse_oauth_callback_accepts_callback_id_path() {
let parsed =
parse_oauth_callback("/callback/abc123?code=abc&state=xyz", "/callback/abc123");
assert!(matches!(parsed, CallbackOutcome::Success(_)));
}
#[test]
fn parse_oauth_callback_rejects_missing_callback_id_path() {
let parsed = parse_oauth_callback("/callback?code=abc&state=xyz", "/callback/abc123");
assert!(matches!(parsed, CallbackOutcome::Invalid));
}
#[test]
fn parse_oauth_callback_rejects_wrong_path() {
let parsed = parse_oauth_callback("/callback?code=abc&state=xyz", "/oauth/callback");
@@ -622,6 +668,55 @@ mod tests {
assert_eq!(path, "/oauth/callback");
}
#[test]
fn callback_id_is_bound_to_server_url() {
let callback_id = callback_id_from_server_url("https://mcp.example.com/mcp?tenant=one")
.expect("server URL should parse");
let same_without_fragment =
callback_id_from_server_url("https://mcp.example.com/mcp?tenant=one#unused")
.expect("server URL should parse");
let different_path = callback_id_from_server_url("https://mcp.example.com/sse?tenant=one")
.expect("server URL should parse");
let different_query = callback_id_from_server_url("https://mcp.example.com/mcp?tenant=two")
.expect("server URL should parse");
let different_origin = callback_id_from_server_url("https://mcp.example.com:8443/mcp")
.expect("server URL should parse");
assert_eq!(callback_id, same_without_fragment);
assert_ne!(callback_id, different_path);
assert_ne!(callback_id, different_query);
assert_ne!(callback_id, different_origin);
assert_eq!(callback_id.len(), 12);
assert!(
callback_id
.chars()
.all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_')
);
}
#[test]
fn callback_id_is_appended_to_redirect_uri_path() {
let redirect_uri =
append_callback_id_to_redirect_uri("http://127.0.0.1:1234/callback", "abc123")
.expect("redirect URI should parse");
assert_eq!(redirect_uri, "http://127.0.0.1:1234/callback/abc123");
}
#[test]
fn callback_id_is_appended_before_redirect_uri_query() {
let redirect_uri = append_callback_id_to_redirect_uri(
"https://callbacks.example.com/oauth/callback?provider=github",
"abc123",
)
.expect("redirect URI should parse");
assert_eq!(
redirect_uri,
"https://callbacks.example.com/oauth/callback/abc123?provider=github"
);
}
#[test]
fn append_query_param_adds_resource_to_absolute_url() {
let url = append_query_param(