Compare commits

...

4 Commits

Author SHA1 Message Date
shijie-openai
f4e3b2f945 wip: adding the ability to refresh rmcp client per thread after changes 2025-12-09 15:37:18 -08:00
shijie-openai
bf53b8c3c7 Uses the latest config after MCP update 2025-12-09 15:34:19 -08:00
shijie-openai
3d57b24de0 Added timeout, clean up auth server and emit event on failure and timeout 2025-12-09 15:34:19 -08:00
shijie-openai
11783eaeef feat: support rmcp login in the same session 2025-12-09 15:34:19 -08:00
14 changed files with 545 additions and 65 deletions

1
codex-rs/Cargo.lock generated
View File

@@ -887,6 +887,7 @@ dependencies = [
"codex-file-search",
"codex-login",
"codex-protocol",
"codex-rmcp-client",
"codex-utils-json-to-toml",
"core_test_support",
"mcp-types",

View File

@@ -139,6 +139,11 @@ client_request_definitions! {
response: v2::ModelListResponse,
},
McpServerOauthLogin => "mcpServer/oauth/login" {
params: v2::McpServerOauthLoginParams,
response: v2::McpServerOauthLoginResponse,
},
McpServersList => "mcpServers/list" {
params: v2::ListMcpServersParams,
response: v2::ListMcpServersResponse,
@@ -524,6 +529,7 @@ server_notification_definitions! {
CommandExecutionOutputDelta => "item/commandExecution/outputDelta" (v2::CommandExecutionOutputDeltaNotification),
FileChangeOutputDelta => "item/fileChange/outputDelta" (v2::FileChangeOutputDeltaNotification),
McpToolCallProgress => "item/mcpToolCall/progress" (v2::McpToolCallProgressNotification),
McpServerOauthLoginCompleted => "mcpServer/oauthLogin/completed" (v2::McpServerOauthLoginCompletedNotification),
AccountUpdated => "account/updated" (v2::AccountUpdatedNotification),
AccountRateLimitsUpdated => "account/rateLimits/updated" (v2::AccountRateLimitsUpdatedNotification),
ReasoningSummaryTextDelta => "item/reasoning/summaryTextDelta" (v2::ReasoningSummaryTextDeltaNotification),

View File

@@ -688,6 +688,26 @@ pub struct ListMcpServersResponse {
pub next_cursor: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]
pub struct McpServerOauthLoginParams {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[ts(optional)]
pub scopes: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[ts(optional)]
pub timeout_secs: Option<i64>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]
pub struct McpServerOauthLoginResponse {
pub authorization_url: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]
@@ -1467,6 +1487,17 @@ pub struct McpToolCallProgressNotification {
pub message: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]
pub struct McpServerOauthLoginCompletedNotification {
pub name: String,
pub success: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[ts(optional)]
pub error: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]

View File

@@ -26,6 +26,7 @@ codex-login = { workspace = true }
codex-protocol = { workspace = true }
codex-app-server-protocol = { workspace = true }
codex-feedback = { workspace = true }
codex-rmcp-client = { workspace = true }
codex-utils-json-to-toml = { workspace = true }
chrono = { workspace = true }
serde = { workspace = true, features = ["derive"] }

View File

@@ -55,6 +55,9 @@ use codex_app_server_protocol::LoginChatGptResponse;
use codex_app_server_protocol::LogoutAccountResponse;
use codex_app_server_protocol::LogoutChatGptResponse;
use codex_app_server_protocol::McpServer;
use codex_app_server_protocol::McpServerOauthLoginCompletedNotification;
use codex_app_server_protocol::McpServerOauthLoginParams;
use codex_app_server_protocol::McpServerOauthLoginResponse;
use codex_app_server_protocol::ModelListParams;
use codex_app_server_protocol::ModelListResponse;
use codex_app_server_protocol::NewConversationParams;
@@ -115,6 +118,7 @@ use codex_core::config::Config;
use codex_core::config::ConfigOverrides;
use codex_core::config::ConfigToml;
use codex_core::config::edit::ConfigEditsBuilder;
use codex_core::config::types::McpServerTransportConfig;
use codex_core::config_loader::load_config_as_toml;
use codex_core::default_client::get_codex_user_agent;
use codex_core::exec::ExecParams;
@@ -147,6 +151,7 @@ use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::SessionMetaLine;
use codex_protocol::protocol::USER_MESSAGE_BEGIN;
use codex_protocol::user_input::UserInput as CoreInputItem;
use codex_rmcp_client::perform_oauth_login_return_url;
use codex_utils_json_to_toml::json_to_toml;
use std::collections::HashMap;
use std::collections::HashSet;
@@ -161,6 +166,7 @@ use std::time::Duration;
use tokio::select;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use toml::Value as TomlValue;
use tracing::error;
use tracing::info;
use tracing::warn;
@@ -198,6 +204,7 @@ pub(crate) struct CodexMessageProcessor {
outgoing: Arc<OutgoingMessageSender>,
codex_linux_sandbox_exe: Option<PathBuf>,
config: Arc<Config>,
cli_overrides: Vec<(String, TomlValue)>,
conversation_listeners: HashMap<Uuid, oneshot::Sender<()>>,
active_login: Arc<Mutex<Option<ActiveLogin>>>,
// Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives.
@@ -244,6 +251,7 @@ impl CodexMessageProcessor {
outgoing: Arc<OutgoingMessageSender>,
codex_linux_sandbox_exe: Option<PathBuf>,
config: Arc<Config>,
cli_overrides: Vec<(String, TomlValue)>,
feedback: CodexFeedback,
) -> Self {
Self {
@@ -252,6 +260,7 @@ impl CodexMessageProcessor {
outgoing,
codex_linux_sandbox_exe,
config,
cli_overrides,
conversation_listeners: HashMap::new(),
active_login: Arc::new(Mutex::new(None)),
pending_interrupts: Arc::new(Mutex::new(HashMap::new())),
@@ -261,6 +270,16 @@ impl CodexMessageProcessor {
}
}
async fn load_latest_config(&self) -> Result<Config, JSONRPCErrorError> {
Config::load_with_cli_overrides(self.cli_overrides.clone(), ConfigOverrides::default())
.await
.map_err(|err| JSONRPCErrorError {
code: INTERNAL_ERROR_CODE,
message: format!("failed to reload config: {err}"),
data: None,
})
}
fn review_request_from_target(
target: ApiReviewTarget,
) -> Result<(ReviewRequest, String), JSONRPCErrorError> {
@@ -369,6 +388,9 @@ impl CodexMessageProcessor {
ClientRequest::ModelList { request_id, params } => {
self.list_models(request_id, params).await;
}
ClientRequest::McpServerOauthLogin { request_id, params } => {
self.mcp_server_oauth_login(request_id, params).await;
}
ClientRequest::McpServersList { request_id, params } => {
self.list_mcp_servers(request_id, params).await;
}
@@ -1916,6 +1938,115 @@ impl CodexMessageProcessor {
self.outgoing.send_response(request_id, response).await;
}
async fn mcp_server_oauth_login(
&self,
request_id: RequestId,
params: McpServerOauthLoginParams,
) {
let config = match self.load_latest_config().await {
Ok(config) => config,
Err(error) => {
self.outgoing.send_error(request_id, error).await;
return;
}
};
if !config.features.enabled(Feature::RmcpClient) {
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: "OAuth login is only supported when [features].rmcp_client is true in config.toml".to_string(),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
}
let McpServerOauthLoginParams {
name,
scopes,
timeout_secs,
} = params;
let Some(server) = config.mcp_servers.get(&name) else {
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: format!("No MCP server named '{name}' found."),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
};
let (url, http_headers, env_http_headers) = match &server.transport {
McpServerTransportConfig::StreamableHttp {
url,
http_headers,
env_http_headers,
..
} => (url.clone(), http_headers.clone(), env_http_headers.clone()),
_ => {
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: "OAuth login is only supported for streamable HTTP servers."
.to_string(),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
}
};
match perform_oauth_login_return_url(
&name,
&url,
config.mcp_oauth_credentials_store_mode,
http_headers,
env_http_headers,
scopes.as_deref().unwrap_or_default(),
timeout_secs,
)
.await
{
Ok(handle) => {
let authorization_url = handle.authorization_url().to_string();
let notification_name = name.clone();
let outgoing = Arc::clone(&self.outgoing);
let conversation_manager = Arc::clone(&self.conversation_manager);
tokio::spawn(async move {
let (success, error) = match handle.wait().await {
Ok(()) => (true, None),
Err(err) => (false, Some(err.to_string())),
};
if success {
conversation_manager.mark_mcp_oauth_success(Utc::now().timestamp());
}
let notification = ServerNotification::McpServerOauthLoginCompleted(
McpServerOauthLoginCompletedNotification {
name: notification_name,
success,
error,
},
);
outgoing.send_server_notification(notification).await;
});
let response = McpServerOauthLoginResponse { authorization_url };
self.outgoing.send_response(request_id, response).await;
}
Err(err) => {
let error = JSONRPCErrorError {
code: INTERNAL_ERROR_CODE,
message: format!("failed to login to MCP server '{name}': {err}"),
data: None,
};
self.outgoing.send_error(request_id, error).await;
}
}
}
async fn list_mcp_servers(&self, request_id: RequestId, params: ListMcpServersParams) {
let snapshot = collect_mcp_snapshot(self.config.as_ref()).await;

View File

@@ -59,6 +59,7 @@ impl MessageProcessor {
outgoing.clone(),
codex_linux_sandbox_exe,
Arc::clone(&config),
cli_overrides.clone(),
feedback,
);
let config_api = ConfigApi::new(config.codex_home.clone(), cli_overrides);

View File

@@ -55,6 +55,8 @@ use mcp_types::ReadResourceResult;
use mcp_types::RequestId;
use serde_json;
use serde_json::Value;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio::sync::oneshot;
@@ -170,6 +172,7 @@ impl Codex {
models_manager: Arc<ModelsManager>,
conversation_history: InitialHistory,
session_source: SessionSource,
mcp_oauth_refresh_clock: Arc<AtomicI64>,
) -> CodexResult<CodexSpawnOk> {
let (tx_sub, rx_sub) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
let (tx_event, rx_event) = async_channel::unbounded();
@@ -210,6 +213,7 @@ impl Codex {
tx_event.clone(),
conversation_history,
session_source_clone,
mcp_oauth_refresh_clock.clone(),
)
.await
.map_err(|e| {
@@ -466,6 +470,7 @@ impl Session {
}
}
#[allow(clippy::too_many_arguments)]
async fn new(
session_configuration: SessionConfiguration,
config: Arc<Config>,
@@ -474,6 +479,7 @@ impl Session {
tx_event: Sender<Event>,
initial_history: InitialHistory,
session_source: SessionSource,
mcp_oauth_refresh_clock: Arc<AtomicI64>,
) -> anyhow::Result<Arc<Self>> {
debug!(
"Configuring session: model={}; provider={:?}",
@@ -583,8 +589,11 @@ impl Session {
let state = SessionState::new(session_configuration.clone());
let services = SessionServices {
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())),
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::new(
mcp_oauth_refresh_clock.clone(),
))),
mcp_startup_cancellation_token: CancellationToken::new(),
mcp_oauth_refresh_clock,
unified_exec_manager: UnifiedExecSessionManager::default(),
notifier: UserNotifier::new(config.notify.clone()),
rollout: Mutex::new(Some(rollout_recorder)),
@@ -1386,6 +1395,7 @@ impl Session {
server: &str,
params: Option<ListResourcesRequestParams>,
) -> anyhow::Result<ListResourcesResult> {
self.refresh_mcp_clients_if_needed().await?;
self.services
.mcp_connection_manager
.read()
@@ -1399,6 +1409,7 @@ impl Session {
server: &str,
params: Option<ListResourceTemplatesRequestParams>,
) -> anyhow::Result<ListResourceTemplatesResult> {
self.refresh_mcp_clients_if_needed().await?;
self.services
.mcp_connection_manager
.read()
@@ -1412,6 +1423,7 @@ impl Session {
server: &str,
params: ReadResourceRequestParams,
) -> anyhow::Result<ReadResourceResult> {
self.refresh_mcp_clients_if_needed().await?;
self.services
.mcp_connection_manager
.read()
@@ -1426,6 +1438,7 @@ impl Session {
tool: &str,
arguments: Option<serde_json::Value>,
) -> anyhow::Result<CallToolResult> {
self.refresh_mcp_clients_if_needed().await?;
self.services
.mcp_connection_manager
.read()
@@ -1435,6 +1448,7 @@ impl Session {
}
pub(crate) async fn parse_mcp_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
self.refresh_mcp_clients_if_needed().await.ok()?;
self.services
.mcp_connection_manager
.read()
@@ -1443,6 +1457,42 @@ impl Session {
.await
}
async fn refresh_mcp_clients_if_needed(&self) -> anyhow::Result<()> {
let current_clock = self.services.mcp_oauth_refresh_clock.load(Ordering::SeqCst);
let last_seen = {
let manager = self.services.mcp_connection_manager.read().await;
manager.last_refresh_seen()
};
if current_clock <= last_seen {
return Ok(());
}
let config = {
let state = self.state.lock().await;
state
.session_configuration
.original_config_do_not_use
.clone()
};
let store_mode = config.mcp_oauth_credentials_store_mode;
let auth_statuses = compute_auth_statuses(config.mcp_servers.iter(), store_mode).await;
{
let mut manager = self.services.mcp_connection_manager.write().await;
manager
.refresh_if_needed(
&config.mcp_servers,
store_mode,
auth_statuses,
self.tx_event.clone(),
self.services.mcp_startup_cancellation_token.clone(),
)
.await;
}
Ok(())
}
pub async fn interrupt_task(self: &Arc<Self>) {
info!("interrupt received: abort current task, if any");
let has_active_turn = { self.active_turn.lock().await.is_some() };
@@ -2882,9 +2932,13 @@ mod tests {
let state = SessionState::new(session_configuration.clone());
let mcp_oauth_refresh_clock = Arc::new(AtomicI64::new(0));
let services = SessionServices {
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())),
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::new(
mcp_oauth_refresh_clock.clone(),
))),
mcp_startup_cancellation_token: CancellationToken::new(),
mcp_oauth_refresh_clock,
unified_exec_manager: UnifiedExecSessionManager::default(),
notifier: UserNotifier::new(None),
rollout: Mutex::new(None),
@@ -2964,9 +3018,13 @@ mod tests {
let state = SessionState::new(session_configuration.clone());
let mcp_oauth_refresh_clock = Arc::new(AtomicI64::new(0));
let services = SessionServices {
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())),
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::new(
mcp_oauth_refresh_clock.clone(),
))),
mcp_startup_cancellation_token: CancellationToken::new(),
mcp_oauth_refresh_clock,
unified_exec_manager: UnifiedExecSessionManager::default(),
notifier: UserNotifier::new(None),
rollout: Mutex::new(None),

View File

@@ -51,6 +51,7 @@ pub(crate) async fn run_codex_conversation_interactive(
models_manager,
initial_history.unwrap_or(InitialHistory::New),
SessionSource::SubAgent(SubAgentSource::Review),
parent_session.services.mcp_oauth_refresh_clock.clone(),
)
.await?;
let codex = Arc::new(codex);

View File

@@ -22,6 +22,8 @@ use codex_protocol::protocol::SessionSource;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
use tokio::sync::RwLock;
/// Represents a newly created Codex conversation, including the first event
@@ -39,6 +41,7 @@ pub struct ConversationManager {
auth_manager: Arc<AuthManager>,
models_manager: Arc<ModelsManager>,
session_source: SessionSource,
mcp_oauth_refresh_clock: Arc<AtomicI64>,
}
impl ConversationManager {
@@ -48,6 +51,7 @@ impl ConversationManager {
auth_manager: auth_manager.clone(),
session_source,
models_manager: Arc::new(ModelsManager::new(auth_manager)),
mcp_oauth_refresh_clock: Arc::new(AtomicI64::new(0)),
}
}
@@ -65,6 +69,15 @@ impl ConversationManager {
self.session_source.clone()
}
pub fn mcp_oauth_refresh_clock(&self) -> Arc<AtomicI64> {
self.mcp_oauth_refresh_clock.clone()
}
pub fn mark_mcp_oauth_success(&self, timestamp_secs: i64) {
self.mcp_oauth_refresh_clock
.store(timestamp_secs, Ordering::SeqCst);
}
pub async fn new_conversation(&self, config: Config) -> CodexResult<NewConversation> {
self.spawn_conversation(
config,
@@ -89,6 +102,7 @@ impl ConversationManager {
models_manager,
InitialHistory::New,
self.session_source.clone(),
self.mcp_oauth_refresh_clock.clone(),
)
.await?;
self.finalize_spawn(codex, conversation_id).await
@@ -166,6 +180,7 @@ impl ConversationManager {
self.models_manager.clone(),
initial_history,
self.session_source.clone(),
self.mcp_oauth_refresh_clock.clone(),
)
.await?;
self.finalize_spawn(codex, conversation_id).await
@@ -207,6 +222,7 @@ impl ConversationManager {
self.models_manager.clone(),
history,
self.session_source.clone(),
self.mcp_oauth_refresh_clock.clone(),
)
.await?;

View File

@@ -1,5 +1,7 @@
pub mod auth;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicI64;
use async_channel::unbounded;
use codex_protocol::protocol::McpListToolsResponseEvent;
@@ -29,7 +31,8 @@ pub async fn collect_mcp_snapshot(config: &Config) -> McpListToolsResponseEvent
)
.await;
let mut mcp_connection_manager = McpConnectionManager::default();
let mcp_oauth_refresh_clock = Arc::new(AtomicI64::new(0));
let mut mcp_connection_manager = McpConnectionManager::new(mcp_oauth_refresh_clock);
let (tx_event, rx_event) = unbounded();
drop(rx_event);
let cancel_token = CancellationToken::new();

View File

@@ -12,6 +12,8 @@ use std::env;
use std::ffi::OsString;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
use std::time::Duration;
use crate::mcp::auth::McpAuthStatusEntry;
@@ -260,13 +262,70 @@ pub struct SandboxState {
}
/// A thin wrapper around a set of running [`RmcpClient`] instances.
#[derive(Default)]
pub(crate) struct McpConnectionManager {
clients: HashMap<String, AsyncManagedClient>,
elicitation_requests: ElicitationRequestManager,
mcp_oauth_refresh_clock: Arc<AtomicI64>,
last_refresh_seen: AtomicI64,
config_snapshot: HashMap<String, McpServerConfig>,
store_mode_snapshot: Option<OAuthCredentialsStoreMode>,
auth_entries_snapshot: HashMap<String, McpAuthStatusEntry>,
}
impl McpConnectionManager {
pub(crate) fn new(mcp_oauth_refresh_clock: Arc<AtomicI64>) -> Self {
Self {
clients: HashMap::new(),
elicitation_requests: ElicitationRequestManager::default(),
mcp_oauth_refresh_clock,
last_refresh_seen: AtomicI64::new(0),
config_snapshot: HashMap::new(),
store_mode_snapshot: None,
auth_entries_snapshot: HashMap::new(),
}
}
fn update_snapshots(
&mut self,
mcp_servers: &HashMap<String, McpServerConfig>,
store_mode: OAuthCredentialsStoreMode,
auth_entries: &HashMap<String, McpAuthStatusEntry>,
) {
self.config_snapshot = mcp_servers.clone();
self.store_mode_snapshot = Some(store_mode);
self.auth_entries_snapshot = auth_entries.clone();
let now = self.mcp_oauth_refresh_clock.load(Ordering::SeqCst);
self.last_refresh_seen.store(now, Ordering::SeqCst);
}
pub(crate) fn last_refresh_seen(&self) -> i64 {
self.last_refresh_seen.load(Ordering::SeqCst)
}
pub(crate) async fn refresh_if_needed(
&mut self,
config: &HashMap<String, McpServerConfig>,
store_mode: OAuthCredentialsStoreMode,
auth_entries: HashMap<String, McpAuthStatusEntry>,
tx_event: Sender<Event>,
cancel_token: CancellationToken,
) {
let current = self.mcp_oauth_refresh_clock.load(Ordering::SeqCst);
if current <= self.last_refresh_seen() {
return;
}
self.initialize(
config.clone(),
store_mode,
auth_entries,
tx_event,
cancel_token,
)
.await;
self.last_refresh_seen.store(current, Ordering::SeqCst);
}
pub async fn initialize(
&mut self,
mcp_servers: HashMap<String, McpServerConfig>,
@@ -281,7 +340,9 @@ impl McpConnectionManager {
let mut clients = HashMap::new();
let mut join_set = JoinSet::new();
let elicitation_requests = ElicitationRequestManager::default();
for (server_name, cfg) in mcp_servers.into_iter().filter(|(_, cfg)| cfg.enabled) {
for (server_name, cfg) in mcp_servers.iter().filter(|(_, cfg)| cfg.enabled) {
let server_name = server_name.to_string();
let cfg = cfg.clone();
let cancel_token = cancel_token.child_token();
let _ = emit_update(
&tx_event,
@@ -333,6 +394,7 @@ impl McpConnectionManager {
}
self.clients = clients;
self.elicitation_requests = elicitation_requests.clone();
self.update_snapshots(&mcp_servers, store_mode, &auth_entries);
tokio::spawn(async move {
let outcomes = join_set.join_all().await;
let mut summary = McpStartupCompleteEvent::default();

View File

@@ -8,6 +8,7 @@ use crate::tools::sandboxing::ApprovalStore;
use crate::unified_exec::UnifiedExecSessionManager;
use crate::user_notification::UserNotifier;
use codex_otel::otel_event_manager::OtelEventManager;
use std::sync::atomic::AtomicI64;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
@@ -15,6 +16,7 @@ use tokio_util::sync::CancellationToken;
pub(crate) struct SessionServices {
pub(crate) mcp_connection_manager: Arc<RwLock<McpConnectionManager>>,
pub(crate) mcp_startup_cancellation_token: CancellationToken,
pub(crate) mcp_oauth_refresh_clock: Arc<AtomicI64>,
pub(crate) unified_exec_manager: UnifiedExecSessionManager,
pub(crate) notifier: UserNotifier,
pub(crate) rollout: Mutex<Option<RolloutRecorder>>,

View File

@@ -16,7 +16,9 @@ pub use oauth::WrappedOAuthTokenResponse;
pub use oauth::delete_oauth_tokens;
pub(crate) use oauth::load_oauth_tokens;
pub use oauth::save_oauth_tokens;
pub use perform_oauth_login::OauthLoginHandle;
pub use perform_oauth_login::perform_oauth_login;
pub use perform_oauth_login::perform_oauth_login_return_url;
pub use rmcp::model::ElicitationAction;
pub use rmcp_client::Elicitation;
pub use rmcp_client::ElicitationResponse;

View File

@@ -22,6 +22,11 @@ use crate::save_oauth_tokens;
use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
struct OauthHeaders {
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
}
struct CallbackServerGuard {
server: Arc<Server>,
}
@@ -40,70 +45,52 @@ pub async fn perform_oauth_login(
env_http_headers: Option<HashMap<String, String>>,
scopes: &[String],
) -> Result<()> {
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
let guard = CallbackServerGuard {
server: Arc::clone(&server),
let headers = OauthHeaders {
http_headers,
env_http_headers,
};
OauthLoginFlow::new(
server_name,
server_url,
store_mode,
headers,
scopes,
true,
None,
)
.await?
.finish()
.await
}
let redirect_uri = match server.server_addr() {
tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
format!("http://{}:{}/callback", addr.ip(), addr.port())
}
tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
format!("http://[{}]:{}/callback", addr.ip(), addr.port())
}
#[cfg(not(target_os = "windows"))]
_ => return Err(anyhow!("unable to determine callback address")),
pub async fn perform_oauth_login_return_url(
server_name: &str,
server_url: &str,
store_mode: OAuthCredentialsStoreMode,
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
scopes: &[String],
timeout_secs: Option<i64>,
) -> Result<OauthLoginHandle> {
let headers = OauthHeaders {
http_headers,
env_http_headers,
};
let flow = OauthLoginFlow::new(
server_name,
server_url,
store_mode,
headers,
scopes,
false,
timeout_secs,
)
.await?;
let (tx, rx) = oneshot::channel();
spawn_callback_server(server, tx);
let authorization_url = flow.authorization_url();
let completion = flow.spawn();
let default_headers = build_default_headers(http_headers, env_http_headers)?;
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
oauth_state
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
.await?;
let auth_url = oauth_state.get_authorization_url().await?;
println!("Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n");
if webbrowser::open(&auth_url).is_err() {
println!("(Browser launch failed; please copy the URL above manually.)");
}
let (code, csrf_state) = timeout(Duration::from_secs(300), rx)
.await
.context("timed out waiting for OAuth callback")?
.context("OAuth callback was cancelled")?;
oauth_state
.handle_callback(&code, &csrf_state)
.await
.context("failed to handle OAuth callback")?;
let (client_id, credentials_opt) = oauth_state
.get_credentials()
.await
.context("failed to retrieve OAuth credentials")?;
let credentials =
credentials_opt.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
let expires_at = compute_expires_at_millis(&credentials);
let stored = StoredOAuthTokens {
server_name: server_name.to_string(),
url: server_url.to_string(),
client_id,
token_response: WrappedOAuthTokenResponse(credentials),
expires_at,
};
save_oauth_tokens(server_name, &stored, store_mode)?;
drop(guard);
Ok(())
Ok(OauthLoginHandle::new(authorization_url, completion))
}
fn spawn_callback_server(server: Arc<Server>, tx: oneshot::Sender<(String, String)>) {
@@ -160,3 +147,181 @@ fn parse_oauth_callback(path: &str) -> Option<OauthCallbackResult> {
state: state?,
})
}
pub struct OauthLoginHandle {
authorization_url: String,
completion: oneshot::Receiver<Result<()>>,
}
impl OauthLoginHandle {
fn new(authorization_url: String, completion: oneshot::Receiver<Result<()>>) -> Self {
Self {
authorization_url,
completion,
}
}
pub fn authorization_url(&self) -> &str {
&self.authorization_url
}
pub fn into_parts(self) -> (String, oneshot::Receiver<Result<()>>) {
(self.authorization_url, self.completion)
}
pub async fn wait(self) -> Result<()> {
self.completion
.await
.map_err(|err| anyhow!("OAuth login task was cancelled: {err}"))?
}
}
struct OauthLoginFlow {
auth_url: String,
oauth_state: OAuthState,
rx: oneshot::Receiver<(String, String)>,
guard: CallbackServerGuard,
server_name: String,
server_url: String,
store_mode: OAuthCredentialsStoreMode,
launch_browser: bool,
timeout: Duration,
}
impl OauthLoginFlow {
async fn new(
server_name: &str,
server_url: &str,
store_mode: OAuthCredentialsStoreMode,
headers: OauthHeaders,
scopes: &[String],
launch_browser: bool,
timeout_secs: Option<i64>,
) -> Result<Self> {
const DEFAULT_OAUTH_TIMEOUT_SECS: i64 = 300;
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
let guard = CallbackServerGuard {
server: Arc::clone(&server),
};
let redirect_uri = match server.server_addr() {
tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
let ip = addr.ip();
let port = addr.port();
format!("http://{ip}:{port}/callback")
}
tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
let ip = addr.ip();
let port = addr.port();
format!("http://[{ip}]:{port}/callback")
}
#[cfg(not(target_os = "windows"))]
_ => return Err(anyhow!("unable to determine callback address")),
};
let (tx, rx) = oneshot::channel();
spawn_callback_server(server, tx);
let OauthHeaders {
http_headers,
env_http_headers,
} = headers;
let default_headers = build_default_headers(http_headers, env_http_headers)?;
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
oauth_state
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
.await?;
let auth_url = oauth_state.get_authorization_url().await?;
let timeout_secs = timeout_secs.unwrap_or(DEFAULT_OAUTH_TIMEOUT_SECS).max(1);
let timeout = Duration::from_secs(timeout_secs as u64);
Ok(Self {
auth_url,
oauth_state,
rx,
guard,
server_name: server_name.to_string(),
server_url: server_url.to_string(),
store_mode,
launch_browser,
timeout,
})
}
fn authorization_url(&self) -> String {
self.auth_url.clone()
}
async fn finish(mut self) -> Result<()> {
if self.launch_browser {
let server_name = &self.server_name;
let auth_url = &self.auth_url;
println!(
"Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n"
);
if webbrowser::open(auth_url).is_err() {
println!("(Browser launch failed; please copy the URL above manually.)");
}
}
let result = async {
let (code, csrf_state) = timeout(self.timeout, &mut self.rx)
.await
.context("timed out waiting for OAuth callback")?
.context("OAuth callback was cancelled")?;
self.oauth_state
.handle_callback(&code, &csrf_state)
.await
.context("failed to handle OAuth callback")?;
let (client_id, credentials_opt) = self
.oauth_state
.get_credentials()
.await
.context("failed to retrieve OAuth credentials")?;
let credentials = credentials_opt
.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
let expires_at = compute_expires_at_millis(&credentials);
let stored = StoredOAuthTokens {
server_name: self.server_name.clone(),
url: self.server_url.clone(),
client_id,
token_response: WrappedOAuthTokenResponse(credentials),
expires_at,
};
save_oauth_tokens(&self.server_name, &stored, self.store_mode)?;
Ok(())
}
.await;
drop(self.guard);
result
}
fn spawn(self) -> oneshot::Receiver<Result<()>> {
let server_name_for_logging = self.server_name.clone();
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
let result = self.finish().await;
if let Err(err) = &result {
eprintln!(
"Failed to complete OAuth login for '{server_name_for_logging}': {err:#}"
);
}
let _ = tx.send(result);
});
rx
}
}