Compare commits

...

1 Commits

Author SHA1 Message Date
pap
4c29f1da17 auto login on auth required for mcp calls 2026-04-16 19:02:05 +01:00
9 changed files with 315 additions and 36 deletions

View File

@@ -349,6 +349,8 @@ pub async fn collect_mcp_snapshot_with_detail(
let (mcp_connection_manager, cancel_token) = McpConnectionManager::new(
&mcp_servers,
config.mcp_oauth_credentials_store_mode,
config.mcp_oauth_callback_port,
config.mcp_oauth_callback_url.clone(),
auth_status_entries.clone(),
&config.approval_policy,
submit_id,
@@ -415,6 +417,8 @@ pub async fn collect_mcp_server_status_snapshot_with_detail(
let (mcp_connection_manager, cancel_token) = McpConnectionManager::new(
&mcp_servers,
config.mcp_oauth_credentials_store_mode,
config.mcp_oauth_callback_port,
config.mcp_oauth_callback_url.clone(),
auth_status_entries.clone(),
&config.approval_policy,
submit_id,

View File

@@ -486,6 +486,8 @@ impl AsyncManagedClient {
server_name: String,
config: McpServerConfig,
store_mode: OAuthCredentialsStoreMode,
mcp_oauth_callback_port: Option<u16>,
mcp_oauth_callback_url: Option<String>,
cancel_token: CancellationToken,
tx_event: Sender<Event>,
elicitation_requests: ElicitationRequestManager,
@@ -507,16 +509,24 @@ impl AsyncManagedClient {
return Err(error.into());
}
let client =
Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?);
let startup_timeout = config.startup_timeout_sec.or(Some(DEFAULT_STARTUP_TIMEOUT));
let tool_timeout = config.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT);
let client = Arc::new(
make_rmcp_client(
&server_name,
config,
store_mode,
mcp_oauth_callback_port,
mcp_oauth_callback_url,
)
.await?,
);
match start_server_task(
server_name,
client,
StartServerTaskParams {
startup_timeout: config
.startup_timeout_sec
.or(Some(DEFAULT_STARTUP_TIMEOUT)),
tool_timeout: config.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT),
startup_timeout,
tool_timeout,
tool_filter: startup_tool_filter,
tx_event,
elicitation_requests,
@@ -703,6 +713,8 @@ impl McpConnectionManager {
pub async fn new(
mcp_servers: &HashMap<String, McpServerConfig>,
store_mode: OAuthCredentialsStoreMode,
mcp_oauth_callback_port: Option<u16>,
mcp_oauth_callback_url: Option<String>,
auth_entries: HashMap<String, McpAuthStatusEntry>,
approval_policy: &Constrained<AskForApproval>,
submit_id: String,
@@ -747,6 +759,8 @@ impl McpConnectionManager {
server_name.clone(),
cfg,
store_mode,
mcp_oauth_callback_port,
mcp_oauth_callback_url.clone(),
cancel_token.clone(),
tx_event.clone(),
elicitation_requests.clone(),
@@ -1482,9 +1496,16 @@ struct StartServerTaskParams {
async fn make_rmcp_client(
server_name: &str,
transport: McpServerTransportConfig,
config: McpServerConfig,
store_mode: OAuthCredentialsStoreMode,
mcp_oauth_callback_port: Option<u16>,
mcp_oauth_callback_url: Option<String>,
) -> Result<RmcpClient, StartupOutcomeError> {
let McpServerConfig {
transport,
oauth_resource,
..
} = config;
match transport {
McpServerTransportConfig::Stdio {
command,
@@ -1521,6 +1542,9 @@ async fn make_rmcp_client(
resolved_bearer_token,
http_headers,
env_http_headers,
oauth_resource,
mcp_oauth_callback_port,
mcp_oauth_callback_url,
store_mode,
)
.await

View File

@@ -2242,6 +2242,8 @@ impl Session {
let (mcp_connection_manager, cancel_token) = McpConnectionManager::new(
&mcp_servers,
config.mcp_oauth_credentials_store_mode,
config.mcp_oauth_callback_port,
config.mcp_oauth_callback_url.clone(),
auth_statuses.clone(),
&session_configuration.approval_policy,
INITIAL_SUBMIT_ID.to_owned(),
@@ -4578,6 +4580,8 @@ impl Session {
let (refreshed_manager, cancel_token) = McpConnectionManager::new(
&mcp_servers,
store_mode,
config.mcp_oauth_callback_port,
config.mcp_oauth_callback_url.clone(),
auth_statuses,
&turn_context.config.permissions.approval_policy,
turn_context.sub_id.clone(),

View File

@@ -228,6 +228,8 @@ pub async fn list_accessible_connectors_from_mcp_tools_with_options_and_status(
let (mcp_connection_manager, cancel_token) = McpConnectionManager::new(
&mcp_servers,
config.mcp_oauth_credentials_store_mode,
config.mcp_oauth_callback_port,
config.mcp_oauth_callback_url.clone(),
auth_status_entries,
&config.permissions.approval_policy,
INITIAL_SUBMIT_ID.to_owned(),

View File

@@ -76,6 +76,7 @@ use crate::elicitation_client_service::ElicitationClientService;
use crate::load_oauth_tokens;
use crate::oauth::OAuthPersistor;
use crate::oauth::StoredOAuthTokens;
use crate::perform_oauth_login;
use crate::program_resolver;
use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
@@ -405,6 +406,9 @@ enum TransportRecipe {
bearer_token: Option<String>,
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
oauth_resource: Option<String>,
callback_port: Option<u16>,
callback_url: Option<String>,
store_mode: OAuthCredentialsStoreMode,
},
}
@@ -604,6 +608,9 @@ impl RmcpClient {
bearer_token: Option<String>,
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
oauth_resource: Option<String>,
callback_port: Option<u16>,
callback_url: Option<String>,
store_mode: OAuthCredentialsStoreMode,
) -> Result<Self> {
let transport_recipe = TransportRecipe::StreamableHttp {
@@ -612,6 +619,9 @@ impl RmcpClient {
bearer_token,
http_headers,
env_http_headers,
oauth_resource,
callback_port,
callback_url,
store_mode,
};
let transport = Self::create_pending_transport(&transport_recipe).await?;
@@ -827,35 +837,72 @@ impl RmcpClient {
arguments,
task: None,
};
let result = self
.run_service_operation("tools/call", timeout, move |service| {
let rmcp_params = rmcp_params.clone();
let meta = meta.clone();
async move {
let result = service
.peer()
.send_request_with_option(
ClientRequest::CallToolRequest(rmcp::model::CallToolRequest {
method: Default::default(),
params: rmcp_params,
extensions: Default::default(),
}),
rmcp::service::PeerRequestOptions {
timeout: None,
meta,
},
)
.await?
.await_response()
.await?;
match result {
ServerResult::CallToolResult(result) => Ok(result),
_ => Err(rmcp::service::ServiceError::UnexpectedResponse),
}
let operation = |service: Arc<RunningService<RoleClient, ElicitationClientService>>| {
let rmcp_params = rmcp_params.clone();
let meta = meta.clone();
async move {
let result = service
.peer()
.send_request_with_option(
ClientRequest::CallToolRequest(rmcp::model::CallToolRequest {
method: Default::default(),
params: rmcp_params,
extensions: Default::default(),
}),
rmcp::service::PeerRequestOptions {
timeout: None,
meta,
},
)
.await?
.await_response()
.await?;
match result {
ServerResult::CallToolResult(result) => Ok(result),
_ => Err(rmcp::service::ServiceError::UnexpectedResponse),
}
.boxed()
})
.await?;
}
.boxed()
};
let service = self.service().await?;
let result = match Self::run_service_operation_once(
Arc::clone(&service),
"tools/call",
timeout,
self.elicitation_pause_state.clone(),
&operation,
)
.await
{
Ok(result) => result,
Err(error) if Self::is_session_expired_404(&error) => {
self.reinitialize_after_session_expiry(&service).await?;
let recovered_service = self.service().await?;
Self::run_service_operation_once(
recovered_service,
"tools/call",
timeout,
self.elicitation_pause_state.clone(),
&operation,
)
.await
.map_err(anyhow::Error::from)?
}
Err(error) if Self::is_auth_required_error(&error) => {
self.reinitialize_after_oauth_login(&service).await?;
let recovered_service = self.service().await?;
Self::run_service_operation_once(
recovered_service,
"tools/call",
timeout,
self.elicitation_pause_state.clone(),
&operation,
)
.await
.map_err(anyhow::Error::from)?
}
Err(error) => return Err(error.into()),
};
self.persist_oauth_tokens().await;
Ok(result)
}
@@ -1016,6 +1063,7 @@ impl RmcpClient {
http_headers,
env_http_headers,
store_mode,
..
} => {
let default_headers =
build_default_headers(http_headers.clone(), env_http_headers.clone())?;
@@ -1223,6 +1271,19 @@ impl RmcpClient {
})
}
fn is_auth_required_error(error: &ClientOperationError) -> bool {
let ClientOperationError::Service(rmcp::service::ServiceError::TransportSend(error)) =
error
else {
return false;
};
error
.error
.downcast_ref::<StreamableHttpError<StreamableHttpResponseClientError>>()
.is_some_and(|error| matches!(error, StreamableHttpError::AuthRequired(_)))
}
async fn reinitialize_after_session_expiry(
&self,
failed_service: &Arc<RunningService<RoleClient, ElicitationClientService>>,
@@ -1273,6 +1334,106 @@ impl RmcpClient {
Ok(())
}
async fn reinitialize_after_oauth_login(
&self,
failed_service: &Arc<RunningService<RoleClient, ElicitationClientService>>,
) -> Result<()> {
let _recovery_guard = self.session_recovery_lock.lock().await;
{
let guard = self.state.lock().await;
match &*guard {
ClientState::Ready { service, .. } if !Arc::ptr_eq(service, failed_service) => {
return Ok(());
}
ClientState::Ready { .. } => {}
ClientState::Connecting { .. } => {
return Err(anyhow!("MCP client not initialized"));
}
}
}
let TransportRecipe::StreamableHttp {
server_name,
url,
bearer_token,
http_headers,
env_http_headers,
oauth_resource,
callback_port,
callback_url,
store_mode,
} = &self.transport_recipe
else {
return Err(anyhow!(
"MCP authentication recovery is only supported for streamable HTTP transports"
));
};
if bearer_token.is_some() {
return Err(anyhow!(
"MCP server `{server_name}` returned auth required while using a fixed bearer token"
));
}
let scopes = load_oauth_tokens(server_name, url, *store_mode)?
.and_then(|tokens| tokens.token_response.0.scopes().cloned())
.map(|scopes| {
scopes
.into_iter()
.map(|scope| scope.to_string())
.collect::<Vec<_>>()
})
.unwrap_or_default();
{
let _pause = self.elicitation_pause_state.enter();
perform_oauth_login(
server_name,
url,
*store_mode,
http_headers.clone(),
env_http_headers.clone(),
&scopes,
oauth_resource.as_deref(),
*callback_port,
callback_url.as_deref(),
)
.await?;
}
let initialize_context = self
.initialize_context
.lock()
.await
.clone()
.ok_or_else(|| anyhow!("MCP client cannot recover before initialize succeeds"))?;
let pending_transport = Self::create_pending_transport(&self.transport_recipe).await?;
let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport(
pending_transport,
initialize_context.client_service,
initialize_context.timeout,
)
.await?;
{
let mut guard = self.state.lock().await;
*guard = ClientState::Ready {
_process_group_guard: process_group_guard,
service,
oauth: oauth_persistor.clone(),
};
}
if let Some(runtime) = oauth_persistor
&& let Err(error) = runtime.persist_if_needed().await
{
warn!("failed to persist OAuth tokens after authentication recovery: {error}");
}
Ok(())
}
}
async fn create_oauth_transport_and_runtime(

View File

@@ -76,6 +76,9 @@ async fn create_client(base_url: &str) -> anyhow::Result<RmcpClient> {
Some("test-bearer".to_string()),
/*http_headers*/ None,
/*env_http_headers*/ None,
/*oauth_resource*/ None,
/*callback_port*/ None,
/*callback_url*/ None,
OAuthCredentialsStoreMode::File,
)
.await?;

View File

@@ -165,7 +165,7 @@ impl App {
self.chat_widget.dismiss_app_server_request(&request);
}
}
ServerNotification::McpServerStatusUpdated(_) => {
ServerNotification::McpServerStatusUpdated(_notification) => {
self.refresh_mcp_startup_expected_servers_from_config();
}
ServerNotification::AccountRateLimitsUpdated(notification) => {

View File

@@ -4743,6 +4743,7 @@ impl ChatWidget {
pub(crate) fn handle_mcp_begin_now(&mut self, ev: McpToolCallBeginEvent) {
self.flush_answer_stream_with_separator();
self.flush_active_cell();
self.bottom_pane.hide_status_indicator();
self.active_cell = Some(Box::new(history_cell::new_active_mcp_tool_call(
ev.call_id,
ev.invocation,
@@ -4785,6 +4786,17 @@ impl ChatWidget {
if let Some(extra) = extra_cell {
self.add_boxed_history(extra);
}
if self.bottom_pane.is_task_running() {
self.bottom_pane.ensure_status_indicator();
self.bottom_pane
.set_interrupt_hint_visible(/*visible*/ true);
self.set_status(
self.current_status.header.clone(),
self.current_status.details.clone(),
StatusDetailsCapitalization::Preserve,
self.current_status.details_max_lines,
);
}
// Mark that actual work was done (MCP tool call)
self.had_work_activity = true;
}

View File

@@ -1,4 +1,6 @@
use super::*;
use codex_protocol::mcp::CallToolResult;
use codex_protocol::protocol::McpInvocation;
use pretty_assertions::assert_eq;
#[tokio::test]
@@ -147,6 +149,73 @@ async fn live_app_server_turn_completed_clears_working_status_after_answer_item(
assert!(chat.bottom_pane.status_widget().is_none());
}
#[tokio::test]
async fn mcp_tool_call_hides_working_status_until_completion() {
let (mut chat, _rx, _op_rx) = make_chatwidget_manual(/*model_override*/ None).await;
let invocation = McpInvocation {
server: "pap".to_string(),
tool: "get-info".to_string(),
arguments: Some(serde_json::json!({})),
};
chat.handle_server_notification(
ServerNotification::TurnStarted(TurnStartedNotification {
thread_id: "thread-1".to_string(),
turn: AppServerTurn {
id: "turn-1".to_string(),
items: Vec::new(),
status: AppServerTurnStatus::InProgress,
error: None,
started_at: Some(0),
completed_at: None,
duration_ms: None,
},
}),
/*replay_kind*/ None,
);
assert!(chat.bottom_pane.status_widget().is_some());
chat.handle_codex_event(Event {
id: "mcp-begin".into(),
msg: EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
call_id: "call-1".to_string(),
invocation: invocation.clone(),
mcp_app_resource_uri: None,
}),
});
assert!(
chat.bottom_pane.status_widget().is_none(),
"expected working status to hide while MCP call is active"
);
chat.handle_codex_event(Event {
id: "mcp-end".into(),
msg: EventMsg::McpToolCallEnd(McpToolCallEndEvent {
call_id: "call-1".to_string(),
invocation,
mcp_app_resource_uri: None,
duration: std::time::Duration::from_secs(1),
result: Ok(CallToolResult {
content: vec![serde_json::json!({
"type": "text",
"text": "aaa",
})],
structured_content: None,
is_error: None,
meta: None,
}),
}),
});
let status = chat
.bottom_pane
.status_widget()
.expect("status indicator should be restored after MCP call completion");
assert_eq!(status.header(), "Working");
}
#[tokio::test]
async fn live_app_server_turn_started_sets_feedback_turn_id() {
let (mut chat, mut rx, _op_rx) = make_chatwidget_manual(/*model_override*/ None).await;