app-server: harden disconnect cleanup paths (#12218)

Hardens codex-rs/app-server connection lifecycle and outbound routing
for websocket clients. Fixes some FUD I was having

- Added per-connection disconnect signaling (CancellationToken) for
websocket transports.
- Split websocket handling into independent inbound/outbound tasks
coordinated by cancellation.
- Changed outbound routing so websocket connections use non-blocking
try_send; slow/full websocket writers are disconnected instead of
stalling broadcast delivery.
- Kept stdio behavior blocking-on-send (no forced disconnect) so local
stdio clients are not dropped when queues are temporarily full.
- Simplified outbound router flow by removing deferred
pending_closed_connections handling.
- Added guards to drop incoming response/notification/error messages
from unknown connections.
- Fixed listener teardown race in thread listener tasks using a
listener_generation check so stale tasks do not clear newer listeners.

Fixes
https://linear.app/openai/issue/CODEX-4966/multiclient-handle-slow-notification-consumers

  ## Tests

  Added/updated transport tests covering:

  - broadcast does not block on a slow/full websocket connection
  - stdio connection waits instead of disconnecting on full queue

I (maxj) have tested manually and will retest before landing
This commit is contained in:
Max Johnson
2026-02-20 12:35:16 -08:00
committed by GitHub
parent d3cf8bd0fa
commit 6b1091fc92
6 changed files with 372 additions and 87 deletions

1
codex-rs/Cargo.lock generated
View File

@@ -1327,6 +1327,7 @@ dependencies = [
"time",
"tokio",
"tokio-tungstenite",
"tokio-util",
"toml 0.9.11+spec-1.1.0",
"tracing",
"tracing-subscriber",

View File

@@ -48,6 +48,7 @@ tokio = { workspace = true, features = [
"rt-multi-thread",
"signal",
] }
tokio-util = { workspace = true }
tokio-tungstenite = { workspace = true }
tracing = { workspace = true, features = ["log"] }
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt", "json"] }

View File

@@ -5811,7 +5811,7 @@ impl CodexMessageProcessor {
api_version: ApiVersion,
) {
let (cancel_tx, mut cancel_rx) = oneshot::channel();
let mut listener_command_rx = {
let (mut listener_command_rx, listener_generation) = {
let mut thread_state = thread_state.lock().await;
if thread_state.listener_matches(&conversation) {
return;
@@ -5927,6 +5927,11 @@ impl CodexMessageProcessor {
}
}
}
let mut thread_state = thread_state.lock().await;
if thread_state.listener_generation == listener_generation {
thread_state.clear_listener();
}
});
}
async fn git_diff_to_origin(&self, request_id: ConnectionRequestId, cwd: PathBuf) {

View File

@@ -10,7 +10,6 @@ use codex_core::config_loader::LoaderOverrides;
use codex_utils_cli::CliConfigOverrides;
use std::collections::HashMap;
use std::collections::HashSet;
use std::collections::VecDeque;
use std::io::ErrorKind;
use std::io::Result as IoResult;
use std::path::PathBuf;
@@ -42,6 +41,7 @@ use codex_core::config_loader::TextRange as CoreTextRange;
use codex_feedback::CodexFeedback;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use toml::Value as TomlValue;
use tracing::error;
use tracing::info;
@@ -92,6 +92,7 @@ enum OutboundControlEvent {
Opened {
connection_id: ConnectionId,
writer: mpsc::Sender<crate::outgoing_message::OutgoingMessage>,
disconnect_sender: Option<CancellationToken>,
initialized: Arc<AtomicBool>,
opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
},
@@ -403,61 +404,43 @@ pub async fn run_main_with_transport(
}
}
let transport_event_tx_for_outbound = transport_event_tx.clone();
let outbound_handle = tokio::spawn(async move {
let mut outbound_connections = HashMap::<ConnectionId, OutboundConnectionState>::new();
let mut pending_closed_connections = VecDeque::<ConnectionId>::new();
loop {
tokio::select! {
biased;
event = outbound_control_rx.recv() => {
let Some(event) = event else {
break;
};
match event {
OutboundControlEvent::Opened {
connection_id,
writer,
initialized,
opted_out_notification_methods,
} => {
outbound_connections.insert(
biased;
event = outbound_control_rx.recv() => {
let Some(event) = event else {
break;
};
match event {
OutboundControlEvent::Opened {
connection_id,
OutboundConnectionState::new(
writer,
initialized,
opted_out_notification_methods,
),
);
}
OutboundControlEvent::Closed { connection_id } => {
outbound_connections.remove(&connection_id);
writer,
disconnect_sender,
initialized,
opted_out_notification_methods,
} => {
outbound_connections.insert(
connection_id,
OutboundConnectionState::new(
writer,
initialized,
opted_out_notification_methods,
disconnect_sender,
),
);
}
OutboundControlEvent::Closed { connection_id } => {
outbound_connections.remove(&connection_id);
}
}
}
}
envelope = outgoing_rx.recv() => {
envelope = outgoing_rx.recv() => {
let Some(envelope) = envelope else {
break;
};
let disconnected_connections =
route_outgoing_envelope(&mut outbound_connections, envelope).await;
pending_closed_connections.extend(disconnected_connections);
}
}
while let Some(connection_id) = pending_closed_connections.front().copied() {
match transport_event_tx_for_outbound
.try_send(TransportEvent::ConnectionClosed { connection_id })
{
Ok(()) => {
pending_closed_connections.pop_front();
}
Err(mpsc::error::TrySendError::Full(_)) => {
break;
}
Err(mpsc::error::TrySendError::Closed(_)) => {
return;
}
route_outgoing_envelope(&mut outbound_connections, envelope).await;
}
}
}
@@ -491,7 +474,11 @@ pub async fn run_main_with_transport(
break;
};
match event {
TransportEvent::ConnectionOpened { connection_id, writer } => {
TransportEvent::ConnectionOpened {
connection_id,
writer,
disconnect_sender,
} => {
let outbound_initialized = Arc::new(AtomicBool::new(false));
let outbound_opted_out_notification_methods =
Arc::new(RwLock::new(HashSet::new()));
@@ -499,6 +486,7 @@ pub async fn run_main_with_transport(
.send(OutboundControlEvent::Opened {
connection_id,
writer,
disconnect_sender,
initialized: Arc::clone(&outbound_initialized),
opted_out_notification_methods: Arc::clone(
&outbound_opted_out_notification_methods,
@@ -518,6 +506,9 @@ pub async fn run_main_with_transport(
);
}
TransportEvent::ConnectionClosed { connection_id } => {
if connections.remove(&connection_id).is_none() {
continue;
}
if outbound_control_tx
.send(OutboundControlEvent::Closed { connection_id })
.await
@@ -526,7 +517,6 @@ pub async fn run_main_with_transport(
break;
}
processor.connection_closed(connection_id).await;
connections.remove(&connection_id);
if shutdown_when_no_connections && connections.is_empty() {
break;
}
@@ -535,7 +525,7 @@ pub async fn run_main_with_transport(
match message {
JSONRPCMessage::Request(request) => {
let Some(connection_state) = connections.get_mut(&connection_id) else {
warn!("dropping request from unknown connection: {:?}", connection_id);
warn!("dropping request from unknown connection: {connection_id:?}");
continue;
};
let was_initialized = connection_state.session.initialized;
@@ -565,12 +555,24 @@ pub async fn run_main_with_transport(
}
}
JSONRPCMessage::Response(response) => {
if !connections.contains_key(&connection_id) {
warn!("dropping response from unknown connection: {connection_id:?}");
continue;
}
processor.process_response(response).await;
}
JSONRPCMessage::Notification(notification) => {
if !connections.contains_key(&connection_id) {
warn!("dropping notification from unknown connection: {connection_id:?}");
continue;
}
processor.process_notification(notification).await;
}
JSONRPCMessage::Error(err) => {
if !connections.contains_key(&connection_id) {
warn!("dropping error from unknown connection: {connection_id:?}");
continue;
}
processor.process_error(err).await;
}
}

View File

@@ -47,6 +47,7 @@ pub(crate) struct ThreadState {
pub(crate) turn_summary: TurnSummary,
pub(crate) cancel_tx: Option<oneshot::Sender<()>>,
pub(crate) experimental_raw_events: bool,
pub(crate) listener_generation: u64,
listener_command_tx: Option<mpsc::UnboundedSender<ThreadListenerCommand>>,
current_turn_history: ThreadHistoryBuilder,
listener_thread: Option<Weak<CodexThread>>,
@@ -65,14 +66,15 @@ impl ThreadState {
&mut self,
cancel_tx: oneshot::Sender<()>,
conversation: &Arc<CodexThread>,
) -> mpsc::UnboundedReceiver<ThreadListenerCommand> {
) -> (mpsc::UnboundedReceiver<ThreadListenerCommand>, u64) {
if let Some(previous) = self.cancel_tx.replace(cancel_tx) {
let _ = previous.send(());
}
self.listener_generation = self.listener_generation.wrapping_add(1);
let (listener_command_tx, listener_command_rx) = mpsc::unbounded_channel();
self.listener_command_tx = Some(listener_command_tx);
self.listener_thread = Some(Arc::downgrade(conversation));
listener_command_rx
(listener_command_rx, self.listener_generation)
}
pub(crate) fn clear_listener(&mut self) {

View File

@@ -32,6 +32,7 @@ use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio_tungstenite::accept_async;
use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use tracing::error;
use tracing::info;
@@ -135,6 +136,7 @@ pub(crate) enum TransportEvent {
ConnectionOpened {
connection_id: ConnectionId,
writer: mpsc::Sender<OutgoingMessage>,
disconnect_sender: Option<CancellationToken>,
},
ConnectionClosed {
connection_id: ConnectionId,
@@ -168,6 +170,7 @@ pub(crate) struct OutboundConnectionState {
pub(crate) initialized: Arc<AtomicBool>,
pub(crate) opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
pub(crate) writer: mpsc::Sender<OutgoingMessage>,
disconnect_sender: Option<CancellationToken>,
}
impl OutboundConnectionState {
@@ -175,11 +178,23 @@ impl OutboundConnectionState {
writer: mpsc::Sender<OutgoingMessage>,
initialized: Arc<AtomicBool>,
opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
disconnect_sender: Option<CancellationToken>,
) -> Self {
Self {
initialized,
opted_out_notification_methods,
writer,
disconnect_sender,
}
}
fn can_disconnect(&self) -> bool {
self.disconnect_sender.is_some()
}
fn request_disconnect(&self) {
if let Some(disconnect_sender) = &self.disconnect_sender {
disconnect_sender.cancel();
}
}
}
@@ -195,6 +210,7 @@ pub(crate) async fn start_stdio_connection(
.send(TransportEvent::ConnectionOpened {
connection_id,
writer: writer_tx,
disconnect_sender: None,
})
.await
.map_err(|_| std::io::Error::new(ErrorKind::BrokenPipe, "processor unavailable"))?;
@@ -299,12 +315,14 @@ async fn run_websocket_connection(
}
};
let (writer_tx, mut writer_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
let (writer_tx, writer_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
let writer_tx_for_reader = writer_tx.clone();
let disconnect_token = CancellationToken::new();
if transport_event_tx
.send(TransportEvent::ConnectionOpened {
connection_id,
writer: writer_tx,
disconnect_sender: Some(disconnect_token.clone()),
})
.await
.is_err()
@@ -312,9 +330,62 @@ async fn run_websocket_connection(
return;
}
let (mut websocket_writer, mut websocket_reader) = websocket_stream.split();
let (websocket_writer, websocket_reader) = websocket_stream.split();
let (writer_control_tx, writer_control_rx) =
mpsc::channel::<WebSocketMessage>(CHANNEL_CAPACITY);
let mut outbound_task = tokio::spawn(run_websocket_outbound_loop(
websocket_writer,
writer_rx,
writer_control_rx,
disconnect_token.clone(),
));
let mut inbound_task = tokio::spawn(run_websocket_inbound_loop(
websocket_reader,
transport_event_tx.clone(),
writer_tx_for_reader,
writer_control_tx,
connection_id,
disconnect_token.clone(),
));
tokio::select! {
_ = &mut outbound_task => {
disconnect_token.cancel();
inbound_task.abort();
}
_ = &mut inbound_task => {
disconnect_token.cancel();
outbound_task.abort();
}
}
let _ = transport_event_tx
.send(TransportEvent::ConnectionClosed { connection_id })
.await;
}
async fn run_websocket_outbound_loop(
mut websocket_writer: futures::stream::SplitSink<
tokio_tungstenite::WebSocketStream<TcpStream>,
WebSocketMessage,
>,
mut writer_rx: mpsc::Receiver<OutgoingMessage>,
mut writer_control_rx: mpsc::Receiver<WebSocketMessage>,
disconnect_token: CancellationToken,
) {
loop {
tokio::select! {
_ = disconnect_token.cancelled() => {
break;
}
message = writer_control_rx.recv() => {
let Some(message) = message else {
break;
};
if websocket_writer.send(message).await.is_err() {
break;
}
}
outgoing_message = writer_rx.recv() => {
let Some(outgoing_message) = outgoing_message else {
break;
@@ -326,6 +397,25 @@ async fn run_websocket_connection(
break;
}
}
}
}
}
async fn run_websocket_inbound_loop(
mut websocket_reader: futures::stream::SplitStream<
tokio_tungstenite::WebSocketStream<TcpStream>,
>,
transport_event_tx: mpsc::Sender<TransportEvent>,
writer_tx_for_reader: mpsc::Sender<OutgoingMessage>,
writer_control_tx: mpsc::Sender<WebSocketMessage>,
connection_id: ConnectionId,
disconnect_token: CancellationToken,
) {
loop {
tokio::select! {
_ = disconnect_token.cancelled() => {
break;
}
incoming_message = websocket_reader.next() => {
match incoming_message {
Some(Ok(WebSocketMessage::Text(text))) => {
@@ -341,8 +431,13 @@ async fn run_websocket_connection(
}
}
Some(Ok(WebSocketMessage::Ping(payload))) => {
if websocket_writer.send(WebSocketMessage::Pong(payload)).await.is_err() {
break;
match writer_control_tx.try_send(WebSocketMessage::Pong(payload)) {
Ok(()) => {}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => break,
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
warn!("websocket control queue full while replying to ping; closing connection");
break;
}
}
}
Some(Ok(WebSocketMessage::Pong(_))) => {}
@@ -359,10 +454,6 @@ async fn run_websocket_connection(
}
}
}
let _ = transport_event_tx
.send(TransportEvent::ConnectionClosed { connection_id })
.await;
}
async fn forward_incoming_message(
@@ -461,30 +552,61 @@ fn should_skip_notification_for_connection(
}
}
fn disconnect_connection(
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
connection_id: ConnectionId,
) -> bool {
if let Some(connection_state) = connections.remove(&connection_id) {
connection_state.request_disconnect();
return true;
}
false
}
async fn send_message_to_connection(
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
connection_id: ConnectionId,
message: OutgoingMessage,
) -> bool {
let Some(connection_state) = connections.get(&connection_id) else {
warn!("dropping message for disconnected connection: {connection_id:?}");
return false;
};
if should_skip_notification_for_connection(connection_state, &message) {
return false;
}
let writer = connection_state.writer.clone();
if connection_state.can_disconnect() {
match writer.try_send(message) {
Ok(()) => false,
Err(mpsc::error::TrySendError::Full(_)) => {
warn!(
"disconnecting slow connection after outbound queue filled: {connection_id:?}"
);
disconnect_connection(connections, connection_id)
}
Err(mpsc::error::TrySendError::Closed(_)) => {
disconnect_connection(connections, connection_id)
}
}
} else if writer.send(message).await.is_err() {
disconnect_connection(connections, connection_id)
} else {
false
}
}
pub(crate) async fn route_outgoing_envelope(
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
envelope: OutgoingEnvelope,
) -> Vec<ConnectionId> {
let mut disconnected = Vec::new();
) {
match envelope {
OutgoingEnvelope::ToConnection {
connection_id,
message,
} => {
let Some(connection_state) = connections.get(&connection_id) else {
warn!(
"dropping message for disconnected connection: {:?}",
connection_id
);
return disconnected;
};
if should_skip_notification_for_connection(connection_state, &message) {
return disconnected;
}
if connection_state.writer.send(message).await.is_err() {
connections.remove(&connection_id);
disconnected.push(connection_id);
}
let _ = send_message_to_connection(connections, connection_id, message).await;
}
OutgoingEnvelope::Broadcast { message } => {
let target_connections: Vec<ConnectionId> = connections
@@ -501,17 +623,11 @@ pub(crate) async fn route_outgoing_envelope(
.collect();
for connection_id in target_connections {
let Some(connection_state) = connections.get(&connection_id) else {
continue;
};
if connection_state.writer.send(message.clone()).await.is_err() {
connections.remove(&connection_id);
disconnected.push(connection_id);
}
let _ =
send_message_to_connection(connections, connection_id, message.clone()).await;
}
}
}
disconnected
}
#[cfg(test)]
@@ -520,6 +636,8 @@ mod tests {
use crate::error_code::OVERLOADED_ERROR_CODE;
use pretty_assertions::assert_eq;
use serde_json::json;
use tokio::time::Duration;
use tokio::time::timeout;
#[test]
fn app_server_transport_parses_stdio_listen_url() {
@@ -754,10 +872,15 @@ mod tests {
let mut connections = HashMap::new();
connections.insert(
connection_id,
OutboundConnectionState::new(writer_tx, initialized, opted_out_notification_methods),
OutboundConnectionState::new(
writer_tx,
initialized,
opted_out_notification_methods,
None,
),
);
let disconnected = route_outgoing_envelope(
route_outgoing_envelope(
&mut connections,
OutgoingEnvelope::ToConnection {
connection_id,
@@ -771,10 +894,161 @@ mod tests {
)
.await;
assert_eq!(disconnected, Vec::<ConnectionId>::new());
assert!(
writer_rx.try_recv().is_err(),
"opted-out notification should be dropped"
);
}
#[tokio::test]
async fn broadcast_does_not_block_on_slow_connection() {
let fast_connection_id = ConnectionId(1);
let slow_connection_id = ConnectionId(2);
let (fast_writer_tx, mut fast_writer_rx) = mpsc::channel(1);
let (slow_writer_tx, mut slow_writer_rx) = mpsc::channel(1);
let fast_disconnect_token = CancellationToken::new();
let slow_disconnect_token = CancellationToken::new();
let mut connections = HashMap::new();
connections.insert(
fast_connection_id,
OutboundConnectionState::new(
fast_writer_tx,
Arc::new(AtomicBool::new(true)),
Arc::new(RwLock::new(HashSet::new())),
Some(fast_disconnect_token.clone()),
),
);
connections.insert(
slow_connection_id,
OutboundConnectionState::new(
slow_writer_tx.clone(),
Arc::new(AtomicBool::new(true)),
Arc::new(RwLock::new(HashSet::new())),
Some(slow_disconnect_token.clone()),
),
);
let queued_message =
OutgoingMessage::Notification(crate::outgoing_message::OutgoingNotification {
method: "codex/event/already-buffered".to_string(),
params: None,
});
slow_writer_tx
.try_send(queued_message)
.expect("channel should have room");
let broadcast_message =
OutgoingMessage::Notification(crate::outgoing_message::OutgoingNotification {
method: "codex/event/test".to_string(),
params: None,
});
timeout(
Duration::from_millis(100),
route_outgoing_envelope(
&mut connections,
OutgoingEnvelope::Broadcast {
message: broadcast_message,
},
),
)
.await
.expect("broadcast should not block on a full writer");
assert!(!connections.contains_key(&slow_connection_id));
assert!(slow_disconnect_token.is_cancelled());
assert!(!fast_disconnect_token.is_cancelled());
let fast_message = fast_writer_rx
.try_recv()
.expect("fast connection should receive broadcast");
assert!(matches!(
fast_message,
OutgoingMessage::Notification(crate::outgoing_message::OutgoingNotification {
method,
params: None,
}) if method == "codex/event/test"
));
let slow_message = slow_writer_rx
.try_recv()
.expect("slow connection should retain its original buffered message");
assert!(matches!(
slow_message,
OutgoingMessage::Notification(crate::outgoing_message::OutgoingNotification {
method,
params: None,
}) if method == "codex/event/already-buffered"
));
}
#[tokio::test]
async fn to_connection_stdio_waits_instead_of_disconnecting_when_writer_queue_is_full() {
let connection_id = ConnectionId(3);
let (writer_tx, mut writer_rx) = mpsc::channel(1);
writer_tx
.send(OutgoingMessage::Notification(
crate::outgoing_message::OutgoingNotification {
method: "queued".to_string(),
params: None,
},
))
.await
.expect("channel should accept the first queued message");
let mut connections = HashMap::new();
connections.insert(
connection_id,
OutboundConnectionState::new(
writer_tx,
Arc::new(AtomicBool::new(true)),
Arc::new(RwLock::new(HashSet::new())),
None,
),
);
let route_task = tokio::spawn(async move {
route_outgoing_envelope(
&mut connections,
OutgoingEnvelope::ToConnection {
connection_id,
message: OutgoingMessage::Notification(
crate::outgoing_message::OutgoingNotification {
method: "second".to_string(),
params: None,
},
),
},
)
.await
});
let first = timeout(Duration::from_millis(100), writer_rx.recv())
.await
.expect("first queued message should be readable")
.expect("first queued message should exist");
let second = timeout(Duration::from_millis(100), writer_rx.recv())
.await
.expect("second message should eventually be delivered")
.expect("second message should exist");
timeout(Duration::from_millis(100), route_task)
.await
.expect("routing should finish after writer drains")
.expect("routing task should succeed");
assert!(matches!(
first,
OutgoingMessage::Notification(crate::outgoing_message::OutgoingNotification {
method,
params: None,
}) if method == "queued"
));
assert!(matches!(
second,
OutgoingMessage::Notification(crate::outgoing_message::OutgoingNotification {
method,
params: None,
}) if method == "second"
));
}
}