codex-rs/app-server: graceful websocket restart on Ctrl-C (#12517)

## Summary
- add graceful websocket app-server restart on Ctrl-C by draining until
no assistant turns are running
- stop the websocket acceptor and disconnect existing connections once
the drain condition is met
- add a websocket integration test that verifies Ctrl-C waits for an
in-flight turn before exit

## Verification
- `cargo check -p codex-app-server --quiet`
- `cargo test -p codex-app-server --test all
suite::v2::connection_handling_websocket`
- I (maxj) tested remote and local Codex.app

---------

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Max Johnson
2026-02-24 16:27:59 -08:00
committed by GitHub
parent 3d356723c4
commit 5163850025
8 changed files with 493 additions and 42 deletions

View File

@@ -28,9 +28,9 @@ use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5);
pub(super) const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5);
type WsClient = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
pub(super) type WsClient = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
#[tokio::test]
async fn websocket_transport_routes_per_connection_handshake_and_responses() -> Result<()> {
@@ -78,7 +78,10 @@ async fn websocket_transport_routes_per_connection_handshake_and_responses() ->
Ok(())
}
async fn spawn_websocket_server(codex_home: &Path, bind_addr: SocketAddr) -> Result<Child> {
pub(super) async fn spawn_websocket_server(
codex_home: &Path,
bind_addr: SocketAddr,
) -> Result<Child> {
let program = codex_utils_cargo_bin::cargo_bin("codex-app-server")
.context("should find app-server binary")?;
let mut cmd = Command::new(program);
@@ -106,14 +109,14 @@ async fn spawn_websocket_server(codex_home: &Path, bind_addr: SocketAddr) -> Res
Ok(process)
}
fn reserve_local_addr() -> Result<SocketAddr> {
pub(super) fn reserve_local_addr() -> Result<SocketAddr> {
let listener = std::net::TcpListener::bind("127.0.0.1:0")?;
let addr = listener.local_addr()?;
drop(listener);
Ok(addr)
}
async fn connect_websocket(bind_addr: SocketAddr) -> Result<WsClient> {
pub(super) async fn connect_websocket(bind_addr: SocketAddr) -> Result<WsClient> {
let url = format!("ws://{bind_addr}");
let deadline = Instant::now() + Duration::from_secs(10);
loop {
@@ -129,7 +132,11 @@ async fn connect_websocket(bind_addr: SocketAddr) -> Result<WsClient> {
}
}
async fn send_initialize_request(stream: &mut WsClient, id: i64, client_name: &str) -> Result<()> {
pub(super) async fn send_initialize_request(
stream: &mut WsClient,
id: i64,
client_name: &str,
) -> Result<()> {
let params = InitializeParams {
client_info: ClientInfo {
name: client_name.to_string(),
@@ -157,7 +164,7 @@ async fn send_config_read_request(stream: &mut WsClient, id: i64) -> Result<()>
.await
}
async fn send_request(
pub(super) async fn send_request(
stream: &mut WsClient,
method: &str,
id: i64,
@@ -179,7 +186,10 @@ async fn send_jsonrpc(stream: &mut WsClient, message: JSONRPCMessage) -> Result<
.context("failed to send websocket frame")
}
async fn read_response_for_id(stream: &mut WsClient, id: i64) -> Result<JSONRPCResponse> {
pub(super) async fn read_response_for_id(
stream: &mut WsClient,
id: i64,
) -> Result<JSONRPCResponse> {
let target_id = RequestId::Integer(id);
loop {
let message = read_jsonrpc_message(stream).await?;
@@ -235,7 +245,7 @@ async fn assert_no_message(stream: &mut WsClient, wait_for: Duration) -> Result<
}
}
fn create_config_toml(
pub(super) fn create_config_toml(
codex_home: &Path,
server_uri: &str,
approval_policy: &str,

View File

@@ -0,0 +1,237 @@
use super::connection_handling_websocket::DEFAULT_READ_TIMEOUT;
use super::connection_handling_websocket::WsClient;
use super::connection_handling_websocket::connect_websocket;
use super::connection_handling_websocket::create_config_toml;
use super::connection_handling_websocket::read_response_for_id;
use super::connection_handling_websocket::reserve_local_addr;
use super::connection_handling_websocket::send_initialize_request;
use super::connection_handling_websocket::send_request;
use super::connection_handling_websocket::spawn_websocket_server;
use anyhow::Context;
use anyhow::Result;
use anyhow::bail;
use app_test_support::create_final_assistant_message_sse_response;
use app_test_support::to_response;
use codex_app_server_protocol::RequestId;
use codex_app_server_protocol::ThreadStartParams;
use codex_app_server_protocol::ThreadStartResponse;
use codex_app_server_protocol::TurnStartParams;
use codex_app_server_protocol::UserInput as V2UserInput;
use core_test_support::responses;
use futures::SinkExt;
use futures::StreamExt;
use std::process::Command as StdCommand;
use tempfile::TempDir;
use tokio::process::Child;
use tokio::time::Duration;
use tokio::time::Instant;
use tokio::time::sleep;
use tokio::time::timeout;
use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
use wiremock::Mock;
use wiremock::matchers::method;
use wiremock::matchers::path_regex;
#[tokio::test]
async fn websocket_transport_ctrl_c_waits_for_running_turn_before_exit() -> Result<()> {
let GracefulCtrlCFixture {
_codex_home,
_server,
mut process,
mut ws,
} = start_ctrl_c_restart_fixture(Duration::from_secs(3)).await?;
send_sigint(&process)?;
assert_process_does_not_exit_within(&mut process, Duration::from_millis(300)).await?;
let status = wait_for_process_exit_within(
&mut process,
Duration::from_secs(10),
"timed out waiting for graceful Ctrl-C restart shutdown",
)
.await?;
assert!(status.success(), "expected graceful exit, got {status}");
expect_websocket_disconnect(&mut ws).await?;
Ok(())
}
#[tokio::test]
async fn websocket_transport_second_ctrl_c_forces_exit_while_turn_running() -> Result<()> {
let GracefulCtrlCFixture {
_codex_home,
_server,
mut process,
mut ws,
} = start_ctrl_c_restart_fixture(Duration::from_secs(3)).await?;
send_sigint(&process)?;
assert_process_does_not_exit_within(&mut process, Duration::from_millis(300)).await?;
send_sigint(&process)?;
let status = wait_for_process_exit_within(
&mut process,
Duration::from_secs(2),
"timed out waiting for forced Ctrl-C restart shutdown",
)
.await?;
assert!(status.success(), "expected graceful exit, got {status}");
expect_websocket_disconnect(&mut ws).await?;
Ok(())
}
struct GracefulCtrlCFixture {
_codex_home: TempDir,
_server: wiremock::MockServer,
process: Child,
ws: WsClient,
}
async fn start_ctrl_c_restart_fixture(turn_delay: Duration) -> Result<GracefulCtrlCFixture> {
let server = responses::start_mock_server().await;
let delayed_turn_response = create_final_assistant_message_sse_response("Done")?;
Mock::given(method("POST"))
.and(path_regex(".*/responses$"))
.respond_with(responses::sse_response(delayed_turn_response).set_delay(turn_delay))
.up_to_n_times(1)
.mount(&server)
.await;
let codex_home = TempDir::new()?;
create_config_toml(codex_home.path(), &server.uri(), "never")?;
let bind_addr = reserve_local_addr()?;
let process = spawn_websocket_server(codex_home.path(), bind_addr).await?;
let mut ws = connect_websocket(bind_addr).await?;
send_initialize_request(&mut ws, 1, "ws_graceful_shutdown").await?;
let init_response = read_response_for_id(&mut ws, 1).await?;
assert_eq!(init_response.id, RequestId::Integer(1));
send_thread_start_request(&mut ws, 2).await?;
let thread_start_response = read_response_for_id(&mut ws, 2).await?;
let ThreadStartResponse { thread, .. } = to_response(thread_start_response)?;
send_turn_start_request(&mut ws, 3, &thread.id).await?;
let turn_start_response = read_response_for_id(&mut ws, 3).await?;
assert_eq!(turn_start_response.id, RequestId::Integer(3));
wait_for_responses_post(&server, Duration::from_secs(5)).await?;
Ok(GracefulCtrlCFixture {
_codex_home: codex_home,
_server: server,
process,
ws,
})
}
async fn send_thread_start_request(stream: &mut WsClient, id: i64) -> Result<()> {
send_request(
stream,
"thread/start",
id,
Some(serde_json::to_value(ThreadStartParams {
model: Some("mock-model".to_string()),
..Default::default()
})?),
)
.await
}
async fn send_turn_start_request(stream: &mut WsClient, id: i64, thread_id: &str) -> Result<()> {
send_request(
stream,
"turn/start",
id,
Some(serde_json::to_value(TurnStartParams {
thread_id: thread_id.to_string(),
input: vec![V2UserInput::Text {
text: "Hello".to_string(),
text_elements: Vec::new(),
}],
..Default::default()
})?),
)
.await
}
async fn wait_for_responses_post(server: &wiremock::MockServer, wait_for: Duration) -> Result<()> {
let deadline = Instant::now() + wait_for;
loop {
let requests = server
.received_requests()
.await
.context("failed to read mock server requests")?;
if requests
.iter()
.any(|request| request.method == "POST" && request.url.path().ends_with("/responses"))
{
return Ok(());
}
if Instant::now() >= deadline {
bail!("timed out waiting for /responses request");
}
sleep(Duration::from_millis(10)).await;
}
}
fn send_sigint(process: &Child) -> Result<()> {
let pid = process
.id()
.context("websocket app-server process has no pid")?;
let status = StdCommand::new("kill")
.arg("-INT")
.arg(pid.to_string())
.status()
.context("failed to invoke kill -INT")?;
if !status.success() {
bail!("kill -INT exited with {status}");
}
Ok(())
}
async fn assert_process_does_not_exit_within(process: &mut Child, window: Duration) -> Result<()> {
match timeout(window, process.wait()).await {
Err(_) => Ok(()),
Ok(Ok(status)) => bail!("process exited too early during graceful drain: {status}"),
Ok(Err(err)) => Err(err).context("failed waiting for process"),
}
}
async fn wait_for_process_exit_within(
process: &mut Child,
window: Duration,
timeout_context: &'static str,
) -> Result<std::process::ExitStatus> {
timeout(window, process.wait())
.await
.context(timeout_context)?
.context("failed waiting for websocket app-server process exit")
}
async fn expect_websocket_disconnect(stream: &mut WsClient) -> Result<()> {
loop {
let frame = timeout(DEFAULT_READ_TIMEOUT, stream.next())
.await
.context("timed out waiting for websocket disconnect")?;
match frame {
None => return Ok(()),
Some(Ok(WebSocketMessage::Close(_))) => return Ok(()),
Some(Ok(WebSocketMessage::Ping(payload))) => {
stream
.send(WebSocketMessage::Pong(payload))
.await
.context("failed to reply to ping while waiting for disconnect")?;
}
Some(Ok(WebSocketMessage::Pong(_))) => {}
Some(Ok(WebSocketMessage::Frame(_))) => {}
Some(Ok(WebSocketMessage::Text(_))) => {}
Some(Ok(WebSocketMessage::Binary(_))) => {}
Some(Err(_)) => return Ok(()),
}
}
}

View File

@@ -5,6 +5,8 @@ mod collaboration_mode_list;
mod compaction;
mod config_rpc;
mod connection_handling_websocket;
#[cfg(unix)]
mod connection_handling_websocket_unix;
mod dynamic_tools;
mod experimental_api;
mod experimental_feature_list;