mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
app-server: harden disconnect cleanup paths (#12218)
Hardens codex-rs/app-server connection lifecycle and outbound routing for websocket clients. Fixes some FUD I was having - Added per-connection disconnect signaling (CancellationToken) for websocket transports. - Split websocket handling into independent inbound/outbound tasks coordinated by cancellation. - Changed outbound routing so websocket connections use non-blocking try_send; slow/full websocket writers are disconnected instead of stalling broadcast delivery. - Kept stdio behavior blocking-on-send (no forced disconnect) so local stdio clients are not dropped when queues are temporarily full. - Simplified outbound router flow by removing deferred pending_closed_connections handling. - Added guards to drop incoming response/notification/error messages from unknown connections. - Fixed listener teardown race in thread listener tasks using a listener_generation check so stale tasks do not clear newer listeners. Fixes https://linear.app/openai/issue/CODEX-4966/multiclient-handle-slow-notification-consumers ## Tests Added/updated transport tests covering: - broadcast does not block on a slow/full websocket connection - stdio connection waits instead of disconnecting on full queue I (maxj) have tested manually and will retest before landing
This commit is contained in:
1
codex-rs/Cargo.lock
generated
1
codex-rs/Cargo.lock
generated
@@ -1327,6 +1327,7 @@ dependencies = [
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tokio-util",
|
||||
"toml 0.9.11+spec-1.1.0",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
|
||||
@@ -48,6 +48,7 @@ tokio = { workspace = true, features = [
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
tokio-util = { workspace = true }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt", "json"] }
|
||||
|
||||
@@ -5811,7 +5811,7 @@ impl CodexMessageProcessor {
|
||||
api_version: ApiVersion,
|
||||
) {
|
||||
let (cancel_tx, mut cancel_rx) = oneshot::channel();
|
||||
let mut listener_command_rx = {
|
||||
let (mut listener_command_rx, listener_generation) = {
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
if thread_state.listener_matches(&conversation) {
|
||||
return;
|
||||
@@ -5927,6 +5927,11 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
if thread_state.listener_generation == listener_generation {
|
||||
thread_state.clear_listener();
|
||||
}
|
||||
});
|
||||
}
|
||||
async fn git_diff_to_origin(&self, request_id: ConnectionRequestId, cwd: PathBuf) {
|
||||
|
||||
@@ -10,7 +10,6 @@ use codex_core::config_loader::LoaderOverrides;
|
||||
use codex_utils_cli::CliConfigOverrides;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::collections::VecDeque;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Result as IoResult;
|
||||
use std::path::PathBuf;
|
||||
@@ -42,6 +41,7 @@ use codex_core::config_loader::TextRange as CoreTextRange;
|
||||
use codex_feedback::CodexFeedback;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use toml::Value as TomlValue;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
@@ -92,6 +92,7 @@ enum OutboundControlEvent {
|
||||
Opened {
|
||||
connection_id: ConnectionId,
|
||||
writer: mpsc::Sender<crate::outgoing_message::OutgoingMessage>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
initialized: Arc<AtomicBool>,
|
||||
opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
},
|
||||
@@ -403,61 +404,43 @@ pub async fn run_main_with_transport(
|
||||
}
|
||||
}
|
||||
|
||||
let transport_event_tx_for_outbound = transport_event_tx.clone();
|
||||
let outbound_handle = tokio::spawn(async move {
|
||||
let mut outbound_connections = HashMap::<ConnectionId, OutboundConnectionState>::new();
|
||||
let mut pending_closed_connections = VecDeque::<ConnectionId>::new();
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
event = outbound_control_rx.recv() => {
|
||||
let Some(event) = event else {
|
||||
break;
|
||||
};
|
||||
match event {
|
||||
OutboundControlEvent::Opened {
|
||||
connection_id,
|
||||
writer,
|
||||
initialized,
|
||||
opted_out_notification_methods,
|
||||
} => {
|
||||
outbound_connections.insert(
|
||||
biased;
|
||||
event = outbound_control_rx.recv() => {
|
||||
let Some(event) = event else {
|
||||
break;
|
||||
};
|
||||
match event {
|
||||
OutboundControlEvent::Opened {
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer,
|
||||
initialized,
|
||||
opted_out_notification_methods,
|
||||
),
|
||||
);
|
||||
}
|
||||
OutboundControlEvent::Closed { connection_id } => {
|
||||
outbound_connections.remove(&connection_id);
|
||||
writer,
|
||||
disconnect_sender,
|
||||
initialized,
|
||||
opted_out_notification_methods,
|
||||
} => {
|
||||
outbound_connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer,
|
||||
initialized,
|
||||
opted_out_notification_methods,
|
||||
disconnect_sender,
|
||||
),
|
||||
);
|
||||
}
|
||||
OutboundControlEvent::Closed { connection_id } => {
|
||||
outbound_connections.remove(&connection_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
envelope = outgoing_rx.recv() => {
|
||||
envelope = outgoing_rx.recv() => {
|
||||
let Some(envelope) = envelope else {
|
||||
break;
|
||||
};
|
||||
let disconnected_connections =
|
||||
route_outgoing_envelope(&mut outbound_connections, envelope).await;
|
||||
pending_closed_connections.extend(disconnected_connections);
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(connection_id) = pending_closed_connections.front().copied() {
|
||||
match transport_event_tx_for_outbound
|
||||
.try_send(TransportEvent::ConnectionClosed { connection_id })
|
||||
{
|
||||
Ok(()) => {
|
||||
pending_closed_connections.pop_front();
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Full(_)) => {
|
||||
break;
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => {
|
||||
return;
|
||||
}
|
||||
route_outgoing_envelope(&mut outbound_connections, envelope).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -491,7 +474,11 @@ pub async fn run_main_with_transport(
|
||||
break;
|
||||
};
|
||||
match event {
|
||||
TransportEvent::ConnectionOpened { connection_id, writer } => {
|
||||
TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer,
|
||||
disconnect_sender,
|
||||
} => {
|
||||
let outbound_initialized = Arc::new(AtomicBool::new(false));
|
||||
let outbound_opted_out_notification_methods =
|
||||
Arc::new(RwLock::new(HashSet::new()));
|
||||
@@ -499,6 +486,7 @@ pub async fn run_main_with_transport(
|
||||
.send(OutboundControlEvent::Opened {
|
||||
connection_id,
|
||||
writer,
|
||||
disconnect_sender,
|
||||
initialized: Arc::clone(&outbound_initialized),
|
||||
opted_out_notification_methods: Arc::clone(
|
||||
&outbound_opted_out_notification_methods,
|
||||
@@ -518,6 +506,9 @@ pub async fn run_main_with_transport(
|
||||
);
|
||||
}
|
||||
TransportEvent::ConnectionClosed { connection_id } => {
|
||||
if connections.remove(&connection_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
if outbound_control_tx
|
||||
.send(OutboundControlEvent::Closed { connection_id })
|
||||
.await
|
||||
@@ -526,7 +517,6 @@ pub async fn run_main_with_transport(
|
||||
break;
|
||||
}
|
||||
processor.connection_closed(connection_id).await;
|
||||
connections.remove(&connection_id);
|
||||
if shutdown_when_no_connections && connections.is_empty() {
|
||||
break;
|
||||
}
|
||||
@@ -535,7 +525,7 @@ pub async fn run_main_with_transport(
|
||||
match message {
|
||||
JSONRPCMessage::Request(request) => {
|
||||
let Some(connection_state) = connections.get_mut(&connection_id) else {
|
||||
warn!("dropping request from unknown connection: {:?}", connection_id);
|
||||
warn!("dropping request from unknown connection: {connection_id:?}");
|
||||
continue;
|
||||
};
|
||||
let was_initialized = connection_state.session.initialized;
|
||||
@@ -565,12 +555,24 @@ pub async fn run_main_with_transport(
|
||||
}
|
||||
}
|
||||
JSONRPCMessage::Response(response) => {
|
||||
if !connections.contains_key(&connection_id) {
|
||||
warn!("dropping response from unknown connection: {connection_id:?}");
|
||||
continue;
|
||||
}
|
||||
processor.process_response(response).await;
|
||||
}
|
||||
JSONRPCMessage::Notification(notification) => {
|
||||
if !connections.contains_key(&connection_id) {
|
||||
warn!("dropping notification from unknown connection: {connection_id:?}");
|
||||
continue;
|
||||
}
|
||||
processor.process_notification(notification).await;
|
||||
}
|
||||
JSONRPCMessage::Error(err) => {
|
||||
if !connections.contains_key(&connection_id) {
|
||||
warn!("dropping error from unknown connection: {connection_id:?}");
|
||||
continue;
|
||||
}
|
||||
processor.process_error(err).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ pub(crate) struct ThreadState {
|
||||
pub(crate) turn_summary: TurnSummary,
|
||||
pub(crate) cancel_tx: Option<oneshot::Sender<()>>,
|
||||
pub(crate) experimental_raw_events: bool,
|
||||
pub(crate) listener_generation: u64,
|
||||
listener_command_tx: Option<mpsc::UnboundedSender<ThreadListenerCommand>>,
|
||||
current_turn_history: ThreadHistoryBuilder,
|
||||
listener_thread: Option<Weak<CodexThread>>,
|
||||
@@ -65,14 +66,15 @@ impl ThreadState {
|
||||
&mut self,
|
||||
cancel_tx: oneshot::Sender<()>,
|
||||
conversation: &Arc<CodexThread>,
|
||||
) -> mpsc::UnboundedReceiver<ThreadListenerCommand> {
|
||||
) -> (mpsc::UnboundedReceiver<ThreadListenerCommand>, u64) {
|
||||
if let Some(previous) = self.cancel_tx.replace(cancel_tx) {
|
||||
let _ = previous.send(());
|
||||
}
|
||||
self.listener_generation = self.listener_generation.wrapping_add(1);
|
||||
let (listener_command_tx, listener_command_rx) = mpsc::unbounded_channel();
|
||||
self.listener_command_tx = Some(listener_command_tx);
|
||||
self.listener_thread = Some(Arc::downgrade(conversation));
|
||||
listener_command_rx
|
||||
(listener_command_rx, self.listener_generation)
|
||||
}
|
||||
|
||||
pub(crate) fn clear_listener(&mut self) {
|
||||
|
||||
@@ -32,6 +32,7 @@ use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_tungstenite::accept_async;
|
||||
use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
@@ -135,6 +136,7 @@ pub(crate) enum TransportEvent {
|
||||
ConnectionOpened {
|
||||
connection_id: ConnectionId,
|
||||
writer: mpsc::Sender<OutgoingMessage>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
},
|
||||
ConnectionClosed {
|
||||
connection_id: ConnectionId,
|
||||
@@ -168,6 +170,7 @@ pub(crate) struct OutboundConnectionState {
|
||||
pub(crate) initialized: Arc<AtomicBool>,
|
||||
pub(crate) opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
pub(crate) writer: mpsc::Sender<OutgoingMessage>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
}
|
||||
|
||||
impl OutboundConnectionState {
|
||||
@@ -175,11 +178,23 @@ impl OutboundConnectionState {
|
||||
writer: mpsc::Sender<OutgoingMessage>,
|
||||
initialized: Arc<AtomicBool>,
|
||||
opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
) -> Self {
|
||||
Self {
|
||||
initialized,
|
||||
opted_out_notification_methods,
|
||||
writer,
|
||||
disconnect_sender,
|
||||
}
|
||||
}
|
||||
|
||||
fn can_disconnect(&self) -> bool {
|
||||
self.disconnect_sender.is_some()
|
||||
}
|
||||
|
||||
fn request_disconnect(&self) {
|
||||
if let Some(disconnect_sender) = &self.disconnect_sender {
|
||||
disconnect_sender.cancel();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -195,6 +210,7 @@ pub(crate) async fn start_stdio_connection(
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer: writer_tx,
|
||||
disconnect_sender: None,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| std::io::Error::new(ErrorKind::BrokenPipe, "processor unavailable"))?;
|
||||
@@ -299,12 +315,14 @@ async fn run_websocket_connection(
|
||||
}
|
||||
};
|
||||
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let (writer_tx, writer_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let writer_tx_for_reader = writer_tx.clone();
|
||||
let disconnect_token = CancellationToken::new();
|
||||
if transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer: writer_tx,
|
||||
disconnect_sender: Some(disconnect_token.clone()),
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
@@ -312,9 +330,62 @@ async fn run_websocket_connection(
|
||||
return;
|
||||
}
|
||||
|
||||
let (mut websocket_writer, mut websocket_reader) = websocket_stream.split();
|
||||
let (websocket_writer, websocket_reader) = websocket_stream.split();
|
||||
let (writer_control_tx, writer_control_rx) =
|
||||
mpsc::channel::<WebSocketMessage>(CHANNEL_CAPACITY);
|
||||
let mut outbound_task = tokio::spawn(run_websocket_outbound_loop(
|
||||
websocket_writer,
|
||||
writer_rx,
|
||||
writer_control_rx,
|
||||
disconnect_token.clone(),
|
||||
));
|
||||
let mut inbound_task = tokio::spawn(run_websocket_inbound_loop(
|
||||
websocket_reader,
|
||||
transport_event_tx.clone(),
|
||||
writer_tx_for_reader,
|
||||
writer_control_tx,
|
||||
connection_id,
|
||||
disconnect_token.clone(),
|
||||
));
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut outbound_task => {
|
||||
disconnect_token.cancel();
|
||||
inbound_task.abort();
|
||||
}
|
||||
_ = &mut inbound_task => {
|
||||
disconnect_token.cancel();
|
||||
outbound_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
let _ = transport_event_tx
|
||||
.send(TransportEvent::ConnectionClosed { connection_id })
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn run_websocket_outbound_loop(
|
||||
mut websocket_writer: futures::stream::SplitSink<
|
||||
tokio_tungstenite::WebSocketStream<TcpStream>,
|
||||
WebSocketMessage,
|
||||
>,
|
||||
mut writer_rx: mpsc::Receiver<OutgoingMessage>,
|
||||
mut writer_control_rx: mpsc::Receiver<WebSocketMessage>,
|
||||
disconnect_token: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = disconnect_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
message = writer_control_rx.recv() => {
|
||||
let Some(message) = message else {
|
||||
break;
|
||||
};
|
||||
if websocket_writer.send(message).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
outgoing_message = writer_rx.recv() => {
|
||||
let Some(outgoing_message) = outgoing_message else {
|
||||
break;
|
||||
@@ -326,6 +397,25 @@ async fn run_websocket_connection(
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_websocket_inbound_loop(
|
||||
mut websocket_reader: futures::stream::SplitStream<
|
||||
tokio_tungstenite::WebSocketStream<TcpStream>,
|
||||
>,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
writer_tx_for_reader: mpsc::Sender<OutgoingMessage>,
|
||||
writer_control_tx: mpsc::Sender<WebSocketMessage>,
|
||||
connection_id: ConnectionId,
|
||||
disconnect_token: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = disconnect_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
incoming_message = websocket_reader.next() => {
|
||||
match incoming_message {
|
||||
Some(Ok(WebSocketMessage::Text(text))) => {
|
||||
@@ -341,8 +431,13 @@ async fn run_websocket_connection(
|
||||
}
|
||||
}
|
||||
Some(Ok(WebSocketMessage::Ping(payload))) => {
|
||||
if websocket_writer.send(WebSocketMessage::Pong(payload)).await.is_err() {
|
||||
break;
|
||||
match writer_control_tx.try_send(WebSocketMessage::Pong(payload)) {
|
||||
Ok(()) => {}
|
||||
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => break,
|
||||
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
|
||||
warn!("websocket control queue full while replying to ping; closing connection");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(WebSocketMessage::Pong(_))) => {}
|
||||
@@ -359,10 +454,6 @@ async fn run_websocket_connection(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = transport_event_tx
|
||||
.send(TransportEvent::ConnectionClosed { connection_id })
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn forward_incoming_message(
|
||||
@@ -461,30 +552,61 @@ fn should_skip_notification_for_connection(
|
||||
}
|
||||
}
|
||||
|
||||
fn disconnect_connection(
|
||||
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||
connection_id: ConnectionId,
|
||||
) -> bool {
|
||||
if let Some(connection_state) = connections.remove(&connection_id) {
|
||||
connection_state.request_disconnect();
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
async fn send_message_to_connection(
|
||||
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||
connection_id: ConnectionId,
|
||||
message: OutgoingMessage,
|
||||
) -> bool {
|
||||
let Some(connection_state) = connections.get(&connection_id) else {
|
||||
warn!("dropping message for disconnected connection: {connection_id:?}");
|
||||
return false;
|
||||
};
|
||||
if should_skip_notification_for_connection(connection_state, &message) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let writer = connection_state.writer.clone();
|
||||
if connection_state.can_disconnect() {
|
||||
match writer.try_send(message) {
|
||||
Ok(()) => false,
|
||||
Err(mpsc::error::TrySendError::Full(_)) => {
|
||||
warn!(
|
||||
"disconnecting slow connection after outbound queue filled: {connection_id:?}"
|
||||
);
|
||||
disconnect_connection(connections, connection_id)
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => {
|
||||
disconnect_connection(connections, connection_id)
|
||||
}
|
||||
}
|
||||
} else if writer.send(message).await.is_err() {
|
||||
disconnect_connection(connections, connection_id)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn route_outgoing_envelope(
|
||||
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||
envelope: OutgoingEnvelope,
|
||||
) -> Vec<ConnectionId> {
|
||||
let mut disconnected = Vec::new();
|
||||
) {
|
||||
match envelope {
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
} => {
|
||||
let Some(connection_state) = connections.get(&connection_id) else {
|
||||
warn!(
|
||||
"dropping message for disconnected connection: {:?}",
|
||||
connection_id
|
||||
);
|
||||
return disconnected;
|
||||
};
|
||||
if should_skip_notification_for_connection(connection_state, &message) {
|
||||
return disconnected;
|
||||
}
|
||||
if connection_state.writer.send(message).await.is_err() {
|
||||
connections.remove(&connection_id);
|
||||
disconnected.push(connection_id);
|
||||
}
|
||||
let _ = send_message_to_connection(connections, connection_id, message).await;
|
||||
}
|
||||
OutgoingEnvelope::Broadcast { message } => {
|
||||
let target_connections: Vec<ConnectionId> = connections
|
||||
@@ -501,17 +623,11 @@ pub(crate) async fn route_outgoing_envelope(
|
||||
.collect();
|
||||
|
||||
for connection_id in target_connections {
|
||||
let Some(connection_state) = connections.get(&connection_id) else {
|
||||
continue;
|
||||
};
|
||||
if connection_state.writer.send(message.clone()).await.is_err() {
|
||||
connections.remove(&connection_id);
|
||||
disconnected.push(connection_id);
|
||||
}
|
||||
let _ =
|
||||
send_message_to_connection(connections, connection_id, message.clone()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
disconnected
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -520,6 +636,8 @@ mod tests {
|
||||
use crate::error_code::OVERLOADED_ERROR_CODE;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
#[test]
|
||||
fn app_server_transport_parses_stdio_listen_url() {
|
||||
@@ -754,10 +872,15 @@ mod tests {
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(writer_tx, initialized, opted_out_notification_methods),
|
||||
OutboundConnectionState::new(
|
||||
writer_tx,
|
||||
initialized,
|
||||
opted_out_notification_methods,
|
||||
None,
|
||||
),
|
||||
);
|
||||
|
||||
let disconnected = route_outgoing_envelope(
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
@@ -771,10 +894,161 @@ mod tests {
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(disconnected, Vec::<ConnectionId>::new());
|
||||
assert!(
|
||||
writer_rx.try_recv().is_err(),
|
||||
"opted-out notification should be dropped"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn broadcast_does_not_block_on_slow_connection() {
|
||||
let fast_connection_id = ConnectionId(1);
|
||||
let slow_connection_id = ConnectionId(2);
|
||||
|
||||
let (fast_writer_tx, mut fast_writer_rx) = mpsc::channel(1);
|
||||
let (slow_writer_tx, mut slow_writer_rx) = mpsc::channel(1);
|
||||
let fast_disconnect_token = CancellationToken::new();
|
||||
let slow_disconnect_token = CancellationToken::new();
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
fast_connection_id,
|
||||
OutboundConnectionState::new(
|
||||
fast_writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
Some(fast_disconnect_token.clone()),
|
||||
),
|
||||
);
|
||||
connections.insert(
|
||||
slow_connection_id,
|
||||
OutboundConnectionState::new(
|
||||
slow_writer_tx.clone(),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
Some(slow_disconnect_token.clone()),
|
||||
),
|
||||
);
|
||||
|
||||
let queued_message =
|
||||
OutgoingMessage::Notification(crate::outgoing_message::OutgoingNotification {
|
||||
method: "codex/event/already-buffered".to_string(),
|
||||
params: None,
|
||||
});
|
||||
slow_writer_tx
|
||||
.try_send(queued_message)
|
||||
.expect("channel should have room");
|
||||
|
||||
let broadcast_message =
|
||||
OutgoingMessage::Notification(crate::outgoing_message::OutgoingNotification {
|
||||
method: "codex/event/test".to_string(),
|
||||
params: None,
|
||||
});
|
||||
timeout(
|
||||
Duration::from_millis(100),
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::Broadcast {
|
||||
message: broadcast_message,
|
||||
},
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("broadcast should not block on a full writer");
|
||||
assert!(!connections.contains_key(&slow_connection_id));
|
||||
assert!(slow_disconnect_token.is_cancelled());
|
||||
assert!(!fast_disconnect_token.is_cancelled());
|
||||
let fast_message = fast_writer_rx
|
||||
.try_recv()
|
||||
.expect("fast connection should receive broadcast");
|
||||
assert!(matches!(
|
||||
fast_message,
|
||||
OutgoingMessage::Notification(crate::outgoing_message::OutgoingNotification {
|
||||
method,
|
||||
params: None,
|
||||
}) if method == "codex/event/test"
|
||||
));
|
||||
|
||||
let slow_message = slow_writer_rx
|
||||
.try_recv()
|
||||
.expect("slow connection should retain its original buffered message");
|
||||
assert!(matches!(
|
||||
slow_message,
|
||||
OutgoingMessage::Notification(crate::outgoing_message::OutgoingNotification {
|
||||
method,
|
||||
params: None,
|
||||
}) if method == "codex/event/already-buffered"
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn to_connection_stdio_waits_instead_of_disconnecting_when_writer_queue_is_full() {
|
||||
let connection_id = ConnectionId(3);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
writer_tx
|
||||
.send(OutgoingMessage::Notification(
|
||||
crate::outgoing_message::OutgoingNotification {
|
||||
method: "queued".to_string(),
|
||||
params: None,
|
||||
},
|
||||
))
|
||||
.await
|
||||
.expect("channel should accept the first queued message");
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
None,
|
||||
),
|
||||
);
|
||||
|
||||
let route_task = tokio::spawn(async move {
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::Notification(
|
||||
crate::outgoing_message::OutgoingNotification {
|
||||
method: "second".to_string(),
|
||||
params: None,
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
let first = timeout(Duration::from_millis(100), writer_rx.recv())
|
||||
.await
|
||||
.expect("first queued message should be readable")
|
||||
.expect("first queued message should exist");
|
||||
let second = timeout(Duration::from_millis(100), writer_rx.recv())
|
||||
.await
|
||||
.expect("second message should eventually be delivered")
|
||||
.expect("second message should exist");
|
||||
|
||||
timeout(Duration::from_millis(100), route_task)
|
||||
.await
|
||||
.expect("routing should finish after writer drains")
|
||||
.expect("routing task should succeed");
|
||||
|
||||
assert!(matches!(
|
||||
first,
|
||||
OutgoingMessage::Notification(crate::outgoing_message::OutgoingNotification {
|
||||
method,
|
||||
params: None,
|
||||
}) if method == "queued"
|
||||
));
|
||||
assert!(matches!(
|
||||
second,
|
||||
OutgoingMessage::Notification(crate::outgoing_message::OutgoingNotification {
|
||||
method,
|
||||
params: None,
|
||||
}) if method == "second"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user