mirror of
https://github.com/openai/codex.git
synced 2026-02-01 22:47:52 +00:00
Compare commits
4 Commits
1271d450b1
...
shijie/sup
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f4e3b2f945 | ||
|
|
bf53b8c3c7 | ||
|
|
3d57b24de0 | ||
|
|
11783eaeef |
1
codex-rs/Cargo.lock
generated
1
codex-rs/Cargo.lock
generated
@@ -887,6 +887,7 @@ dependencies = [
|
||||
"codex-file-search",
|
||||
"codex-login",
|
||||
"codex-protocol",
|
||||
"codex-rmcp-client",
|
||||
"codex-utils-json-to-toml",
|
||||
"core_test_support",
|
||||
"mcp-types",
|
||||
|
||||
@@ -139,6 +139,11 @@ client_request_definitions! {
|
||||
response: v2::ModelListResponse,
|
||||
},
|
||||
|
||||
McpServerOauthLogin => "mcpServer/oauth/login" {
|
||||
params: v2::McpServerOauthLoginParams,
|
||||
response: v2::McpServerOauthLoginResponse,
|
||||
},
|
||||
|
||||
McpServersList => "mcpServers/list" {
|
||||
params: v2::ListMcpServersParams,
|
||||
response: v2::ListMcpServersResponse,
|
||||
@@ -524,6 +529,7 @@ server_notification_definitions! {
|
||||
CommandExecutionOutputDelta => "item/commandExecution/outputDelta" (v2::CommandExecutionOutputDeltaNotification),
|
||||
FileChangeOutputDelta => "item/fileChange/outputDelta" (v2::FileChangeOutputDeltaNotification),
|
||||
McpToolCallProgress => "item/mcpToolCall/progress" (v2::McpToolCallProgressNotification),
|
||||
McpServerOauthLoginCompleted => "mcpServer/oauthLogin/completed" (v2::McpServerOauthLoginCompletedNotification),
|
||||
AccountUpdated => "account/updated" (v2::AccountUpdatedNotification),
|
||||
AccountRateLimitsUpdated => "account/rateLimits/updated" (v2::AccountRateLimitsUpdatedNotification),
|
||||
ReasoningSummaryTextDelta => "item/reasoning/summaryTextDelta" (v2::ReasoningSummaryTextDeltaNotification),
|
||||
|
||||
@@ -688,6 +688,26 @@ pub struct ListMcpServersResponse {
|
||||
pub next_cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct McpServerOauthLoginParams {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[ts(optional)]
|
||||
pub scopes: Option<Vec<String>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[ts(optional)]
|
||||
pub timeout_secs: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct McpServerOauthLoginResponse {
|
||||
pub authorization_url: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
@@ -1467,6 +1487,17 @@ pub struct McpToolCallProgressNotification {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct McpServerOauthLoginCompletedNotification {
|
||||
pub name: String,
|
||||
pub success: bool,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[ts(optional)]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
|
||||
@@ -26,6 +26,7 @@ codex-login = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-feedback = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-utils-json-to-toml = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
|
||||
@@ -55,6 +55,9 @@ use codex_app_server_protocol::LoginChatGptResponse;
|
||||
use codex_app_server_protocol::LogoutAccountResponse;
|
||||
use codex_app_server_protocol::LogoutChatGptResponse;
|
||||
use codex_app_server_protocol::McpServer;
|
||||
use codex_app_server_protocol::McpServerOauthLoginCompletedNotification;
|
||||
use codex_app_server_protocol::McpServerOauthLoginParams;
|
||||
use codex_app_server_protocol::McpServerOauthLoginResponse;
|
||||
use codex_app_server_protocol::ModelListParams;
|
||||
use codex_app_server_protocol::ModelListResponse;
|
||||
use codex_app_server_protocol::NewConversationParams;
|
||||
@@ -115,6 +118,7 @@ use codex_core::config::Config;
|
||||
use codex_core::config::ConfigOverrides;
|
||||
use codex_core::config::ConfigToml;
|
||||
use codex_core::config::edit::ConfigEditsBuilder;
|
||||
use codex_core::config::types::McpServerTransportConfig;
|
||||
use codex_core::config_loader::load_config_as_toml;
|
||||
use codex_core::default_client::get_codex_user_agent;
|
||||
use codex_core::exec::ExecParams;
|
||||
@@ -147,6 +151,7 @@ use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::SessionMetaLine;
|
||||
use codex_protocol::protocol::USER_MESSAGE_BEGIN;
|
||||
use codex_protocol::user_input::UserInput as CoreInputItem;
|
||||
use codex_rmcp_client::perform_oauth_login_return_url;
|
||||
use codex_utils_json_to_toml::json_to_toml;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
@@ -161,6 +166,7 @@ use std::time::Duration;
|
||||
use tokio::select;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::oneshot;
|
||||
use toml::Value as TomlValue;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
@@ -198,6 +204,7 @@ pub(crate) struct CodexMessageProcessor {
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
config: Arc<Config>,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
conversation_listeners: HashMap<Uuid, oneshot::Sender<()>>,
|
||||
active_login: Arc<Mutex<Option<ActiveLogin>>>,
|
||||
// Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives.
|
||||
@@ -244,6 +251,7 @@ impl CodexMessageProcessor {
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
config: Arc<Config>,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
feedback: CodexFeedback,
|
||||
) -> Self {
|
||||
Self {
|
||||
@@ -252,6 +260,7 @@ impl CodexMessageProcessor {
|
||||
outgoing,
|
||||
codex_linux_sandbox_exe,
|
||||
config,
|
||||
cli_overrides,
|
||||
conversation_listeners: HashMap::new(),
|
||||
active_login: Arc::new(Mutex::new(None)),
|
||||
pending_interrupts: Arc::new(Mutex::new(HashMap::new())),
|
||||
@@ -261,6 +270,16 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
async fn load_latest_config(&self) -> Result<Config, JSONRPCErrorError> {
|
||||
Config::load_with_cli_overrides(self.cli_overrides.clone(), ConfigOverrides::default())
|
||||
.await
|
||||
.map_err(|err| JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to reload config: {err}"),
|
||||
data: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn review_request_from_target(
|
||||
target: ApiReviewTarget,
|
||||
) -> Result<(ReviewRequest, String), JSONRPCErrorError> {
|
||||
@@ -369,6 +388,9 @@ impl CodexMessageProcessor {
|
||||
ClientRequest::ModelList { request_id, params } => {
|
||||
self.list_models(request_id, params).await;
|
||||
}
|
||||
ClientRequest::McpServerOauthLogin { request_id, params } => {
|
||||
self.mcp_server_oauth_login(request_id, params).await;
|
||||
}
|
||||
ClientRequest::McpServersList { request_id, params } => {
|
||||
self.list_mcp_servers(request_id, params).await;
|
||||
}
|
||||
@@ -1916,6 +1938,115 @@ impl CodexMessageProcessor {
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
|
||||
async fn mcp_server_oauth_login(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
params: McpServerOauthLoginParams,
|
||||
) {
|
||||
let config = match self.load_latest_config().await {
|
||||
Ok(config) => config,
|
||||
Err(error) => {
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if !config.features.enabled(Feature::RmcpClient) {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: "OAuth login is only supported when [features].rmcp_client is true in config.toml".to_string(),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
|
||||
let McpServerOauthLoginParams {
|
||||
name,
|
||||
scopes,
|
||||
timeout_secs,
|
||||
} = params;
|
||||
|
||||
let Some(server) = config.mcp_servers.get(&name) else {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("No MCP server named '{name}' found."),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
};
|
||||
|
||||
let (url, http_headers, env_http_headers) = match &server.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
..
|
||||
} => (url.clone(), http_headers.clone(), env_http_headers.clone()),
|
||||
_ => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: "OAuth login is only supported for streamable HTTP servers."
|
||||
.to_string(),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match perform_oauth_login_return_url(
|
||||
&name,
|
||||
&url,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
scopes.as_deref().unwrap_or_default(),
|
||||
timeout_secs,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(handle) => {
|
||||
let authorization_url = handle.authorization_url().to_string();
|
||||
let notification_name = name.clone();
|
||||
let outgoing = Arc::clone(&self.outgoing);
|
||||
let conversation_manager = Arc::clone(&self.conversation_manager);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let (success, error) = match handle.wait().await {
|
||||
Ok(()) => (true, None),
|
||||
Err(err) => (false, Some(err.to_string())),
|
||||
};
|
||||
|
||||
if success {
|
||||
conversation_manager.mark_mcp_oauth_success(Utc::now().timestamp());
|
||||
}
|
||||
|
||||
let notification = ServerNotification::McpServerOauthLoginCompleted(
|
||||
McpServerOauthLoginCompletedNotification {
|
||||
name: notification_name,
|
||||
success,
|
||||
error,
|
||||
},
|
||||
);
|
||||
outgoing.send_server_notification(notification).await;
|
||||
});
|
||||
|
||||
let response = McpServerOauthLoginResponse { authorization_url };
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to login to MCP server '{name}': {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_mcp_servers(&self, request_id: RequestId, params: ListMcpServersParams) {
|
||||
let snapshot = collect_mcp_snapshot(self.config.as_ref()).await;
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ impl MessageProcessor {
|
||||
outgoing.clone(),
|
||||
codex_linux_sandbox_exe,
|
||||
Arc::clone(&config),
|
||||
cli_overrides.clone(),
|
||||
feedback,
|
||||
);
|
||||
let config_api = ConfigApi::new(config.codex_home.clone(), cli_overrides);
|
||||
|
||||
@@ -55,6 +55,8 @@ use mcp_types::ReadResourceResult;
|
||||
use mcp_types::RequestId;
|
||||
use serde_json;
|
||||
use serde_json::Value;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::oneshot;
|
||||
@@ -170,6 +172,7 @@ impl Codex {
|
||||
models_manager: Arc<ModelsManager>,
|
||||
conversation_history: InitialHistory,
|
||||
session_source: SessionSource,
|
||||
mcp_oauth_refresh_clock: Arc<AtomicI64>,
|
||||
) -> CodexResult<CodexSpawnOk> {
|
||||
let (tx_sub, rx_sub) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
|
||||
let (tx_event, rx_event) = async_channel::unbounded();
|
||||
@@ -210,6 +213,7 @@ impl Codex {
|
||||
tx_event.clone(),
|
||||
conversation_history,
|
||||
session_source_clone,
|
||||
mcp_oauth_refresh_clock.clone(),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -466,6 +470,7 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn new(
|
||||
session_configuration: SessionConfiguration,
|
||||
config: Arc<Config>,
|
||||
@@ -474,6 +479,7 @@ impl Session {
|
||||
tx_event: Sender<Event>,
|
||||
initial_history: InitialHistory,
|
||||
session_source: SessionSource,
|
||||
mcp_oauth_refresh_clock: Arc<AtomicI64>,
|
||||
) -> anyhow::Result<Arc<Self>> {
|
||||
debug!(
|
||||
"Configuring session: model={}; provider={:?}",
|
||||
@@ -583,8 +589,11 @@ impl Session {
|
||||
let state = SessionState::new(session_configuration.clone());
|
||||
|
||||
let services = SessionServices {
|
||||
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())),
|
||||
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::new(
|
||||
mcp_oauth_refresh_clock.clone(),
|
||||
))),
|
||||
mcp_startup_cancellation_token: CancellationToken::new(),
|
||||
mcp_oauth_refresh_clock,
|
||||
unified_exec_manager: UnifiedExecSessionManager::default(),
|
||||
notifier: UserNotifier::new(config.notify.clone()),
|
||||
rollout: Mutex::new(Some(rollout_recorder)),
|
||||
@@ -1386,6 +1395,7 @@ impl Session {
|
||||
server: &str,
|
||||
params: Option<ListResourcesRequestParams>,
|
||||
) -> anyhow::Result<ListResourcesResult> {
|
||||
self.refresh_mcp_clients_if_needed().await?;
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
@@ -1399,6 +1409,7 @@ impl Session {
|
||||
server: &str,
|
||||
params: Option<ListResourceTemplatesRequestParams>,
|
||||
) -> anyhow::Result<ListResourceTemplatesResult> {
|
||||
self.refresh_mcp_clients_if_needed().await?;
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
@@ -1412,6 +1423,7 @@ impl Session {
|
||||
server: &str,
|
||||
params: ReadResourceRequestParams,
|
||||
) -> anyhow::Result<ReadResourceResult> {
|
||||
self.refresh_mcp_clients_if_needed().await?;
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
@@ -1426,6 +1438,7 @@ impl Session {
|
||||
tool: &str,
|
||||
arguments: Option<serde_json::Value>,
|
||||
) -> anyhow::Result<CallToolResult> {
|
||||
self.refresh_mcp_clients_if_needed().await?;
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
@@ -1435,6 +1448,7 @@ impl Session {
|
||||
}
|
||||
|
||||
pub(crate) async fn parse_mcp_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
|
||||
self.refresh_mcp_clients_if_needed().await.ok()?;
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
@@ -1443,6 +1457,42 @@ impl Session {
|
||||
.await
|
||||
}
|
||||
|
||||
async fn refresh_mcp_clients_if_needed(&self) -> anyhow::Result<()> {
|
||||
let current_clock = self.services.mcp_oauth_refresh_clock.load(Ordering::SeqCst);
|
||||
let last_seen = {
|
||||
let manager = self.services.mcp_connection_manager.read().await;
|
||||
manager.last_refresh_seen()
|
||||
};
|
||||
if current_clock <= last_seen {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let config = {
|
||||
let state = self.state.lock().await;
|
||||
state
|
||||
.session_configuration
|
||||
.original_config_do_not_use
|
||||
.clone()
|
||||
};
|
||||
let store_mode = config.mcp_oauth_credentials_store_mode;
|
||||
let auth_statuses = compute_auth_statuses(config.mcp_servers.iter(), store_mode).await;
|
||||
|
||||
{
|
||||
let mut manager = self.services.mcp_connection_manager.write().await;
|
||||
manager
|
||||
.refresh_if_needed(
|
||||
&config.mcp_servers,
|
||||
store_mode,
|
||||
auth_statuses,
|
||||
self.tx_event.clone(),
|
||||
self.services.mcp_startup_cancellation_token.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn interrupt_task(self: &Arc<Self>) {
|
||||
info!("interrupt received: abort current task, if any");
|
||||
let has_active_turn = { self.active_turn.lock().await.is_some() };
|
||||
@@ -2882,9 +2932,13 @@ mod tests {
|
||||
|
||||
let state = SessionState::new(session_configuration.clone());
|
||||
|
||||
let mcp_oauth_refresh_clock = Arc::new(AtomicI64::new(0));
|
||||
let services = SessionServices {
|
||||
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())),
|
||||
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::new(
|
||||
mcp_oauth_refresh_clock.clone(),
|
||||
))),
|
||||
mcp_startup_cancellation_token: CancellationToken::new(),
|
||||
mcp_oauth_refresh_clock,
|
||||
unified_exec_manager: UnifiedExecSessionManager::default(),
|
||||
notifier: UserNotifier::new(None),
|
||||
rollout: Mutex::new(None),
|
||||
@@ -2964,9 +3018,13 @@ mod tests {
|
||||
|
||||
let state = SessionState::new(session_configuration.clone());
|
||||
|
||||
let mcp_oauth_refresh_clock = Arc::new(AtomicI64::new(0));
|
||||
let services = SessionServices {
|
||||
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())),
|
||||
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::new(
|
||||
mcp_oauth_refresh_clock.clone(),
|
||||
))),
|
||||
mcp_startup_cancellation_token: CancellationToken::new(),
|
||||
mcp_oauth_refresh_clock,
|
||||
unified_exec_manager: UnifiedExecSessionManager::default(),
|
||||
notifier: UserNotifier::new(None),
|
||||
rollout: Mutex::new(None),
|
||||
|
||||
@@ -51,6 +51,7 @@ pub(crate) async fn run_codex_conversation_interactive(
|
||||
models_manager,
|
||||
initial_history.unwrap_or(InitialHistory::New),
|
||||
SessionSource::SubAgent(SubAgentSource::Review),
|
||||
parent_session.services.mcp_oauth_refresh_clock.clone(),
|
||||
)
|
||||
.await?;
|
||||
let codex = Arc::new(codex);
|
||||
|
||||
@@ -22,6 +22,8 @@ use codex_protocol::protocol::SessionSource;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Represents a newly created Codex conversation, including the first event
|
||||
@@ -39,6 +41,7 @@ pub struct ConversationManager {
|
||||
auth_manager: Arc<AuthManager>,
|
||||
models_manager: Arc<ModelsManager>,
|
||||
session_source: SessionSource,
|
||||
mcp_oauth_refresh_clock: Arc<AtomicI64>,
|
||||
}
|
||||
|
||||
impl ConversationManager {
|
||||
@@ -48,6 +51,7 @@ impl ConversationManager {
|
||||
auth_manager: auth_manager.clone(),
|
||||
session_source,
|
||||
models_manager: Arc::new(ModelsManager::new(auth_manager)),
|
||||
mcp_oauth_refresh_clock: Arc::new(AtomicI64::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +69,15 @@ impl ConversationManager {
|
||||
self.session_source.clone()
|
||||
}
|
||||
|
||||
pub fn mcp_oauth_refresh_clock(&self) -> Arc<AtomicI64> {
|
||||
self.mcp_oauth_refresh_clock.clone()
|
||||
}
|
||||
|
||||
pub fn mark_mcp_oauth_success(&self, timestamp_secs: i64) {
|
||||
self.mcp_oauth_refresh_clock
|
||||
.store(timestamp_secs, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
pub async fn new_conversation(&self, config: Config) -> CodexResult<NewConversation> {
|
||||
self.spawn_conversation(
|
||||
config,
|
||||
@@ -89,6 +102,7 @@ impl ConversationManager {
|
||||
models_manager,
|
||||
InitialHistory::New,
|
||||
self.session_source.clone(),
|
||||
self.mcp_oauth_refresh_clock.clone(),
|
||||
)
|
||||
.await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
@@ -166,6 +180,7 @@ impl ConversationManager {
|
||||
self.models_manager.clone(),
|
||||
initial_history,
|
||||
self.session_source.clone(),
|
||||
self.mcp_oauth_refresh_clock.clone(),
|
||||
)
|
||||
.await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
@@ -207,6 +222,7 @@ impl ConversationManager {
|
||||
self.models_manager.clone(),
|
||||
history,
|
||||
self.session_source.clone(),
|
||||
self.mcp_oauth_refresh_clock.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
pub mod auth;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
|
||||
use async_channel::unbounded;
|
||||
use codex_protocol::protocol::McpListToolsResponseEvent;
|
||||
@@ -29,7 +31,8 @@ pub async fn collect_mcp_snapshot(config: &Config) -> McpListToolsResponseEvent
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut mcp_connection_manager = McpConnectionManager::default();
|
||||
let mcp_oauth_refresh_clock = Arc::new(AtomicI64::new(0));
|
||||
let mut mcp_connection_manager = McpConnectionManager::new(mcp_oauth_refresh_clock);
|
||||
let (tx_event, rx_event) = unbounded();
|
||||
drop(rx_event);
|
||||
let cancel_token = CancellationToken::new();
|
||||
|
||||
@@ -12,6 +12,8 @@ use std::env;
|
||||
use std::ffi::OsString;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::mcp::auth::McpAuthStatusEntry;
|
||||
@@ -260,13 +262,70 @@ pub struct SandboxState {
|
||||
}
|
||||
|
||||
/// A thin wrapper around a set of running [`RmcpClient`] instances.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct McpConnectionManager {
|
||||
clients: HashMap<String, AsyncManagedClient>,
|
||||
elicitation_requests: ElicitationRequestManager,
|
||||
mcp_oauth_refresh_clock: Arc<AtomicI64>,
|
||||
last_refresh_seen: AtomicI64,
|
||||
config_snapshot: HashMap<String, McpServerConfig>,
|
||||
store_mode_snapshot: Option<OAuthCredentialsStoreMode>,
|
||||
auth_entries_snapshot: HashMap<String, McpAuthStatusEntry>,
|
||||
}
|
||||
|
||||
impl McpConnectionManager {
|
||||
pub(crate) fn new(mcp_oauth_refresh_clock: Arc<AtomicI64>) -> Self {
|
||||
Self {
|
||||
clients: HashMap::new(),
|
||||
elicitation_requests: ElicitationRequestManager::default(),
|
||||
mcp_oauth_refresh_clock,
|
||||
last_refresh_seen: AtomicI64::new(0),
|
||||
config_snapshot: HashMap::new(),
|
||||
store_mode_snapshot: None,
|
||||
auth_entries_snapshot: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_snapshots(
|
||||
&mut self,
|
||||
mcp_servers: &HashMap<String, McpServerConfig>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
auth_entries: &HashMap<String, McpAuthStatusEntry>,
|
||||
) {
|
||||
self.config_snapshot = mcp_servers.clone();
|
||||
self.store_mode_snapshot = Some(store_mode);
|
||||
self.auth_entries_snapshot = auth_entries.clone();
|
||||
let now = self.mcp_oauth_refresh_clock.load(Ordering::SeqCst);
|
||||
self.last_refresh_seen.store(now, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
pub(crate) fn last_refresh_seen(&self) -> i64 {
|
||||
self.last_refresh_seen.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub(crate) async fn refresh_if_needed(
|
||||
&mut self,
|
||||
config: &HashMap<String, McpServerConfig>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
auth_entries: HashMap<String, McpAuthStatusEntry>,
|
||||
tx_event: Sender<Event>,
|
||||
cancel_token: CancellationToken,
|
||||
) {
|
||||
let current = self.mcp_oauth_refresh_clock.load(Ordering::SeqCst);
|
||||
if current <= self.last_refresh_seen() {
|
||||
return;
|
||||
}
|
||||
|
||||
self.initialize(
|
||||
config.clone(),
|
||||
store_mode,
|
||||
auth_entries,
|
||||
tx_event,
|
||||
cancel_token,
|
||||
)
|
||||
.await;
|
||||
self.last_refresh_seen.store(current, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
pub async fn initialize(
|
||||
&mut self,
|
||||
mcp_servers: HashMap<String, McpServerConfig>,
|
||||
@@ -281,7 +340,9 @@ impl McpConnectionManager {
|
||||
let mut clients = HashMap::new();
|
||||
let mut join_set = JoinSet::new();
|
||||
let elicitation_requests = ElicitationRequestManager::default();
|
||||
for (server_name, cfg) in mcp_servers.into_iter().filter(|(_, cfg)| cfg.enabled) {
|
||||
for (server_name, cfg) in mcp_servers.iter().filter(|(_, cfg)| cfg.enabled) {
|
||||
let server_name = server_name.to_string();
|
||||
let cfg = cfg.clone();
|
||||
let cancel_token = cancel_token.child_token();
|
||||
let _ = emit_update(
|
||||
&tx_event,
|
||||
@@ -333,6 +394,7 @@ impl McpConnectionManager {
|
||||
}
|
||||
self.clients = clients;
|
||||
self.elicitation_requests = elicitation_requests.clone();
|
||||
self.update_snapshots(&mcp_servers, store_mode, &auth_entries);
|
||||
tokio::spawn(async move {
|
||||
let outcomes = join_set.join_all().await;
|
||||
let mut summary = McpStartupCompleteEvent::default();
|
||||
|
||||
@@ -8,6 +8,7 @@ use crate::tools::sandboxing::ApprovalStore;
|
||||
use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::user_notification::UserNotifier;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -15,6 +16,7 @@ use tokio_util::sync::CancellationToken;
|
||||
pub(crate) struct SessionServices {
|
||||
pub(crate) mcp_connection_manager: Arc<RwLock<McpConnectionManager>>,
|
||||
pub(crate) mcp_startup_cancellation_token: CancellationToken,
|
||||
pub(crate) mcp_oauth_refresh_clock: Arc<AtomicI64>,
|
||||
pub(crate) unified_exec_manager: UnifiedExecSessionManager,
|
||||
pub(crate) notifier: UserNotifier,
|
||||
pub(crate) rollout: Mutex<Option<RolloutRecorder>>,
|
||||
|
||||
@@ -16,7 +16,9 @@ pub use oauth::WrappedOAuthTokenResponse;
|
||||
pub use oauth::delete_oauth_tokens;
|
||||
pub(crate) use oauth::load_oauth_tokens;
|
||||
pub use oauth::save_oauth_tokens;
|
||||
pub use perform_oauth_login::OauthLoginHandle;
|
||||
pub use perform_oauth_login::perform_oauth_login;
|
||||
pub use perform_oauth_login::perform_oauth_login_return_url;
|
||||
pub use rmcp::model::ElicitationAction;
|
||||
pub use rmcp_client::Elicitation;
|
||||
pub use rmcp_client::ElicitationResponse;
|
||||
|
||||
@@ -22,6 +22,11 @@ use crate::save_oauth_tokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
|
||||
struct OauthHeaders {
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
struct CallbackServerGuard {
|
||||
server: Arc<Server>,
|
||||
}
|
||||
@@ -40,70 +45,52 @@ pub async fn perform_oauth_login(
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
scopes: &[String],
|
||||
) -> Result<()> {
|
||||
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
|
||||
let guard = CallbackServerGuard {
|
||||
server: Arc::clone(&server),
|
||||
let headers = OauthHeaders {
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
};
|
||||
OauthLoginFlow::new(
|
||||
server_name,
|
||||
server_url,
|
||||
store_mode,
|
||||
headers,
|
||||
scopes,
|
||||
true,
|
||||
None,
|
||||
)
|
||||
.await?
|
||||
.finish()
|
||||
.await
|
||||
}
|
||||
|
||||
let redirect_uri = match server.server_addr() {
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
|
||||
format!("http://{}:{}/callback", addr.ip(), addr.port())
|
||||
}
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
|
||||
format!("http://[{}]:{}/callback", addr.ip(), addr.port())
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
_ => return Err(anyhow!("unable to determine callback address")),
|
||||
pub async fn perform_oauth_login_return_url(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
scopes: &[String],
|
||||
timeout_secs: Option<i64>,
|
||||
) -> Result<OauthLoginHandle> {
|
||||
let headers = OauthHeaders {
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
};
|
||||
let flow = OauthLoginFlow::new(
|
||||
server_name,
|
||||
server_url,
|
||||
store_mode,
|
||||
headers,
|
||||
scopes,
|
||||
false,
|
||||
timeout_secs,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
spawn_callback_server(server, tx);
|
||||
let authorization_url = flow.authorization_url();
|
||||
let completion = flow.spawn();
|
||||
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
|
||||
|
||||
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
|
||||
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
|
||||
oauth_state
|
||||
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
|
||||
.await?;
|
||||
let auth_url = oauth_state.get_authorization_url().await?;
|
||||
|
||||
println!("Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n");
|
||||
|
||||
if webbrowser::open(&auth_url).is_err() {
|
||||
println!("(Browser launch failed; please copy the URL above manually.)");
|
||||
}
|
||||
|
||||
let (code, csrf_state) = timeout(Duration::from_secs(300), rx)
|
||||
.await
|
||||
.context("timed out waiting for OAuth callback")?
|
||||
.context("OAuth callback was cancelled")?;
|
||||
|
||||
oauth_state
|
||||
.handle_callback(&code, &csrf_state)
|
||||
.await
|
||||
.context("failed to handle OAuth callback")?;
|
||||
|
||||
let (client_id, credentials_opt) = oauth_state
|
||||
.get_credentials()
|
||||
.await
|
||||
.context("failed to retrieve OAuth credentials")?;
|
||||
let credentials =
|
||||
credentials_opt.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
|
||||
|
||||
let expires_at = compute_expires_at_millis(&credentials);
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: server_name.to_string(),
|
||||
url: server_url.to_string(),
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials),
|
||||
expires_at,
|
||||
};
|
||||
save_oauth_tokens(server_name, &stored, store_mode)?;
|
||||
|
||||
drop(guard);
|
||||
Ok(())
|
||||
Ok(OauthLoginHandle::new(authorization_url, completion))
|
||||
}
|
||||
|
||||
fn spawn_callback_server(server: Arc<Server>, tx: oneshot::Sender<(String, String)>) {
|
||||
@@ -160,3 +147,181 @@ fn parse_oauth_callback(path: &str) -> Option<OauthCallbackResult> {
|
||||
state: state?,
|
||||
})
|
||||
}
|
||||
|
||||
pub struct OauthLoginHandle {
|
||||
authorization_url: String,
|
||||
completion: oneshot::Receiver<Result<()>>,
|
||||
}
|
||||
|
||||
impl OauthLoginHandle {
|
||||
fn new(authorization_url: String, completion: oneshot::Receiver<Result<()>>) -> Self {
|
||||
Self {
|
||||
authorization_url,
|
||||
completion,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn authorization_url(&self) -> &str {
|
||||
&self.authorization_url
|
||||
}
|
||||
|
||||
pub fn into_parts(self) -> (String, oneshot::Receiver<Result<()>>) {
|
||||
(self.authorization_url, self.completion)
|
||||
}
|
||||
|
||||
pub async fn wait(self) -> Result<()> {
|
||||
self.completion
|
||||
.await
|
||||
.map_err(|err| anyhow!("OAuth login task was cancelled: {err}"))?
|
||||
}
|
||||
}
|
||||
|
||||
struct OauthLoginFlow {
|
||||
auth_url: String,
|
||||
oauth_state: OAuthState,
|
||||
rx: oneshot::Receiver<(String, String)>,
|
||||
guard: CallbackServerGuard,
|
||||
server_name: String,
|
||||
server_url: String,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
launch_browser: bool,
|
||||
timeout: Duration,
|
||||
}
|
||||
|
||||
impl OauthLoginFlow {
|
||||
async fn new(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
headers: OauthHeaders,
|
||||
scopes: &[String],
|
||||
launch_browser: bool,
|
||||
timeout_secs: Option<i64>,
|
||||
) -> Result<Self> {
|
||||
const DEFAULT_OAUTH_TIMEOUT_SECS: i64 = 300;
|
||||
|
||||
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
|
||||
let guard = CallbackServerGuard {
|
||||
server: Arc::clone(&server),
|
||||
};
|
||||
|
||||
let redirect_uri = match server.server_addr() {
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
|
||||
let ip = addr.ip();
|
||||
let port = addr.port();
|
||||
format!("http://{ip}:{port}/callback")
|
||||
}
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
|
||||
let ip = addr.ip();
|
||||
let port = addr.port();
|
||||
format!("http://[{ip}]:{port}/callback")
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
_ => return Err(anyhow!("unable to determine callback address")),
|
||||
};
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
spawn_callback_server(server, tx);
|
||||
|
||||
let OauthHeaders {
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} = headers;
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
|
||||
|
||||
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
|
||||
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
|
||||
oauth_state
|
||||
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
|
||||
.await?;
|
||||
let auth_url = oauth_state.get_authorization_url().await?;
|
||||
let timeout_secs = timeout_secs.unwrap_or(DEFAULT_OAUTH_TIMEOUT_SECS).max(1);
|
||||
let timeout = Duration::from_secs(timeout_secs as u64);
|
||||
|
||||
Ok(Self {
|
||||
auth_url,
|
||||
oauth_state,
|
||||
rx,
|
||||
guard,
|
||||
server_name: server_name.to_string(),
|
||||
server_url: server_url.to_string(),
|
||||
store_mode,
|
||||
launch_browser,
|
||||
timeout,
|
||||
})
|
||||
}
|
||||
|
||||
fn authorization_url(&self) -> String {
|
||||
self.auth_url.clone()
|
||||
}
|
||||
|
||||
async fn finish(mut self) -> Result<()> {
|
||||
if self.launch_browser {
|
||||
let server_name = &self.server_name;
|
||||
let auth_url = &self.auth_url;
|
||||
println!(
|
||||
"Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n"
|
||||
);
|
||||
|
||||
if webbrowser::open(auth_url).is_err() {
|
||||
println!("(Browser launch failed; please copy the URL above manually.)");
|
||||
}
|
||||
}
|
||||
|
||||
let result = async {
|
||||
let (code, csrf_state) = timeout(self.timeout, &mut self.rx)
|
||||
.await
|
||||
.context("timed out waiting for OAuth callback")?
|
||||
.context("OAuth callback was cancelled")?;
|
||||
|
||||
self.oauth_state
|
||||
.handle_callback(&code, &csrf_state)
|
||||
.await
|
||||
.context("failed to handle OAuth callback")?;
|
||||
|
||||
let (client_id, credentials_opt) = self
|
||||
.oauth_state
|
||||
.get_credentials()
|
||||
.await
|
||||
.context("failed to retrieve OAuth credentials")?;
|
||||
let credentials = credentials_opt
|
||||
.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
|
||||
|
||||
let expires_at = compute_expires_at_millis(&credentials);
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: self.server_name.clone(),
|
||||
url: self.server_url.clone(),
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials),
|
||||
expires_at,
|
||||
};
|
||||
save_oauth_tokens(&self.server_name, &stored, self.store_mode)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
drop(self.guard);
|
||||
result
|
||||
}
|
||||
|
||||
fn spawn(self) -> oneshot::Receiver<Result<()>> {
|
||||
let server_name_for_logging = self.server_name.clone();
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let result = self.finish().await;
|
||||
|
||||
if let Err(err) = &result {
|
||||
eprintln!(
|
||||
"Failed to complete OAuth login for '{server_name_for_logging}': {err:#}"
|
||||
);
|
||||
}
|
||||
|
||||
let _ = tx.send(result);
|
||||
});
|
||||
|
||||
rx
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user