Added timeout, clean up auth server and emit event on failure and timeout

This commit is contained in:
shijie-openai
2025-12-08 14:43:34 -08:00
parent 11783eaeef
commit 3d57b24de0
5 changed files with 98 additions and 10 deletions

View File

@@ -529,6 +529,7 @@ server_notification_definitions! {
CommandExecutionOutputDelta => "item/commandExecution/outputDelta" (v2::CommandExecutionOutputDeltaNotification),
FileChangeOutputDelta => "item/fileChange/outputDelta" (v2::FileChangeOutputDeltaNotification),
McpToolCallProgress => "item/mcpToolCall/progress" (v2::McpToolCallProgressNotification),
McpServerOauthLoginCompleted => "mcpServer/oauthLogin/completed" (v2::McpServerOauthLoginCompletedNotification),
AccountUpdated => "account/updated" (v2::AccountUpdatedNotification),
AccountRateLimitsUpdated => "account/rateLimits/updated" (v2::AccountRateLimitsUpdatedNotification),
ReasoningSummaryTextDelta => "item/reasoning/summaryTextDelta" (v2::ReasoningSummaryTextDeltaNotification),

View File

@@ -696,6 +696,9 @@ pub struct McpServerOauthLoginParams {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[ts(optional)]
pub scopes: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[ts(optional)]
pub timeout_secs: Option<i64>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
@@ -1484,6 +1487,17 @@ pub struct McpToolCallProgressNotification {
pub message: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]
pub struct McpServerOauthLoginCompletedNotification {
pub name: String,
pub success: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[ts(optional)]
pub error: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]

View File

@@ -55,6 +55,7 @@ use codex_app_server_protocol::LoginChatGptResponse;
use codex_app_server_protocol::LogoutAccountResponse;
use codex_app_server_protocol::LogoutChatGptResponse;
use codex_app_server_protocol::McpServer;
use codex_app_server_protocol::McpServerOauthLoginCompletedNotification;
use codex_app_server_protocol::McpServerOauthLoginParams;
use codex_app_server_protocol::McpServerOauthLoginResponse;
use codex_app_server_protocol::ModelListParams;
@@ -1938,7 +1939,11 @@ impl CodexMessageProcessor {
return;
}
let McpServerOauthLoginParams { name, scopes } = params;
let McpServerOauthLoginParams {
name,
scopes,
timeout_secs,
} = params;
let Some(server) = self.config.mcp_servers.get(&name) else {
let error = JSONRPCErrorError {
@@ -1976,10 +1981,31 @@ impl CodexMessageProcessor {
http_headers,
env_http_headers,
scopes.as_deref().unwrap_or_default(),
timeout_secs,
)
.await
{
Ok(authorization_url) => {
Ok(handle) => {
let authorization_url = handle.authorization_url().to_string();
let notification_name = name.clone();
let outgoing = Arc::clone(&self.outgoing);
tokio::spawn(async move {
let (success, error) = match handle.wait().await {
Ok(()) => (true, None),
Err(err) => (false, Some(err.to_string())),
};
let notification = ServerNotification::McpServerOauthLoginCompleted(
McpServerOauthLoginCompletedNotification {
name: notification_name,
success,
error,
},
);
outgoing.send_server_notification(notification).await;
});
let response = McpServerOauthLoginResponse { authorization_url };
self.outgoing.send_response(request_id, response).await;
}

View File

@@ -16,6 +16,7 @@ pub use oauth::WrappedOAuthTokenResponse;
pub use oauth::delete_oauth_tokens;
pub(crate) use oauth::load_oauth_tokens;
pub use oauth::save_oauth_tokens;
pub use perform_oauth_login::OauthLoginHandle;
pub use perform_oauth_login::perform_oauth_login;
pub use perform_oauth_login::perform_oauth_login_return_url;
pub use rmcp::model::ElicitationAction;

View File

@@ -48,6 +48,7 @@ pub async fn perform_oauth_login(
env_http_headers,
scopes,
true,
None,
)
.await?
.finish()
@@ -61,7 +62,8 @@ pub async fn perform_oauth_login_return_url(
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
scopes: &[String],
) -> Result<String> {
timeout_secs: Option<i64>,
) -> Result<OauthLoginHandle> {
let flow = OauthLoginFlow::new(
server_name,
server_url,
@@ -70,14 +72,14 @@ pub async fn perform_oauth_login_return_url(
env_http_headers,
scopes,
false,
timeout_secs,
)
.await?;
let auth_url = flow.authorization_url();
let server_name_for_logging = flow.server_name.clone();
flow.spawn(server_name_for_logging);
let authorization_url = flow.authorization_url();
let completion = flow.spawn();
Ok(auth_url)
Ok(OauthLoginHandle::new(authorization_url, completion))
}
fn spawn_callback_server(server: Arc<Server>, tx: oneshot::Sender<(String, String)>) {
@@ -135,6 +137,34 @@ fn parse_oauth_callback(path: &str) -> Option<OauthCallbackResult> {
})
}
pub struct OauthLoginHandle {
authorization_url: String,
completion: oneshot::Receiver<Result<()>>,
}
impl OauthLoginHandle {
fn new(authorization_url: String, completion: oneshot::Receiver<Result<()>>) -> Self {
Self {
authorization_url,
completion,
}
}
pub fn authorization_url(&self) -> &str {
&self.authorization_url
}
pub fn into_parts(self) -> (String, oneshot::Receiver<Result<()>>) {
(self.authorization_url, self.completion)
}
pub async fn wait(self) -> Result<()> {
self.completion
.await
.map_err(|err| anyhow!("OAuth login task was cancelled: {err}"))?
}
}
struct OauthLoginFlow {
auth_url: String,
oauth_state: OAuthState,
@@ -144,6 +174,7 @@ struct OauthLoginFlow {
server_url: String,
store_mode: OAuthCredentialsStoreMode,
launch_browser: bool,
timeout: Duration,
}
impl OauthLoginFlow {
@@ -155,7 +186,10 @@ impl OauthLoginFlow {
env_http_headers: Option<HashMap<String, String>>,
scopes: &[String],
launch_browser: bool,
timeout_secs: Option<i64>,
) -> Result<Self> {
const DEFAULT_OAUTH_TIMEOUT_SECS: i64 = 300;
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
let guard = CallbackServerGuard {
server: Arc::clone(&server),
@@ -188,6 +222,8 @@ impl OauthLoginFlow {
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
.await?;
let auth_url = oauth_state.get_authorization_url().await?;
let timeout_secs = timeout_secs.unwrap_or(DEFAULT_OAUTH_TIMEOUT_SECS).max(1);
let timeout = Duration::from_secs(timeout_secs as u64);
Ok(Self {
auth_url,
@@ -198,6 +234,7 @@ impl OauthLoginFlow {
server_url: server_url.to_string(),
store_mode,
launch_browser,
timeout,
})
}
@@ -219,7 +256,7 @@ impl OauthLoginFlow {
}
let result = async {
let (code, csrf_state) = timeout(Duration::from_secs(300), &mut self.rx)
let (code, csrf_state) = timeout(self.timeout, &mut self.rx)
.await
.context("timed out waiting for OAuth callback")?
.context("OAuth callback was cancelled")?;
@@ -255,13 +292,22 @@ impl OauthLoginFlow {
result
}
fn spawn(self, server_name_for_logging: String) {
fn spawn(self) -> oneshot::Receiver<Result<()>> {
let server_name_for_logging = self.server_name.clone();
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
if let Err(err) = self.finish().await {
let result = self.finish().await;
if let Err(err) = &result {
eprintln!(
"Failed to complete OAuth login for '{server_name_for_logging}': {err:#}"
);
}
let _ = tx.send(result);
});
rx
}
}