Compare commits

...

4 Commits

Author SHA1 Message Date
Max Johnson
d5e1c66816 app-server: force restart on second ctrl-c
Co-authored-by: Codex <noreply@openai.com>
2026-02-22 15:30:14 -08:00
Max Johnson
ac9f520b57 app-server: gate ctrl-c websocket test on unix
Co-authored-by: Codex <noreply@openai.com>
2026-02-22 15:19:12 -08:00
Max Johnson
11886fcd12 app-server: add ctrl-c restart websocket test
Co-authored-by: Codex <noreply@openai.com>
2026-02-22 14:56:18 -08:00
Max Johnson
0a0a1c5efb impl 2026-02-22 14:56:18 -08:00
6 changed files with 416 additions and 28 deletions

View File

@@ -269,6 +269,7 @@ use std::time::SystemTime;
use tokio::sync::Mutex;
use tokio::sync::broadcast;
use tokio::sync::oneshot;
use tokio::sync::watch;
use toml::Value as TomlValue;
use tracing::error;
use tracing::info;
@@ -2794,6 +2795,10 @@ impl CodexMessageProcessor {
.await;
}
pub(crate) fn subscribe_running_assistant_turn_count(&self) -> watch::Receiver<usize> {
self.thread_watch_manager.subscribe_running_turn_count()
}
/// Best-effort: ensure initialized connections are subscribed to this thread.
pub(crate) async fn try_attach_thread_listener(
&mut self,

View File

@@ -98,6 +98,8 @@ enum OutboundControlEvent {
},
/// Remove state for a closed/disconnected connection.
Closed { connection_id: ConnectionId },
/// Disconnect all connection-oriented clients during graceful restart.
DisconnectAll,
}
fn config_warning_from_error(
@@ -255,17 +257,27 @@ pub async fn run_main_with_transport(
let mut stdio_handles = Vec::<JoinHandle<()>>::new();
let mut websocket_accept_handle = None;
let mut websocket_accept_shutdown = None;
match transport {
AppServerTransport::Stdio => {
start_stdio_connection(transport_event_tx.clone(), &mut stdio_handles).await?;
}
AppServerTransport::WebSocket { bind_address } => {
websocket_accept_handle =
Some(start_websocket_acceptor(bind_address, transport_event_tx.clone()).await?);
let shutdown_token = CancellationToken::new();
websocket_accept_handle = Some(
start_websocket_acceptor(
bind_address,
transport_event_tx.clone(),
shutdown_token.clone(),
)
.await?,
);
websocket_accept_shutdown = Some(shutdown_token);
}
}
let single_client_mode = matches!(transport, AppServerTransport::Stdio);
let shutdown_when_no_connections = single_client_mode;
let graceful_ctrl_c_restart_enabled = !single_client_mode;
// Parse CLI overrides once and derive the base Config eagerly so later
// components do not need to work with raw TOML values.
@@ -434,6 +446,16 @@ pub async fn run_main_with_transport(
OutboundControlEvent::Closed { connection_id } => {
outbound_connections.remove(&connection_id);
}
OutboundControlEvent::DisconnectAll => {
info!(
"disconnecting {} outbound websocket connection(s) for graceful restart",
outbound_connections.len()
);
for connection_state in outbound_connections.values() {
connection_state.request_disconnect();
}
outbound_connections.clear();
}
}
}
envelope = outgoing_rx.recv() => {
@@ -464,11 +486,69 @@ pub async fn run_main_with_transport(
config_warnings,
});
let mut thread_created_rx = processor.thread_created_receiver();
let mut running_turn_count_rx = processor.subscribe_running_assistant_turn_count();
let mut connections = HashMap::<ConnectionId, ConnectionState>::new();
let websocket_accept_shutdown = websocket_accept_shutdown.clone();
async move {
let mut listen_for_threads = true;
let mut restart_requested = false;
let mut restart_forced = false;
let mut last_logged_running_turn_count = None;
loop {
if restart_requested {
let running_turn_count = *running_turn_count_rx.borrow();
if restart_forced || running_turn_count == 0 {
if restart_forced {
info!(
"received second Ctrl-C; forcing restart with {} running assistant turn(s) and {} connection(s)",
running_turn_count,
connections.len()
);
} else {
info!(
"Ctrl-C restart: no assistant turns running; stopping acceptor and disconnecting {} connection(s)",
connections.len()
);
}
if let Some(shutdown_token) = &websocket_accept_shutdown {
shutdown_token.cancel();
}
let _ = outbound_control_tx
.send(OutboundControlEvent::DisconnectAll)
.await;
break;
}
if last_logged_running_turn_count != Some(running_turn_count) {
info!(
"Ctrl-C restart: waiting for {running_turn_count} running assistant turn(s) to finish"
);
last_logged_running_turn_count = Some(running_turn_count);
}
}
tokio::select! {
ctrl_c_result = tokio::signal::ctrl_c(), if graceful_ctrl_c_restart_enabled && !restart_forced => {
if let Err(err) = ctrl_c_result {
warn!("failed to listen for Ctrl-C during graceful restart drain: {err}");
}
if restart_requested {
restart_forced = true;
} else {
restart_requested = true;
last_logged_running_turn_count = None;
let running_turn_count = *running_turn_count_rx.borrow();
info!(
"received Ctrl-C; entering graceful restart drain (connections={}, runningAssistantTurns={}, requests still accepted until no assistant turns are running)",
connections.len(),
running_turn_count,
);
}
}
changed = running_turn_count_rx.changed(), if graceful_ctrl_c_restart_enabled && restart_requested => {
if changed.is_err() {
warn!("running-turn watcher closed during graceful restart drain");
}
}
event = transport_event_rx.recv() => {
let Some(event) = event else {
break;
@@ -619,8 +699,11 @@ pub async fn run_main_with_transport(
let _ = processor_handle.await;
let _ = outbound_handle.await;
if let Some(shutdown_token) = websocket_accept_shutdown {
shutdown_token.cancel();
}
if let Some(handle) = websocket_accept_handle {
handle.abort();
let _ = handle.await;
}
for handle in stdio_handles {

View File

@@ -50,6 +50,7 @@ use codex_feedback::CodexFeedback;
use codex_protocol::ThreadId;
use codex_protocol::protocol::SessionSource;
use tokio::sync::broadcast;
use tokio::sync::watch;
use tokio::time::Duration;
use tokio::time::timeout;
use toml::Value as TomlValue;
@@ -427,6 +428,11 @@ impl MessageProcessor {
.await;
}
pub(crate) fn subscribe_running_assistant_turn_count(&self) -> watch::Receiver<usize> {
self.codex_message_processor
.subscribe_running_assistant_turn_count()
}
/// Handle a standalone JSON-RPC response originating from the peer.
pub(crate) async fn process_response(&mut self, response: JSONRPCResponse) {
tracing::info!("<- response: {:?}", response);

View File

@@ -15,11 +15,13 @@ use std::sync::Arc;
use tokio::sync::Mutex;
#[cfg(test)]
use tokio::sync::mpsc;
use tokio::sync::watch;
#[derive(Clone)]
pub(crate) struct ThreadWatchManager {
state: Arc<Mutex<ThreadWatchState>>,
outgoing: Option<Arc<OutgoingMessageSender>>,
running_turn_count_tx: watch::Sender<usize>,
}
pub(crate) struct ThreadWatchActiveGuard {
@@ -71,16 +73,20 @@ impl Default for ThreadWatchManager {
impl ThreadWatchManager {
pub(crate) fn new() -> Self {
let (running_turn_count_tx, _running_turn_count_rx) = watch::channel(0);
Self {
state: Arc::new(Mutex::new(ThreadWatchState::default())),
outgoing: None,
running_turn_count_tx,
}
}
pub(crate) fn new_with_outgoing(outgoing: Arc<OutgoingMessageSender>) -> Self {
let (running_turn_count_tx, _running_turn_count_rx) = watch::channel(0);
Self {
state: Arc::new(Mutex::new(ThreadWatchState::default())),
outgoing: Some(outgoing),
running_turn_count_tx,
}
}
@@ -113,6 +119,21 @@ impl ThreadWatchManager {
.collect()
}
#[cfg(test)]
pub(crate) async fn running_turn_count(&self) -> usize {
self.state
.lock()
.await
.runtime_by_thread_id
.values()
.filter(|runtime| runtime.running)
.count()
}
pub(crate) fn subscribe_running_turn_count(&self) -> watch::Receiver<usize> {
self.running_turn_count_tx.subscribe()
}
pub(crate) async fn note_turn_started(&self, thread_id: &str) {
self.update_runtime_for_thread(thread_id, |runtime| {
runtime.is_loaded = true;
@@ -193,10 +214,17 @@ impl ThreadWatchManager {
where
F: FnOnce(&mut ThreadWatchState) -> Option<ThreadStatusChangedNotification>,
{
let notification = {
let (notification, running_turn_count) = {
let mut state = self.state.lock().await;
mutate(&mut state)
let notification = mutate(&mut state);
let running_turn_count = state
.runtime_by_thread_id
.values()
.filter(|runtime| runtime.running)
.count();
(notification, running_turn_count)
};
let _ = self.running_turn_count_tx.send(running_turn_count);
if let Some(notification) = notification
&& let Some(outgoing) = &self.outgoing
@@ -588,6 +616,32 @@ mod tests {
);
}
#[tokio::test]
async fn has_running_turns_tracks_runtime_running_flag_only() {
let manager = ThreadWatchManager::new();
manager
.upsert_thread(test_thread(
INTERACTIVE_THREAD_ID,
codex_app_server_protocol::SessionSource::Cli,
))
.await;
assert_eq!(manager.running_turn_count().await, 0);
let _permission_guard = manager
.note_permission_requested(INTERACTIVE_THREAD_ID)
.await;
assert_eq!(manager.running_turn_count().await, 0);
manager.note_turn_started(INTERACTIVE_THREAD_ID).await;
assert_eq!(manager.running_turn_count().await, 1);
manager
.note_turn_completed(INTERACTIVE_THREAD_ID, false)
.await;
assert_eq!(manager.running_turn_count().await, 0);
}
#[tokio::test]
async fn status_change_emits_notification() {
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(8);

View File

@@ -67,12 +67,6 @@ fn print_websocket_startup_banner(addr: SocketAddr) {
}
}
#[allow(clippy::print_stderr)]
fn print_websocket_connection(peer_addr: SocketAddr) {
let connected_label = colorize("websocket client connected from", Style::new().dimmed());
eprintln!("{connected_label} {peer_addr}");
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum AppServerTransport {
Stdio,
@@ -192,7 +186,7 @@ impl OutboundConnectionState {
self.disconnect_sender.is_some()
}
fn request_disconnect(&self) {
pub(crate) fn request_disconnect(&self) {
if let Some(disconnect_sender) = &self.disconnect_sender {
disconnect_sender.cancel();
}
@@ -270,6 +264,7 @@ pub(crate) async fn start_stdio_connection(
pub(crate) async fn start_websocket_acceptor(
bind_address: SocketAddr,
transport_event_tx: mpsc::Sender<TransportEvent>,
shutdown_token: CancellationToken,
) -> IoResult<JoinHandle<()>> {
let listener = TcpListener::bind(bind_address).await?;
let local_addr = listener.local_addr()?;
@@ -279,23 +274,31 @@ pub(crate) async fn start_websocket_acceptor(
let connection_counter = Arc::new(AtomicU64::new(1));
Ok(tokio::spawn(async move {
loop {
match listener.accept().await {
Ok((stream, peer_addr)) => {
print_websocket_connection(peer_addr);
let connection_id =
ConnectionId(connection_counter.fetch_add(1, Ordering::Relaxed));
let transport_event_tx_for_connection = transport_event_tx.clone();
tokio::spawn(async move {
run_websocket_connection(
connection_id,
stream,
transport_event_tx_for_connection,
)
.await;
});
tokio::select! {
_ = shutdown_token.cancelled() => {
info!("websocket acceptor shutting down");
break;
}
Err(err) => {
error!("failed to accept websocket connection: {err}");
accept_result = listener.accept() => {
match accept_result {
Ok((stream, peer_addr)) => {
info!(%peer_addr, "websocket client connected");
let connection_id =
ConnectionId(connection_counter.fetch_add(1, Ordering::Relaxed));
let transport_event_tx_for_connection = transport_event_tx.clone();
tokio::spawn(async move {
run_websocket_connection(
connection_id,
stream,
transport_event_tx_for_connection,
)
.await;
});
}
Err(err) => {
error!("failed to accept websocket connection: {err}");
}
}
}
}
}

View File

@@ -1,7 +1,11 @@
use anyhow::Context;
use anyhow::Result;
use anyhow::bail;
#[cfg(unix)]
use app_test_support::create_final_assistant_message_sse_response;
use app_test_support::create_mock_responses_server_sequence_unchecked;
#[cfg(unix)]
use app_test_support::to_response;
use codex_app_server_protocol::ClientInfo;
use codex_app_server_protocol::InitializeParams;
use codex_app_server_protocol::JSONRPCError;
@@ -9,11 +13,23 @@ use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RequestId;
#[cfg(unix)]
use codex_app_server_protocol::ThreadStartParams;
#[cfg(unix)]
use codex_app_server_protocol::ThreadStartResponse;
#[cfg(unix)]
use codex_app_server_protocol::TurnStartParams;
#[cfg(unix)]
use codex_app_server_protocol::UserInput as V2UserInput;
#[cfg(unix)]
use core_test_support::responses;
use futures::SinkExt;
use futures::StreamExt;
use serde_json::json;
use std::net::SocketAddr;
use std::path::Path;
#[cfg(unix)]
use std::process::Command as StdCommand;
use std::process::Stdio;
use tempfile::TempDir;
use tokio::io::AsyncBufReadExt;
@@ -27,6 +43,12 @@ use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
#[cfg(unix)]
use wiremock::Mock;
#[cfg(unix)]
use wiremock::matchers::method;
#[cfg(unix)]
use wiremock::matchers::path_regex;
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5);
@@ -78,6 +100,59 @@ async fn websocket_transport_routes_per_connection_handshake_and_responses() ->
Ok(())
}
#[cfg(unix)]
#[tokio::test]
async fn websocket_transport_ctrl_c_waits_for_running_turn_before_exit() -> Result<()> {
let GracefulCtrlCFixture {
_codex_home,
_server,
mut process,
mut ws,
} = start_ctrl_c_restart_fixture(Duration::from_secs(3)).await?;
send_sigint(&process)?;
assert_process_does_not_exit_within(&mut process, Duration::from_millis(300)).await?;
let status = wait_for_process_exit_within(
&mut process,
Duration::from_secs(10),
"timed out waiting for graceful Ctrl-C restart shutdown",
)
.await?;
assert!(status.success(), "expected graceful exit, got {status}");
expect_websocket_disconnect(&mut ws).await?;
Ok(())
}
#[cfg(unix)]
#[tokio::test]
async fn websocket_transport_second_ctrl_c_forces_exit_while_turn_running() -> Result<()> {
let GracefulCtrlCFixture {
_codex_home,
_server,
mut process,
mut ws,
} = start_ctrl_c_restart_fixture(Duration::from_secs(3)).await?;
send_sigint(&process)?;
assert_process_does_not_exit_within(&mut process, Duration::from_millis(300)).await?;
send_sigint(&process)?;
let status = wait_for_process_exit_within(
&mut process,
Duration::from_secs(2),
"timed out waiting for forced Ctrl-C restart shutdown",
)
.await?;
assert!(status.success(), "expected graceful exit, got {status}");
expect_websocket_disconnect(&mut ws).await?;
Ok(())
}
async fn spawn_websocket_server(codex_home: &Path, bind_addr: SocketAddr) -> Result<Child> {
let program = codex_utils_cargo_bin::cargo_bin("codex-app-server")
.context("should find app-server binary")?;
@@ -157,6 +232,38 @@ async fn send_config_read_request(stream: &mut WsClient, id: i64) -> Result<()>
.await
}
#[cfg(unix)]
async fn send_thread_start_request(stream: &mut WsClient, id: i64) -> Result<()> {
send_request(
stream,
"thread/start",
id,
Some(serde_json::to_value(ThreadStartParams {
model: Some("mock-model".to_string()),
..Default::default()
})?),
)
.await
}
#[cfg(unix)]
async fn send_turn_start_request(stream: &mut WsClient, id: i64, thread_id: &str) -> Result<()> {
send_request(
stream,
"turn/start",
id,
Some(serde_json::to_value(TurnStartParams {
thread_id: thread_id.to_string(),
input: vec![V2UserInput::Text {
text: "Hello".to_string(),
text_elements: Vec::new(),
}],
..Default::default()
})?),
)
.await
}
async fn send_request(
stream: &mut WsClient,
method: &str,
@@ -235,6 +342,136 @@ async fn assert_no_message(stream: &mut WsClient, wait_for: Duration) -> Result<
}
}
#[cfg(unix)]
struct GracefulCtrlCFixture {
_codex_home: TempDir,
_server: wiremock::MockServer,
process: Child,
ws: WsClient,
}
#[cfg(unix)]
async fn start_ctrl_c_restart_fixture(turn_delay: Duration) -> Result<GracefulCtrlCFixture> {
let server = responses::start_mock_server().await;
let delayed_turn_response = create_final_assistant_message_sse_response("Done")?;
Mock::given(method("POST"))
.and(path_regex(".*/responses$"))
.respond_with(responses::sse_response(delayed_turn_response).set_delay(turn_delay))
.up_to_n_times(1)
.mount(&server)
.await;
let codex_home = TempDir::new()?;
create_config_toml(codex_home.path(), &server.uri(), "never")?;
let bind_addr = reserve_local_addr()?;
let process = spawn_websocket_server(codex_home.path(), bind_addr).await?;
let mut ws = connect_websocket(bind_addr).await?;
send_initialize_request(&mut ws, 1, "ws_graceful_shutdown").await?;
let init_response = read_response_for_id(&mut ws, 1).await?;
assert_eq!(init_response.id, RequestId::Integer(1));
send_thread_start_request(&mut ws, 2).await?;
let thread_start_response = read_response_for_id(&mut ws, 2).await?;
let ThreadStartResponse { thread, .. } = to_response(thread_start_response)?;
send_turn_start_request(&mut ws, 3, &thread.id).await?;
let turn_start_response = read_response_for_id(&mut ws, 3).await?;
assert_eq!(turn_start_response.id, RequestId::Integer(3));
wait_for_responses_post(&server, Duration::from_secs(5)).await?;
Ok(GracefulCtrlCFixture {
_codex_home: codex_home,
_server: server,
process,
ws,
})
}
#[cfg(unix)]
async fn wait_for_responses_post(server: &wiremock::MockServer, wait_for: Duration) -> Result<()> {
let deadline = Instant::now() + wait_for;
loop {
let requests = server
.received_requests()
.await
.context("failed to read mock server requests")?;
if requests
.iter()
.any(|request| request.method == "POST" && request.url.path().ends_with("/responses"))
{
return Ok(());
}
if Instant::now() >= deadline {
bail!("timed out waiting for /responses request");
}
sleep(Duration::from_millis(10)).await;
}
}
#[cfg(unix)]
fn send_sigint(process: &Child) -> Result<()> {
let pid = process
.id()
.context("websocket app-server process has no pid")?;
let status = StdCommand::new("kill")
.arg("-INT")
.arg(pid.to_string())
.status()
.context("failed to invoke kill -INT")?;
if !status.success() {
bail!("kill -INT exited with {status}");
}
Ok(())
}
#[cfg(unix)]
async fn assert_process_does_not_exit_within(process: &mut Child, window: Duration) -> Result<()> {
match timeout(window, process.wait()).await {
Err(_) => Ok(()),
Ok(Ok(status)) => bail!("process exited too early during graceful drain: {status}"),
Ok(Err(err)) => Err(err).context("failed waiting for process"),
}
}
#[cfg(unix)]
async fn wait_for_process_exit_within(
process: &mut Child,
window: Duration,
timeout_context: &'static str,
) -> Result<std::process::ExitStatus> {
timeout(window, process.wait())
.await
.context(timeout_context)?
.context("failed waiting for websocket app-server process exit")
}
#[cfg(unix)]
async fn expect_websocket_disconnect(stream: &mut WsClient) -> Result<()> {
loop {
let frame = timeout(DEFAULT_READ_TIMEOUT, stream.next())
.await
.context("timed out waiting for websocket disconnect")?;
match frame {
None => return Ok(()),
Some(Ok(WebSocketMessage::Close(_))) => return Ok(()),
Some(Ok(WebSocketMessage::Ping(payload))) => {
stream
.send(WebSocketMessage::Pong(payload))
.await
.context("failed to reply to ping while waiting for disconnect")?;
}
Some(Ok(WebSocketMessage::Pong(_))) => {}
Some(Ok(WebSocketMessage::Frame(_))) => {}
Some(Ok(WebSocketMessage::Text(_))) => {}
Some(Ok(WebSocketMessage::Binary(_))) => {}
Some(Err(_)) => return Ok(()),
}
}
}
fn create_config_toml(
codex_home: &Path,
server_uri: &str,