Compare commits

...

3 Commits

Author SHA1 Message Date
Cooper Gamble
2e40d19f9e [app-server] fix argument comment lint in orphan unload path [ci changed_files] 2026-03-18 21:48:06 +00:00
cooper-oai
fb4a089dbb Merge branch 'main' into cooper/mcp-thread-orphan-reap 2026-03-18 12:57:13 -07:00
Cooper Gamble
2779adeb34 [app-server] Reap orphaned idle threads on websocket disconnect [ci changed_files]
Co-authored-by: Codex <noreply@openai.com>
2026-03-18 03:00:17 +00:00
5 changed files with 375 additions and 57 deletions

View File

@@ -3,5 +3,8 @@ load("//:defs.bzl", "codex_rust_crate")
codex_rust_crate(
name = "app-server",
crate_name = "codex_app_server",
extra_binaries = [
"//codex-rs/rmcp-client:test_stdio_server",
],
test_tags = ["no-sandbox"],
)

View File

@@ -3275,9 +3275,14 @@ impl CodexMessageProcessor {
self.command_exec_manager
.connection_closed(connection_id)
.await;
self.thread_state_manager
let orphaned_thread_ids = self
.thread_state_manager
.remove_connection(connection_id)
.await;
for thread_id in orphaned_thread_ids {
self.maybe_unload_orphaned_thread_after_disconnect(thread_id)
.await;
}
}
pub(crate) fn subscribe_running_assistant_turn_count(&self) -> watch::Receiver<usize> {
@@ -4837,6 +4842,93 @@ impl CodexMessageProcessor {
}
}
fn should_unload_orphaned_thread_after_disconnect(
rollout_path: Option<&Path>,
agent_status: AgentStatus,
loaded_status: ThreadStatus,
) -> bool {
rollout_path.is_some_and(std::path::Path::exists)
&& !matches!(agent_status, AgentStatus::Running)
&& matches!(
loaded_status,
ThreadStatus::Idle | ThreadStatus::SystemError
)
}
async fn begin_thread_shutdown(&self, thread_id: ThreadId, thread: Arc<CodexThread>) {
info!("thread {thread_id} has no subscribers; shutting down");
self.pending_thread_unloads.lock().await.insert(thread_id);
// Any pending app-server -> client requests for this thread can no longer be
// answered; cancel their callbacks before shutdown/unload.
self.outgoing
.cancel_requests_for_thread(thread_id, /*error*/ None)
.await;
self.thread_state_manager
.remove_thread_state(thread_id)
.await;
let outgoing = self.outgoing.clone();
let pending_thread_unloads = self.pending_thread_unloads.clone();
let thread_manager = self.thread_manager.clone();
let thread_watch_manager = self.thread_watch_manager.clone();
tokio::spawn(async move {
match Self::wait_for_thread_shutdown(&thread).await {
ThreadShutdownResult::Complete => {
if thread_manager.remove_thread(&thread_id).await.is_none() {
info!("thread {thread_id} was already removed before shutdown finalized");
thread_watch_manager
.remove_thread(&thread_id.to_string())
.await;
pending_thread_unloads.lock().await.remove(&thread_id);
return;
}
thread_watch_manager
.remove_thread(&thread_id.to_string())
.await;
let notification = ThreadClosedNotification {
thread_id: thread_id.to_string(),
};
outgoing
.send_server_notification(ServerNotification::ThreadClosed(notification))
.await;
pending_thread_unloads.lock().await.remove(&thread_id);
}
ThreadShutdownResult::SubmitFailed => {
pending_thread_unloads.lock().await.remove(&thread_id);
warn!("failed to submit Shutdown to thread {thread_id}");
}
ThreadShutdownResult::TimedOut => {
pending_thread_unloads.lock().await.remove(&thread_id);
warn!("thread {thread_id} shutdown timed out; leaving thread loaded");
}
}
});
}
async fn maybe_unload_orphaned_thread_after_disconnect(&self, thread_id: ThreadId) {
let Ok(thread) = self.thread_manager.get_thread(thread_id).await else {
return;
};
let loaded_status = self
.thread_watch_manager
.loaded_status_for_thread(&thread_id.to_string())
.await;
let agent_status = thread.agent_status().await;
if !Self::should_unload_orphaned_thread_after_disconnect(
thread.rollout_path().as_deref(),
agent_status,
loaded_status,
) {
return;
}
if self.thread_state_manager.has_subscribers(thread_id).await {
return;
}
self.begin_thread_shutdown(thread_id, thread).await;
}
async fn finalize_thread_teardown(&mut self, thread_id: ThreadId) {
self.pending_thread_unloads.lock().await.remove(&thread_id);
self.outgoing
@@ -4898,57 +4990,7 @@ impl CodexMessageProcessor {
if !self.thread_state_manager.has_subscribers(thread_id).await {
// This connection was the last subscriber. Only now do we unload the thread.
info!("thread {thread_id} has no subscribers; shutting down");
self.pending_thread_unloads.lock().await.insert(thread_id);
// Any pending app-server -> client requests for this thread can no longer be
// answered; cancel their callbacks before shutdown/unload.
self.outgoing
.cancel_requests_for_thread(thread_id, /*error*/ None)
.await;
self.thread_state_manager
.remove_thread_state(thread_id)
.await;
let outgoing = self.outgoing.clone();
let pending_thread_unloads = self.pending_thread_unloads.clone();
let thread_manager = self.thread_manager.clone();
let thread_watch_manager = self.thread_watch_manager.clone();
tokio::spawn(async move {
match Self::wait_for_thread_shutdown(&thread).await {
ThreadShutdownResult::Complete => {
if thread_manager.remove_thread(&thread_id).await.is_none() {
info!(
"thread {thread_id} was already removed before unsubscribe finalized"
);
thread_watch_manager
.remove_thread(&thread_id.to_string())
.await;
pending_thread_unloads.lock().await.remove(&thread_id);
return;
}
thread_watch_manager
.remove_thread(&thread_id.to_string())
.await;
let notification = ThreadClosedNotification {
thread_id: thread_id.to_string(),
};
outgoing
.send_server_notification(ServerNotification::ThreadClosed(
notification,
))
.await;
pending_thread_unloads.lock().await.remove(&thread_id);
}
ThreadShutdownResult::SubmitFailed => {
pending_thread_unloads.lock().await.remove(&thread_id);
warn!("failed to submit Shutdown to thread {thread_id}");
}
ThreadShutdownResult::TimedOut => {
pending_thread_unloads.lock().await.remove(&thread_id);
warn!("thread {thread_id} shutdown timed out; leaving thread loaded");
}
}
});
self.begin_thread_shutdown(thread_id, thread).await;
}
self.outgoing
@@ -8620,12 +8662,13 @@ mod tests {
state.lock().await.cancel_tx = Some(cancel_tx);
}
manager.remove_connection(connection_a).await;
let orphaned_thread_ids = manager.remove_connection(connection_a).await;
assert!(
tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx)
.await
.is_err()
);
assert!(orphaned_thread_ids.is_empty());
assert_eq!(
manager.subscribed_connection_ids(thread_id).await,
@@ -8634,6 +8677,25 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn removing_last_connection_reports_orphaned_thread() -> Result<()> {
let manager = ThreadStateManager::new();
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let connection = ConnectionId(1);
manager.connection_initialized(connection).await;
manager
.try_ensure_connection_subscribed(thread_id, connection, false)
.await
.expect("connection should be live");
let orphaned_thread_ids = manager.remove_connection(connection).await;
assert_eq!(orphaned_thread_ids, vec![thread_id]);
assert!(!manager.has_subscribers(thread_id).await);
Ok(())
}
#[tokio::test]
async fn closed_connection_cannot_be_reintroduced_by_auto_subscribe() -> Result<()> {
let manager = ThreadStateManager::new();
@@ -8641,7 +8703,8 @@ mod tests {
let connection = ConnectionId(1);
manager.connection_initialized(connection).await;
manager.remove_connection(connection).await;
let orphaned_thread_ids = manager.remove_connection(connection).await;
assert!(orphaned_thread_ids.is_empty());
assert!(
manager

View File

@@ -313,8 +313,8 @@ impl ThreadStateManager {
true
}
pub(crate) async fn remove_connection(&self, connection_id: ConnectionId) {
let thread_states = {
pub(crate) async fn remove_connection(&self, connection_id: ConnectionId) -> Vec<ThreadId> {
let orphaned_threads = {
let mut state = self.state.lock().await;
state.live_connections.remove(&connection_id);
let thread_ids = state
@@ -344,7 +344,9 @@ impl ThreadStateManager {
.collect::<Vec<_>>()
};
for (thread_id, no_subscribers, thread_state) in thread_states {
let mut orphaned_thread_ids = Vec::new();
for (thread_id, no_subscribers, thread_state) in orphaned_threads {
if !no_subscribers {
continue;
}
@@ -358,6 +360,9 @@ impl ThreadStateManager {
listener_generation,
"retaining thread listener after connection disconnect left zero subscribers"
);
orphaned_thread_ids.push(thread_id);
}
orphaned_thread_ids
}
}

View File

@@ -30,6 +30,8 @@ mod review;
mod safety_check_downgrade;
mod skills_list;
mod thread_archive;
#[cfg(unix)]
mod thread_disconnect_websocket_unix;
mod thread_fork;
mod thread_list;
mod thread_loaded_list;

View File

@@ -0,0 +1,245 @@
use super::connection_handling_websocket::DEFAULT_READ_TIMEOUT;
use super::connection_handling_websocket::WsClient;
use super::connection_handling_websocket::connect_websocket;
use super::connection_handling_websocket::create_config_toml;
use super::connection_handling_websocket::read_response_for_id;
use super::connection_handling_websocket::send_initialize_request;
use super::connection_handling_websocket::send_request;
use super::connection_handling_websocket::spawn_websocket_server;
use anyhow::Context;
use anyhow::Result;
use app_test_support::create_fake_rollout_with_text_elements;
use app_test_support::create_mock_responses_server_repeating_assistant;
use app_test_support::to_response;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::ThreadLoadedListParams;
use codex_app_server_protocol::ThreadLoadedListResponse;
use codex_app_server_protocol::ThreadResumeParams;
use codex_app_server_protocol::ThreadResumeResponse;
use core_test_support::stdio_server_bin;
use pretty_assertions::assert_eq;
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use std::process::Command;
use tempfile::TempDir;
use tokio::time::Duration;
use tokio::time::sleep;
use tokio::time::timeout;
fn process_is_running(pid: u32) -> bool {
let output = Command::new("ps")
.args(["-o", "stat=", "-p", &pid.to_string()])
.stderr(std::process::Stdio::null())
.output();
output
.ok()
.filter(|output| output.status.success())
.and_then(|output| {
String::from_utf8_lossy(&output.stdout)
.lines()
.find_map(|line| {
let trimmed = line.trim();
(!trimmed.is_empty()).then_some(trimmed.to_owned())
})
})
.is_some_and(|stat| !stat.starts_with('Z'))
}
async fn wait_for_pid_entry(path: &Path) -> Result<u32> {
for _ in 0..50 {
match std::fs::read_to_string(path) {
Ok(content) => {
let Some(pid) = content
.lines()
.find_map(|line| {
let trimmed = line.trim();
(!trimmed.is_empty()).then_some(trimmed)
})
.map(str::parse::<u32>)
.transpose()
.with_context(|| format!("failed to parse pid from {}", path.display()))?
else {
sleep(Duration::from_millis(100)).await;
continue;
};
return Ok(pid);
}
Err(error) if error.kind() == std::io::ErrorKind::NotFound => {
sleep(Duration::from_millis(100)).await;
}
Err(error) => {
return Err(error).with_context(|| format!("failed to read {}", path.display()));
}
}
}
anyhow::bail!("timed out waiting for pid file at {}", path.display());
}
async fn wait_for_process_exit(pid: u32) -> Result<()> {
for _ in 0..50 {
if !process_is_running(pid) {
return Ok(());
}
sleep(Duration::from_millis(100)).await;
}
anyhow::bail!("process {pid} still running after timeout");
}
fn ensure_stdio_server_bin() -> Result<String> {
match stdio_server_bin() {
Ok(bin) => Ok(bin),
Err(_) => {
let status = Command::new("cargo")
.args([
"build",
"-p",
"codex-rmcp-client",
"--bin",
"test_stdio_server",
])
.status()
.context("failed to invoke cargo to build test_stdio_server")?;
anyhow::ensure!(
status.success(),
"cargo build -p codex-rmcp-client --bin test_stdio_server failed with status {status}"
);
stdio_server_bin()
.context("test_stdio_server binary still unavailable after cargo build")
}
}
}
fn create_config_toml_with_local_mcp(
codex_home: &Path,
server_uri: &str,
wrapper_path: &Path,
) -> std::io::Result<()> {
create_config_toml(codex_home, server_uri, "never")?;
let config_toml = codex_home.join("config.toml");
let existing = std::fs::read_to_string(&config_toml)?;
std::fs::write(
&config_toml,
format!(
r#"{existing}
[mcp_servers.test_stdio]
command = "{}"
"#,
wrapper_path.display(),
),
)
}
fn create_pid_logging_stdio_wrapper(
dir: &Path,
pid_file: &Path,
server_bin: &str,
) -> std::io::Result<std::path::PathBuf> {
let wrapper_path = dir.join("mcp-wrapper.sh");
std::fs::write(
&wrapper_path,
format!(
"#!/bin/sh\nprintf '%s\\n' \"$$\" >> '{}'\nexec '{}' \"$@\"\n",
pid_file.display(),
server_bin,
),
)?;
let mut permissions = std::fs::metadata(&wrapper_path)?.permissions();
permissions.set_mode(0o755);
std::fs::set_permissions(&wrapper_path, permissions)?;
Ok(wrapper_path)
}
fn create_rollout(codex_home: &Path, filename_ts: &str) -> Result<String> {
create_fake_rollout_with_text_elements(
codex_home,
filename_ts,
"2025-01-05T12:00:00Z",
"Saved user message",
Vec::new(),
Some("mock_provider"),
None,
)
}
async fn initialize_ws_client(bind_addr: std::net::SocketAddr) -> Result<WsClient> {
let mut ws = connect_websocket(bind_addr).await?;
send_initialize_request(&mut ws, 1, "ws_disconnect_client").await?;
timeout(DEFAULT_READ_TIMEOUT, read_response_for_id(&mut ws, 1)).await??;
Ok(ws)
}
#[tokio::test]
async fn websocket_disconnect_unloads_resumed_thread_and_reaps_stdio_mcp() -> Result<()> {
let server = create_mock_responses_server_repeating_assistant("Done").await;
let codex_home = TempDir::new()?;
let pid_file = codex_home.path().join("mcp-pids.log");
let stdio_server = ensure_stdio_server_bin()?;
let wrapper_path =
create_pid_logging_stdio_wrapper(codex_home.path(), &pid_file, &stdio_server)?;
create_config_toml_with_local_mcp(codex_home.path(), &server.uri(), &wrapper_path)?;
let conversation_id = create_rollout(codex_home.path(), "2025-01-05T12-00-00")?;
let (mut process, bind_addr) = spawn_websocket_server(codex_home.path()).await?;
let result = async {
let mut ws = initialize_ws_client(bind_addr).await?;
send_request(
&mut ws,
"thread/resume",
10,
Some(serde_json::to_value(ThreadResumeParams {
thread_id: conversation_id.clone(),
..Default::default()
})?),
)
.await?;
let resume_resp: JSONRPCResponse =
timeout(DEFAULT_READ_TIMEOUT, read_response_for_id(&mut ws, 10)).await??;
let resume: ThreadResumeResponse = to_response::<ThreadResumeResponse>(resume_resp)?;
assert_eq!(resume.thread.id, conversation_id);
let pid = wait_for_pid_entry(&pid_file).await?;
assert!(
process_is_running(pid),
"expected stdio MCP process {pid} to be running before disconnect"
);
ws.close(None)
.await
.context("failed to close websocket connection")?;
drop(ws);
wait_for_process_exit(pid).await?;
let mut ws = connect_websocket(bind_addr).await?;
send_initialize_request(&mut ws, 2, "ws_disconnect_client_reconnect").await?;
timeout(DEFAULT_READ_TIMEOUT, read_response_for_id(&mut ws, 2)).await??;
send_request(
&mut ws,
"thread/loaded/list",
11,
Some(serde_json::to_value(ThreadLoadedListParams::default())?),
)
.await?;
let list_resp: JSONRPCResponse =
timeout(DEFAULT_READ_TIMEOUT, read_response_for_id(&mut ws, 11)).await??;
let ThreadLoadedListResponse { data, next_cursor } =
to_response::<ThreadLoadedListResponse>(list_resp)?;
assert_eq!(data, Vec::<String>::new());
assert_eq!(next_cursor, None);
Ok(())
}
.await;
process
.kill()
.await
.context("failed to stop websocket app-server process")?;
result
}