Compare commits

...

2 Commits

Author SHA1 Message Date
jif-oai
f42c63a511 fix: treat missing taskkill process as terminated
Avoid retrying Windows MCP shutdown when taskkill reports the process has already exited. This keeps stale PID retries from targeting a reused process tree while preserving surfaced errors for other taskkill failures.

Co-authored-by: Codex <noreply@openai.com>
2026-04-30 17:33:19 +01:00
jif-oai
64faf26d8c fix: mcp leak on list 2026-04-30 16:38:16 +01:00
4 changed files with 303 additions and 54 deletions

View File

@@ -3604,7 +3604,7 @@ impl CodexMessageProcessor {
let (thread_id, thread) = self.load_thread(&thread_id).await?;
let request = request_id.clone();
let request = request_id;
let rollback_already_in_progress = {
let thread_state = self.thread_state_manager.thread_state(thread_id).await;
@@ -5599,18 +5599,16 @@ impl CodexMessageProcessor {
),
};
tokio::spawn(async move {
Self::list_mcp_server_status_task(
outgoing,
request,
params,
config,
mcp_config,
auth,
runtime_environment,
)
.await;
});
Self::list_mcp_server_status_task(
outgoing,
request,
params,
config,
mcp_config,
auth,
runtime_environment,
)
.await;
}
async fn list_mcp_server_status_task(
@@ -5743,10 +5741,8 @@ impl CodexMessageProcessor {
}
};
tokio::spawn(async move {
let result = thread.read_mcp_resource(&server, &uri).await;
Self::send_mcp_resource_read_response(outgoing, request_id, result).await;
});
let result = thread.read_mcp_resource(&server, &uri).await;
Self::send_mcp_resource_read_response(outgoing, request_id, result).await;
return;
}
@@ -5771,21 +5767,19 @@ impl CodexMessageProcessor {
McpRuntimeEnvironment::new(environment, config.cwd.to_path_buf())
};
tokio::spawn(async move {
let result = match read_mcp_resource_without_thread(
&mcp_config,
auth.as_ref(),
runtime_environment,
&server,
&uri,
)
.await
{
Ok(result) => serde_json::to_value(result).map_err(anyhow::Error::from),
Err(error) => Err(error),
};
Self::send_mcp_resource_read_response(outgoing, request_id, result).await;
});
let result = match read_mcp_resource_without_thread(
&mcp_config,
auth.as_ref(),
runtime_environment,
&server,
&uri,
)
.await
{
Ok(result) => serde_json::to_value(result).map_err(anyhow::Error::from),
Err(error) => Err(error),
};
Self::send_mcp_resource_read_response(outgoing, request_id, result).await;
}
async fn send_mcp_resource_read_response(
@@ -5821,14 +5815,12 @@ impl CodexMessageProcessor {
};
let meta = with_mcp_tool_call_thread_id_meta(params.meta, &thread_id);
tokio::spawn(async move {
let result = thread
.call_mcp_tool(&params.server, &params.tool, params.arguments, meta)
.await
.map(McpServerToolCallResponse::from)
.map_err(|error| internal_error(format!("{error:#}")));
outgoing.send_result(request_id, result).await;
});
let result = thread
.call_mcp_tool(&params.server, &params.tool, params.arguments, meta)
.await
.map(McpServerToolCallResponse::from)
.map_err(|error| internal_error(format!("{error:#}")));
outgoing.send_result(request_id, result).await;
}
async fn send_optional_result<T>(

View File

@@ -2,9 +2,13 @@ use std::borrow::Cow;
use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::time::Duration;
use anyhow::Result;
use anyhow::bail;
use app_test_support::McpProcess;
use app_test_support::create_mock_responses_server_sequence_unchecked;
use app_test_support::to_response;
@@ -145,6 +149,20 @@ struct SlowInventoryServer {
tool_name: Arc<String>,
}
#[derive(Default)]
struct InventoryConcurrencyTracker {
active_resource_calls: AtomicUsize,
max_resource_calls: AtomicUsize,
release_resource_calls: AtomicBool,
started_resource_calls: AtomicUsize,
}
#[derive(Clone)]
struct BlockingInventoryServer {
tool_name: Arc<String>,
tracker: Arc<InventoryConcurrencyTracker>,
}
impl ServerHandler for SlowInventoryServer {
fn get_info(&self) -> ServerInfo {
ServerInfo {
@@ -208,6 +226,74 @@ impl ServerHandler for SlowInventoryServer {
}
}
impl ServerHandler for BlockingInventoryServer {
fn get_info(&self) -> ServerInfo {
ServerInfo {
capabilities: ServerCapabilities::builder()
.enable_tools()
.enable_resources()
.build(),
..ServerInfo::default()
}
}
async fn list_tools(
&self,
_request: Option<PaginatedRequestParams>,
_context: RequestContext<rmcp::service::RoleServer>,
) -> Result<ListToolsResult, rmcp::ErrorData> {
let input_schema: JsonObject = serde_json::from_value(json!({
"type": "object",
"additionalProperties": false
}))
.map_err(|err| rmcp::ErrorData::internal_error(err.to_string(), None))?;
let mut tool = Tool::new(
Cow::Owned(self.tool_name.as_ref().clone()),
Cow::Borrowed("Look up test data."),
Arc::new(input_schema),
);
tool.annotations = Some(ToolAnnotations::new().read_only(true));
Ok(ListToolsResult {
tools: vec![tool],
next_cursor: None,
meta: None,
})
}
async fn list_resources(
&self,
_request: Option<PaginatedRequestParams>,
_context: RequestContext<rmcp::service::RoleServer>,
) -> Result<ListResourcesResult, rmcp::ErrorData> {
let active = self
.tracker
.active_resource_calls
.fetch_add(1, Ordering::AcqRel)
+ 1;
self.tracker
.started_resource_calls
.fetch_add(1, Ordering::AcqRel);
self.tracker
.max_resource_calls
.fetch_max(active, Ordering::AcqRel);
while !self.tracker.release_resource_calls.load(Ordering::Acquire) {
tokio::time::sleep(Duration::from_millis(10)).await;
}
self.tracker
.active_resource_calls
.fetch_sub(1, Ordering::AcqRel);
Ok(ListResourcesResult {
resources: Vec::new(),
next_cursor: None,
meta: None,
})
}
}
#[tokio::test]
async fn mcp_server_status_list_tools_and_auth_only_skips_slow_inventory_calls() -> Result<()> {
let server = create_mock_responses_server_sequence_unchecked(Vec::new()).await;
@@ -267,6 +353,87 @@ url = "{mcp_server_url}/mcp"
Ok(())
}
#[tokio::test]
async fn mcp_server_status_list_serializes_inventory_work() -> Result<()> {
let server = create_mock_responses_server_sequence_unchecked(Vec::new()).await;
let tracker = Arc::new(InventoryConcurrencyTracker::default());
let (mcp_server_url, mcp_server_handle) =
start_blocking_inventory_mcp_server("lookup", Arc::clone(&tracker)).await?;
let codex_home = TempDir::new()?;
write_mock_responses_config_toml(
codex_home.path(),
&server.uri(),
&BTreeMap::new(),
/*auto_compact_limit*/ 1024,
/*requires_openai_auth*/ None,
"mock_provider",
"compact",
)?;
let config_path = codex_home.path().join("config.toml");
let mut config_toml = std::fs::read_to_string(&config_path)?;
config_toml.push_str(&format!(
r#"
[mcp_servers.some-server]
url = "{mcp_server_url}/mcp"
"#
));
std::fs::write(config_path, config_toml)?;
let mut mcp = McpProcess::new(codex_home.path()).await?;
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
let first_request_id = mcp
.send_list_mcp_server_status_request(ListMcpServerStatusParams {
cursor: None,
limit: None,
detail: None,
})
.await?;
wait_for_resource_call_count(&tracker, 1).await?;
let second_request_id = mcp
.send_list_mcp_server_status_request(ListMcpServerStatusParams {
cursor: None,
limit: None,
detail: None,
})
.await?;
assert!(
timeout(
Duration::from_millis(750),
wait_for_resource_call_count(&tracker, 2)
)
.await
.is_err()
);
assert_eq!(tracker.max_resource_calls.load(Ordering::Acquire), 1);
tracker
.release_resource_calls
.store(true, Ordering::Release);
let response = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(first_request_id)),
)
.await??;
let _: ListMcpServerStatusResponse = to_response(response)?;
let response = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(second_request_id)),
)
.await??;
let _: ListMcpServerStatusResponse = to_response(response)?;
assert_eq!(tracker.max_resource_calls.load(Ordering::Acquire), 1);
mcp_server_handle.abort();
let _ = mcp_server_handle.await;
Ok(())
}
#[tokio::test]
async fn mcp_server_status_list_keeps_tools_for_sanitized_name_collisions() -> Result<()> {
let server = create_mock_responses_server_sequence_unchecked(Vec::new()).await;
@@ -388,3 +555,43 @@ async fn start_slow_inventory_mcp_server(tool_name: &str) -> Result<(String, Joi
Ok((format!("http://{addr}"), handle))
}
async fn start_blocking_inventory_mcp_server(
tool_name: &str,
tracker: Arc<InventoryConcurrencyTracker>,
) -> Result<(String, JoinHandle<()>)> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let tool_name = Arc::new(tool_name.to_string());
let mcp_service = StreamableHttpService::new(
move || {
Ok(BlockingInventoryServer {
tool_name: Arc::clone(&tool_name),
tracker: Arc::clone(&tracker),
})
},
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default(),
);
let router = Router::new().nest_service("/mcp", mcp_service);
let handle = tokio::spawn(async move {
let _ = axum::serve(listener, router).await;
});
Ok((format!("http://{addr}"), handle))
}
async fn wait_for_resource_call_count(
tracker: &InventoryConcurrencyTracker,
expected: usize,
) -> Result<()> {
for _ in 0..100 {
if tracker.started_resource_calls.load(Ordering::Acquire) >= expected {
return Ok(());
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
bail!("timed out waiting for {expected} resource/list call(s)");
}

View File

@@ -229,7 +229,7 @@ pub async fn read_mcp_resource(
.await;
let (tx_event, rx_event) = unbounded();
drop(rx_event);
let (manager, cancel_token) = McpConnectionManager::new(
let (mut manager, cancel_token) = McpConnectionManager::new(
&mcp_servers,
config.mcp_oauth_credentials_store_mode,
auth_statuses,
@@ -255,6 +255,7 @@ pub async fn read_mcp_resource(
)
.await;
cancel_token.cancel();
manager.shutdown().await;
result
}
@@ -294,7 +295,7 @@ pub async fn collect_mcp_server_status_snapshot_with_detail(
let (tx_event, rx_event) = unbounded();
drop(rx_event);
let (mcp_connection_manager, cancel_token) = McpConnectionManager::new(
let (mut mcp_connection_manager, cancel_token) = McpConnectionManager::new(
&mcp_servers,
config.mcp_oauth_credentials_store_mode,
auth_status_entries.clone(),
@@ -318,6 +319,7 @@ pub async fn collect_mcp_server_status_snapshot_with_detail(
.await;
cancel_token.cancel();
mcp_connection_manager.shutdown().await;
snapshot
}

View File

@@ -315,13 +315,15 @@ impl LocalProcessTerminator {
}
#[cfg(unix)]
fn terminate(&self) {
fn terminate(&self) -> io::Result<()> {
let process_group_id = self.process_group_id;
let should_escalate = match terminate_process_group(process_group_id) {
Ok(exists) => exists,
Err(error) => {
warn!("Failed to terminate MCP process group {process_group_id}: {error}");
false
return Err(io::Error::new(
error.kind(),
format!("terminating MCP process group {process_group_id}: {error}"),
));
}
};
if should_escalate {
@@ -332,20 +334,44 @@ impl LocalProcessTerminator {
}
});
}
Ok(())
}
#[cfg(windows)]
fn terminate(&self) {
let _ = std::process::Command::new("taskkill")
fn terminate(&self) -> io::Result<()> {
let output = std::process::Command::new("taskkill")
.arg("/PID")
.arg(self.pid.to_string())
.arg("/T")
.arg("/F")
.status();
.output()?;
if output.status.success() || taskkill_output_reports_missing_process(&output) {
return Ok(());
}
Err(io::Error::other(format!(
"taskkill exited with status {}",
output.status
)))
}
#[cfg(not(any(unix, windows)))]
fn terminate(&self) {}
fn terminate(&self) -> io::Result<()> {
Ok(())
}
}
#[cfg(windows)]
fn taskkill_output_reports_missing_process(output: &std::process::Output) -> bool {
taskkill_text_reports_missing_process(&output.stdout)
|| taskkill_text_reports_missing_process(&output.stderr)
}
#[cfg(any(windows, test))]
fn taskkill_text_reports_missing_process(bytes: &[u8]) -> bool {
String::from_utf8_lossy(bytes)
.to_ascii_lowercase()
.contains("not found")
}
impl StdioServerProcessHandle {
@@ -375,10 +401,13 @@ impl StdioServerProcessHandle {
}
match &self.inner.kind {
StdioServerProcessKind::Local(Some(terminator)) => {
terminator.terminate();
Ok(())
}
StdioServerProcessKind::Local(Some(terminator)) => match terminator.terminate() {
Ok(()) => Ok(()),
Err(error) => {
self.inner.terminated.store(false, Ordering::Release);
Err(error)
}
},
StdioServerProcessKind::Local(None) => Ok(()),
StdioServerProcessKind::Executor(process) => match process.terminate().await {
Ok(()) => Ok(()),
@@ -399,7 +428,12 @@ impl Drop for StdioServerProcessHandleInner {
match &self.kind {
StdioServerProcessKind::Local(Some(terminator)) => {
terminator.terminate();
if let Err(error) = terminator.terminate() {
warn!(
"Failed to terminate MCP process group on drop ({}): {error}",
self.program_name
);
}
}
StdioServerProcessKind::Local(None) => {}
StdioServerProcessKind::Executor(process) => {
@@ -651,4 +685,18 @@ mod tests {
);
assert!(!env.contains_key("UNREQUESTED_SECRET"));
}
#[test]
fn taskkill_missing_process_output_is_treated_as_already_terminated() {
assert!(taskkill_text_reports_missing_process(
br#"ERROR: The process "1234" not found."#
));
}
#[test]
fn taskkill_other_failure_output_is_not_treated_as_missing_process() {
assert!(!taskkill_text_reports_missing_process(
b"ERROR: Access is denied."
));
}
}