Compare commits

...

1 Commits

Author SHA1 Message Date
Richard Lee
c968b85fab Add exec-server SIGTERM shutdown 2026-05-05 21:59:23 -07:00
7 changed files with 222 additions and 19 deletions

View File

@@ -36,6 +36,7 @@ tokio = { workspace = true, features = [
"net",
"process",
"rt-multi-thread",
"signal",
"sync",
"time",
] }

View File

@@ -9,6 +9,7 @@ use serde_json::Value;
use sha2::Digest as _;
use tokio::time::sleep;
use tokio_tungstenite::connect_async;
use tokio_util::sync::CancellationToken;
use tracing::warn;
use uuid::Uuid;
@@ -194,10 +195,13 @@ pub async fn run_remote_executor(
Ok((websocket, _)) => {
backoff = Duration::from_secs(1);
processor
.run_connection(JsonRpcConnection::from_websocket(
websocket,
"remote exec-server websocket".to_string(),
))
.run_connection(
JsonRpcConnection::from_websocket(
websocket,
"remote exec-server websocket".to_string(),
),
CancellationToken::new(),
)
.await;
}
Err(err) => {

View File

@@ -1,6 +1,7 @@
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use tracing::warn;
@@ -31,20 +32,30 @@ impl ConnectionProcessor {
}
}
pub(crate) async fn run_connection(&self, connection: JsonRpcConnection) {
pub(crate) async fn run_connection(
&self,
connection: JsonRpcConnection,
shutdown_token: CancellationToken,
) {
run_connection(
connection,
Arc::clone(&self.session_registry),
self.runtime_paths.clone(),
shutdown_token,
)
.await;
}
pub(crate) async fn shutdown(&self) {
self.session_registry.shutdown_all().await;
}
}
async fn run_connection(
connection: JsonRpcConnection,
session_registry: Arc<SessionRegistry>,
runtime_paths: ExecServerRuntimePaths,
shutdown_token: CancellationToken,
) {
let router = Arc::new(build_router());
let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) =
@@ -74,7 +85,17 @@ async fn run_connection(
});
// Process inbound events sequentially to preserve initialize/initialized ordering.
while let Some(event) = incoming_rx.recv().await {
loop {
let event = tokio::select! {
event = incoming_rx.recv() => event,
_ = shutdown_token.cancelled() => {
debug!("exec-server connection shutting down after signal");
break;
}
};
let Some(event) = event else {
break;
};
if !handler.is_session_attached() {
debug!("exec-server connection evicted after session resume");
break;
@@ -102,6 +123,10 @@ async fn run_connection(
debug!("exec-server transport disconnected while handling request");
break;
}
_ = shutdown_token.cancelled() => {
debug!("exec-server shutdown while handling request");
break;
}
};
if let Some(message) = message
&& outgoing_tx.send(message).await.is_err()
@@ -139,6 +164,10 @@ async fn run_connection(
);
break;
}
_ = shutdown_token.cancelled() => {
debug!("exec-server shutdown while handling notification");
break;
}
};
if let Err(err) = result {
warn!("closing exec-server connection after protocol error: {err}");
@@ -200,6 +229,7 @@ mod tests {
use tokio::io::duplex;
use tokio::task::JoinHandle;
use tokio::time::timeout;
use tokio_util::sync::CancellationToken;
use super::run_connection;
use crate::ExecServerRuntimePaths;
@@ -317,7 +347,12 @@ mod tests {
let (server_writer, client_reader) = duplex(1 << 20);
let connection =
JsonRpcConnection::from_stdio(server_reader, server_writer, label.to_string());
let task = tokio::spawn(run_connection(connection, registry, test_runtime_paths()));
let task = tokio::spawn(run_connection(
connection,
registry,
test_runtime_paths(),
CancellationToken::new(),
));
(client_writer, BufReader::new(client_reader).lines(), task)
}

View File

@@ -116,6 +116,17 @@ impl SessionRegistry {
})
}
pub(crate) async fn shutdown_all(&self) {
let sessions = {
let mut sessions = self.sessions.lock().await;
sessions.drain().map(|(_, entry)| entry).collect::<Vec<_>>()
};
for session in sessions {
session.process.shutdown().await;
}
}
async fn expire_if_detached(&self, session_id: String, connection_id: ConnectionId) {
tokio::time::sleep(DETACHED_SESSION_TTL).await;

View File

@@ -1,3 +1,4 @@
use std::io::Result as IoResult;
use std::io::Write as _;
use std::net::SocketAddr;
use tokio::io;
@@ -5,6 +6,9 @@ use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::net::TcpListener;
use tokio_tungstenite::accept_async;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use tracing::info;
use tracing::warn;
use crate::ExecServerRuntimePaths;
@@ -92,13 +96,27 @@ where
{
let processor = ConnectionProcessor::new(runtime_paths);
tracing::info!("codex-exec-server listening on stdio");
let shutdown_token = CancellationToken::new();
let signal_shutdown_token = shutdown_token.clone();
let signal_task = tokio::spawn(async move {
match shutdown_signal().await {
Ok(()) => {
info!("received SIGTERM; shutting down codex-exec-server");
signal_shutdown_token.cancel();
}
Err(err) => {
warn!("failed to listen for exec-server shutdown signal: {err}");
}
}
});
processor
.run_connection(JsonRpcConnection::from_stdio(
reader,
writer,
"exec-server stdio".to_string(),
))
.run_connection(
JsonRpcConnection::from_stdio(reader, writer, "exec-server stdio".to_string()),
shutdown_token,
)
.await;
signal_task.abort();
processor.shutdown().await;
Ok(())
}
@@ -113,17 +131,41 @@ async fn run_websocket_listener(
println!("ws://{local_addr}");
std::io::stdout().flush()?;
let shutdown_token = CancellationToken::new();
let connection_tasks = TaskTracker::new();
let shutdown_signal = shutdown_signal();
tokio::pin!(shutdown_signal);
loop {
let (stream, peer_addr) = listener.accept().await?;
let accepted = tokio::select! {
accepted = listener.accept() => accepted?,
shutdown_result = &mut shutdown_signal => {
if let Err(err) = shutdown_result {
warn!("failed to listen for exec-server shutdown signal: {err}");
}
info!("received SIGTERM; shutting down codex-exec-server");
break;
}
};
let (stream, peer_addr) = accepted;
let processor = processor.clone();
tokio::spawn(async move {
match accept_async(stream).await {
let connection_shutdown_token = shutdown_token.clone();
connection_tasks.spawn(async move {
let websocket = tokio::select! {
websocket = accept_async(stream) => websocket,
_ = connection_shutdown_token.cancelled() => {
return;
}
};
match websocket {
Ok(websocket) => {
processor
.run_connection(JsonRpcConnection::from_websocket(
websocket,
format!("exec-server websocket {peer_addr}"),
))
.run_connection(
JsonRpcConnection::from_websocket(
websocket,
format!("exec-server websocket {peer_addr}"),
),
connection_shutdown_token,
)
.await;
}
Err(err) => {
@@ -134,6 +176,31 @@ async fn run_websocket_listener(
}
});
}
shutdown_token.cancel();
connection_tasks.close();
connection_tasks.wait().await;
processor.shutdown().await;
info!("codex-exec-server shutdown complete");
Ok(())
}
async fn shutdown_signal() -> IoResult<()> {
#[cfg(unix)]
{
use tokio::signal::unix::SignalKind;
use tokio::signal::unix::signal;
let mut term = signal(SignalKind::terminate())?;
let _ = term.recv().await;
Ok(())
}
#[cfg(not(unix))]
{
std::future::pending::<()>().await;
Ok(())
}
}
#[cfg(test)]

View File

@@ -1,6 +1,8 @@
#![allow(dead_code)]
use std::path::PathBuf;
#[cfg(unix)]
use std::process::Command as StdCommand;
use std::process::Stdio;
use std::time::Duration;
@@ -177,6 +179,27 @@ impl ExecServerHarness {
Ok(())
}
#[cfg(unix)]
pub(crate) async fn send_sigterm_and_wait(
&mut self,
) -> anyhow::Result<std::process::ExitStatus> {
let pid = self
.child
.id()
.ok_or_else(|| anyhow!("exec-server process has no pid"))?;
let status = StdCommand::new("kill")
.arg("-TERM")
.arg(pid.to_string())
.status()?;
if !status.success() {
return Err(anyhow!("kill -TERM exited with {status}"));
}
timeout(CONNECT_TIMEOUT, self.child.wait())
.await
.map_err(|_| anyhow!("timed out waiting for exec-server SIGTERM shutdown"))?
.map_err(Into::into)
}
async fn send_message(&mut self, message: JSONRPCMessage) -> anyhow::Result<()> {
let encoded = serde_json::to_string(&message)?;
self.websocket.send(Message::Text(encoded.into())).await?;

View File

@@ -78,6 +78,68 @@ async fn exec_server_starts_process_over_websocket() -> anyhow::Result<()> {
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn exec_server_sigterm_shuts_down_gracefully() -> anyhow::Result<()> {
let mut server = exec_server().await?;
let initialize_id = server
.send_request(
"initialize",
serde_json::to_value(InitializeParams {
client_name: "exec-server-test".to_string(),
resume_session_id: None,
})?,
)
.await?;
let _ = server
.wait_for_event(|event| {
matches!(
event,
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &initialize_id
)
})
.await?;
server
.send_notification("initialized", serde_json::json!({}))
.await?;
let process_start_id = server
.send_request(
"process/start",
serde_json::json!({
"processId": "proc-sigterm",
"argv": ["/bin/sh", "-c", "sleep 30"],
"cwd": std::env::current_dir()?,
"env": {},
"tty": false,
"pipeStdin": false,
"arg0": null
}),
)
.await?;
let _ = server
.wait_for_event(|event| {
matches!(
event,
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &process_start_id
)
})
.await?;
let status = server.send_sigterm_and_wait().await?;
assert!(
status.success(),
"expected graceful SIGTERM exit, got {status}"
);
server
.next_event()
.await
.expect_err("websocket should disconnect during SIGTERM shutdown");
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn exec_server_defaults_omitted_pipe_stdin_to_closed_stdin() -> anyhow::Result<()> {
let mut server = exec_server().await?;