mirror of
https://github.com/openai/codex.git
synced 2026-05-01 18:06:47 +00:00
Compare commits
7 Commits
codex-fix/
...
etraut/web
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d051a4221 | ||
|
|
997a23330f | ||
|
|
bd282df268 | ||
|
|
64872b998a | ||
|
|
add5ac1325 | ||
|
|
d896f5cffc | ||
|
|
c850aa8d31 |
@@ -371,11 +371,13 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
|
||||
outbound_connections.insert(
|
||||
IN_PROCESS_CONNECTION_ID,
|
||||
OutboundConnectionState::new(
|
||||
IN_PROCESS_CONNECTION_ID,
|
||||
writer_tx,
|
||||
Arc::clone(&outbound_initialized),
|
||||
Arc::clone(&outbound_experimental_api_enabled),
|
||||
Arc::clone(&outbound_opted_out_notification_methods),
|
||||
/*disconnect_sender*/ None,
|
||||
/*disconnect_notifier*/ None,
|
||||
),
|
||||
);
|
||||
let mut outbound_handle = tokio::spawn(async move {
|
||||
|
||||
@@ -29,6 +29,7 @@ use crate::transport::ConnectionState;
|
||||
use crate::transport::OutboundConnectionState;
|
||||
use crate::transport::TransportEvent;
|
||||
use crate::transport::auth::policy_from_settings;
|
||||
use crate::transport::disconnect_connection;
|
||||
use crate::transport::route_outgoing_envelope;
|
||||
use crate::transport::start_remote_control;
|
||||
use crate::transport::start_stdio_connection;
|
||||
@@ -594,49 +595,62 @@ pub async fn run_main_with_transport(
|
||||
|
||||
let outbound_handle = tokio::spawn(async move {
|
||||
let mut outbound_connections = HashMap::<ConnectionId, OutboundConnectionState>::new();
|
||||
// Overflow workers run outside this router task. This side channel lets
|
||||
// them remove a slow connection from routing before the transport loop's
|
||||
// eventual ConnectionClosed event catches up.
|
||||
let (outbound_disconnect_tx, mut outbound_disconnect_rx) =
|
||||
mpsc::channel::<ConnectionId>(CHANNEL_CAPACITY);
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
event = outbound_control_rx.recv() => {
|
||||
let Some(event) = event else {
|
||||
break;
|
||||
};
|
||||
match event {
|
||||
OutboundControlEvent::Opened {
|
||||
biased;
|
||||
event = outbound_control_rx.recv() => {
|
||||
let Some(event) = event else {
|
||||
break;
|
||||
};
|
||||
match event {
|
||||
OutboundControlEvent::Opened {
|
||||
connection_id,
|
||||
writer,
|
||||
disconnect_sender,
|
||||
initialized,
|
||||
experimental_api_enabled,
|
||||
opted_out_notification_methods,
|
||||
} => {
|
||||
outbound_connections.insert(
|
||||
connection_id,
|
||||
writer,
|
||||
disconnect_sender,
|
||||
initialized,
|
||||
experimental_api_enabled,
|
||||
opted_out_notification_methods,
|
||||
} => {
|
||||
outbound_connections.insert(
|
||||
OutboundConnectionState::new(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer,
|
||||
initialized,
|
||||
experimental_api_enabled,
|
||||
opted_out_notification_methods,
|
||||
disconnect_sender,
|
||||
),
|
||||
);
|
||||
}
|
||||
OutboundControlEvent::Closed { connection_id } => {
|
||||
outbound_connections.remove(&connection_id);
|
||||
}
|
||||
OutboundControlEvent::DisconnectAll => {
|
||||
info!(
|
||||
"disconnecting {} outbound websocket connection(s) for graceful restart",
|
||||
outbound_connections.len()
|
||||
);
|
||||
for connection_state in outbound_connections.values() {
|
||||
connection_state.request_disconnect();
|
||||
}
|
||||
outbound_connections.clear();
|
||||
writer,
|
||||
initialized,
|
||||
experimental_api_enabled,
|
||||
opted_out_notification_methods,
|
||||
disconnect_sender,
|
||||
Some(outbound_disconnect_tx.clone()),
|
||||
),
|
||||
);
|
||||
}
|
||||
OutboundControlEvent::Closed { connection_id } => {
|
||||
outbound_connections.remove(&connection_id);
|
||||
}
|
||||
OutboundControlEvent::DisconnectAll => {
|
||||
info!(
|
||||
"disconnecting {} outbound websocket connection(s) for graceful restart",
|
||||
outbound_connections.len()
|
||||
);
|
||||
for connection_state in outbound_connections.values() {
|
||||
connection_state.request_disconnect();
|
||||
}
|
||||
outbound_connections.clear();
|
||||
}
|
||||
}
|
||||
envelope = outgoing_rx.recv() => {
|
||||
}
|
||||
connection_id = outbound_disconnect_rx.recv() => {
|
||||
let Some(connection_id) = connection_id else {
|
||||
break;
|
||||
};
|
||||
disconnect_connection(&mut outbound_connections, connection_id);
|
||||
}
|
||||
envelope = outgoing_rx.recv() => {
|
||||
let Some(envelope) = envelope else {
|
||||
break;
|
||||
};
|
||||
|
||||
@@ -18,7 +18,9 @@ use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::error;
|
||||
@@ -29,6 +31,11 @@ use tracing::warn;
|
||||
/// plenty for an interactive CLI.
|
||||
pub(crate) const CHANNEL_CAPACITY: usize = 128;
|
||||
|
||||
#[cfg(not(test))]
|
||||
const OUTBOUND_QUEUE_FULL_GRACE: Duration = Duration::from_secs(2);
|
||||
#[cfg(test)]
|
||||
const OUTBOUND_QUEUE_FULL_GRACE: Duration = Duration::from_millis(200);
|
||||
|
||||
mod remote_control;
|
||||
mod stdio;
|
||||
mod websocket;
|
||||
@@ -144,22 +151,75 @@ pub(crate) struct OutboundConnectionState {
|
||||
pub(crate) experimental_api_enabled: Arc<AtomicBool>,
|
||||
pub(crate) opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
pub(crate) writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
overflow_writer: Option<mpsc::Sender<QueuedOutgoingMessage>>,
|
||||
overflow_depth: Arc<AtomicUsize>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
}
|
||||
|
||||
impl OutboundConnectionState {
|
||||
pub(crate) fn new(
|
||||
connection_id: ConnectionId,
|
||||
writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
initialized: Arc<AtomicBool>,
|
||||
experimental_api_enabled: Arc<AtomicBool>,
|
||||
opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
disconnect_notifier: Option<mpsc::Sender<ConnectionId>>,
|
||||
) -> Self {
|
||||
let overflow_depth = Arc::new(AtomicUsize::new(0));
|
||||
let overflow_writer = disconnect_sender.as_ref().map(|disconnect_sender| {
|
||||
let (overflow_tx, mut overflow_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let writer = writer.clone();
|
||||
let disconnect_sender = disconnect_sender.clone();
|
||||
let disconnect_notifier = disconnect_notifier.clone();
|
||||
let overflow_depth = Arc::clone(&overflow_depth);
|
||||
tokio::spawn(async move {
|
||||
while let Some(queued_message) = overflow_rx.recv().await {
|
||||
match writer
|
||||
.send_timeout(queued_message, OUTBOUND_QUEUE_FULL_GRACE)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
overflow_depth.fetch_sub(1, Ordering::AcqRel);
|
||||
}
|
||||
Err(mpsc::error::SendTimeoutError::Timeout(_)) => {
|
||||
overflow_depth.fetch_sub(1, Ordering::AcqRel);
|
||||
warn!(
|
||||
"disconnecting slow connection after outbound queue remained full for {:?}: {connection_id:?}",
|
||||
OUTBOUND_QUEUE_FULL_GRACE
|
||||
);
|
||||
disconnect_sender.cancel();
|
||||
// The websocket task will eventually report ConnectionClosed,
|
||||
// but notify the outbound router now so no newer messages are
|
||||
// routed after this timed-out one is dropped.
|
||||
if let Some(disconnect_notifier) = &disconnect_notifier {
|
||||
let _ = disconnect_notifier.send(connection_id).await;
|
||||
}
|
||||
break;
|
||||
}
|
||||
Err(mpsc::error::SendTimeoutError::Closed(_)) => {
|
||||
overflow_depth.fetch_sub(1, Ordering::AcqRel);
|
||||
disconnect_sender.cancel();
|
||||
// Drop outbound routing state promptly even if the transport's
|
||||
// close event is delayed behind other incoming events.
|
||||
if let Some(disconnect_notifier) = &disconnect_notifier {
|
||||
let _ = disconnect_notifier.send(connection_id).await;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
overflow_tx
|
||||
});
|
||||
|
||||
Self {
|
||||
initialized,
|
||||
experimental_api_enabled,
|
||||
opted_out_notification_methods,
|
||||
writer,
|
||||
overflow_writer,
|
||||
overflow_depth,
|
||||
disconnect_sender,
|
||||
}
|
||||
}
|
||||
@@ -274,7 +334,7 @@ fn should_skip_notification_for_connection(
|
||||
}
|
||||
}
|
||||
|
||||
fn disconnect_connection(
|
||||
pub(crate) fn disconnect_connection(
|
||||
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||
connection_id: ConnectionId,
|
||||
) -> bool {
|
||||
@@ -306,16 +366,17 @@ async fn send_message_to_connection(
|
||||
write_complete_tx,
|
||||
};
|
||||
if connection_state.can_disconnect() {
|
||||
match writer.try_send(queued_message) {
|
||||
Ok(()) => false,
|
||||
Err(mpsc::error::TrySendError::Full(_)) => {
|
||||
warn!(
|
||||
"disconnecting slow connection after outbound queue filled: {connection_id:?}"
|
||||
);
|
||||
disconnect_connection(connections, connection_id)
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => {
|
||||
disconnect_connection(connections, connection_id)
|
||||
if connection_state.overflow_depth.load(Ordering::Acquire) > 0 {
|
||||
queue_overflow_message(connections, connection_id, queued_message).await
|
||||
} else {
|
||||
match writer.try_send(queued_message) {
|
||||
Ok(()) => false,
|
||||
Err(mpsc::error::TrySendError::Full(queued_message)) => {
|
||||
queue_overflow_message(connections, connection_id, queued_message).await
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => {
|
||||
disconnect_connection(connections, connection_id)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if writer.send(queued_message).await.is_err() {
|
||||
@@ -325,6 +386,58 @@ async fn send_message_to_connection(
|
||||
}
|
||||
}
|
||||
|
||||
async fn queue_overflow_message(
|
||||
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||
connection_id: ConnectionId,
|
||||
queued_message: QueuedOutgoingMessage,
|
||||
) -> bool {
|
||||
let Some(connection_state) = connections.get(&connection_id) else {
|
||||
warn!("dropping overflow message for disconnected connection: {connection_id:?}");
|
||||
return false;
|
||||
};
|
||||
let Some(overflow_writer) = connection_state.overflow_writer.clone() else {
|
||||
unreachable!("disconnectable connection must have an overflow writer");
|
||||
};
|
||||
let overflow_depth = Arc::clone(&connection_state.overflow_depth);
|
||||
|
||||
// WebSocket clients are marked disconnectable so a stuck writer cannot
|
||||
// block the outbound router forever. Still, normal turns can briefly burst
|
||||
// past the per-connection queue capacity while the writer task is healthy.
|
||||
// Queue the overflow on a bounded, ordered side channel so the router stays
|
||||
// non-blocking without creating unbounded detached send waiters.
|
||||
overflow_depth.fetch_add(1, Ordering::AcqRel);
|
||||
match overflow_writer.try_send(queued_message) {
|
||||
Ok(()) => false,
|
||||
Err(mpsc::error::TrySendError::Full(queued_message)) => {
|
||||
// Both bounded queues are full now. Give the overflow worker the
|
||||
// same grace window to make room before deciding this connection is
|
||||
// slow enough to disconnect.
|
||||
match overflow_writer
|
||||
.send_timeout(queued_message, OUTBOUND_QUEUE_FULL_GRACE)
|
||||
.await
|
||||
{
|
||||
Ok(()) => false,
|
||||
Err(mpsc::error::SendTimeoutError::Timeout(_)) => {
|
||||
overflow_depth.fetch_sub(1, Ordering::AcqRel);
|
||||
warn!(
|
||||
"disconnecting slow connection after outbound overflow queue remained full for {:?}: {connection_id:?}",
|
||||
OUTBOUND_QUEUE_FULL_GRACE
|
||||
);
|
||||
disconnect_connection(connections, connection_id)
|
||||
}
|
||||
Err(mpsc::error::SendTimeoutError::Closed(_)) => {
|
||||
overflow_depth.fetch_sub(1, Ordering::AcqRel);
|
||||
disconnect_connection(connections, connection_id)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => {
|
||||
overflow_depth.fetch_sub(1, Ordering::AcqRel);
|
||||
disconnect_connection(connections, connection_id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_outgoing_message_for_connection(
|
||||
connection_state: &OutboundConnectionState,
|
||||
message: OutgoingMessage,
|
||||
@@ -623,11 +736,13 @@ mod tests {
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
connection_id,
|
||||
writer_tx,
|
||||
initialized,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
opted_out_notification_methods,
|
||||
/*disconnect_sender*/ None,
|
||||
/*disconnect_notifier*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
@@ -663,11 +778,13 @@ mod tests {
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
connection_id,
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::from(["configWarning".to_string()]))),
|
||||
/*disconnect_sender*/ None,
|
||||
/*disconnect_notifier*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
@@ -703,11 +820,13 @@ mod tests {
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
connection_id,
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*disconnect_sender*/ None,
|
||||
/*disconnect_notifier*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
@@ -749,11 +868,13 @@ mod tests {
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
connection_id,
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*disconnect_sender*/ None,
|
||||
/*disconnect_notifier*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
@@ -811,11 +932,13 @@ mod tests {
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
connection_id,
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*disconnect_sender*/ None,
|
||||
/*disconnect_notifier*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
@@ -875,86 +998,349 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn broadcast_does_not_block_on_slow_connection() {
|
||||
let fast_connection_id = ConnectionId(1);
|
||||
let slow_connection_id = ConnectionId(2);
|
||||
async fn disconnectable_connection_waits_for_queue_to_drain() {
|
||||
let connection_id = ConnectionId(1);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
let disconnect_token = CancellationToken::new();
|
||||
|
||||
let (fast_writer_tx, mut fast_writer_rx) = mpsc::channel(1);
|
||||
let (slow_writer_tx, mut slow_writer_rx) = mpsc::channel(1);
|
||||
let fast_disconnect_token = CancellationToken::new();
|
||||
let slow_disconnect_token = CancellationToken::new();
|
||||
writer_tx
|
||||
.send(QueuedOutgoingMessage::new(
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "queued".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
))
|
||||
.await
|
||||
.expect("channel should accept the first queued message");
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
fast_connection_id,
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
fast_writer_tx,
|
||||
connection_id,
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
Some(fast_disconnect_token.clone()),
|
||||
),
|
||||
);
|
||||
connections.insert(
|
||||
slow_connection_id,
|
||||
OutboundConnectionState::new(
|
||||
slow_writer_tx.clone(),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
Some(slow_disconnect_token.clone()),
|
||||
Some(disconnect_token.clone()),
|
||||
/*disconnect_notifier*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
let queued_message = OutgoingMessage::AppServerNotification(
|
||||
ServerNotification::ConfigWarning(ConfigWarningNotification {
|
||||
summary: "already-buffered".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
}),
|
||||
);
|
||||
slow_writer_tx
|
||||
.try_send(QueuedOutgoingMessage::new(queued_message))
|
||||
.expect("channel should have room");
|
||||
|
||||
let broadcast_message = OutgoingMessage::AppServerNotification(
|
||||
ServerNotification::ConfigWarning(ConfigWarningNotification {
|
||||
summary: "test".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
}),
|
||||
);
|
||||
timeout(
|
||||
Duration::from_millis(100),
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::Broadcast {
|
||||
message: broadcast_message,
|
||||
},
|
||||
),
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "second".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("broadcast should return even when one connection is slow");
|
||||
assert!(!connections.contains_key(&slow_connection_id));
|
||||
assert!(slow_disconnect_token.is_cancelled());
|
||||
assert!(!fast_disconnect_token.is_cancelled());
|
||||
let fast_message = fast_writer_rx
|
||||
.try_recv()
|
||||
.expect("fast connection should receive the broadcast notification");
|
||||
.await;
|
||||
|
||||
let first = writer_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("first queued message should be readable");
|
||||
let second = timeout(Duration::from_millis(100), writer_rx.recv())
|
||||
.await
|
||||
.expect("second notification should be delivered after queue capacity returns")
|
||||
.expect("second notification should exist");
|
||||
|
||||
assert!(connections.contains_key(&connection_id));
|
||||
assert!(!disconnect_token.is_cancelled());
|
||||
assert!(matches!(
|
||||
fast_message.message,
|
||||
first.message,
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification { summary, .. }
|
||||
)) if summary == "test"
|
||||
)) if summary == "queued"
|
||||
));
|
||||
|
||||
let slow_message = slow_writer_rx
|
||||
.try_recv()
|
||||
.expect("slow connection should retain its original buffered message");
|
||||
assert!(matches!(
|
||||
slow_message.message,
|
||||
second.message,
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification { summary, .. }
|
||||
)) if summary == "second"
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn disconnectable_connection_preserves_order_while_overflow_is_draining() {
|
||||
let connection_id = ConnectionId(12);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
let disconnect_token = CancellationToken::new();
|
||||
|
||||
writer_tx
|
||||
.send(QueuedOutgoingMessage::new(
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "queued".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
))
|
||||
.await
|
||||
.expect("channel should accept the first queued message");
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
connection_id,
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
Some(disconnect_token.clone()),
|
||||
/*disconnect_notifier*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
for summary in ["second", "third"] {
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::AppServerNotification(
|
||||
ServerNotification::ConfigWarning(ConfigWarningNotification {
|
||||
summary: summary.to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
}),
|
||||
),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let mut summaries = Vec::new();
|
||||
for _ in 0..3 {
|
||||
let message = timeout(Duration::from_millis(100), writer_rx.recv())
|
||||
.await
|
||||
.expect("queued notification should be delivered")
|
||||
.expect("queued notification should exist");
|
||||
let OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification { summary, .. },
|
||||
)) = message.message
|
||||
else {
|
||||
panic!("expected config warning notification");
|
||||
};
|
||||
summaries.push(summary);
|
||||
}
|
||||
|
||||
assert_eq!(summaries, vec!["queued", "second", "third"]);
|
||||
assert!(connections.contains_key(&connection_id));
|
||||
assert!(!disconnect_token.is_cancelled());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn disconnectable_connection_applies_grace_when_overflow_queue_fills() {
|
||||
let connection_id = ConnectionId(13);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
let disconnect_token = CancellationToken::new();
|
||||
|
||||
writer_tx
|
||||
.send(QueuedOutgoingMessage::new(
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "already-buffered".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
))
|
||||
.await
|
||||
.expect("channel should accept the first queued message");
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
connection_id,
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
Some(disconnect_token.clone()),
|
||||
/*disconnect_notifier*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "overflow-active".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
for index in 0..CHANNEL_CAPACITY {
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::AppServerNotification(
|
||||
ServerNotification::ConfigWarning(ConfigWarningNotification {
|
||||
summary: format!("overflow-{index}"),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
}),
|
||||
),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
assert!(connections.contains_key(&connection_id));
|
||||
assert!(!disconnect_token.is_cancelled());
|
||||
|
||||
let mut route_task = tokio::spawn(async move {
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::AppServerNotification(
|
||||
ServerNotification::ConfigWarning(ConfigWarningNotification {
|
||||
summary: "too-many".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
}),
|
||||
),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
connections
|
||||
});
|
||||
|
||||
assert!(
|
||||
timeout(Duration::from_millis(50), &mut route_task)
|
||||
.await
|
||||
.is_err(),
|
||||
"saturated overflow queue should be given a grace window before disconnecting"
|
||||
);
|
||||
assert!(!disconnect_token.is_cancelled());
|
||||
|
||||
let connections = timeout(
|
||||
OUTBOUND_QUEUE_FULL_GRACE + Duration::from_millis(100),
|
||||
route_task,
|
||||
)
|
||||
.await
|
||||
.expect("saturated overflow queue should eventually disconnect")
|
||||
.expect("routing task should not panic");
|
||||
|
||||
assert!(!connections.contains_key(&connection_id));
|
||||
assert!(disconnect_token.is_cancelled());
|
||||
let original_message = writer_rx
|
||||
.try_recv()
|
||||
.expect("full queue should retain its original buffered message");
|
||||
assert!(matches!(
|
||||
original_message.message,
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification { summary, .. }
|
||||
)) if summary == "already-buffered"
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn disconnectable_connection_requests_disconnect_after_queue_grace_expires() {
|
||||
let connection_id = ConnectionId(2);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
let (disconnect_notifier_tx, mut disconnect_notifier_rx) = mpsc::channel(1);
|
||||
let disconnect_token = CancellationToken::new();
|
||||
|
||||
writer_tx
|
||||
.send(QueuedOutgoingMessage::new(
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "already-buffered".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
))
|
||||
.await
|
||||
.expect("channel should accept the first queued message");
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
connection_id,
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
Some(disconnect_token.clone()),
|
||||
/*disconnect_notifier*/ Some(disconnect_notifier_tx),
|
||||
),
|
||||
);
|
||||
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "second".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(connections.contains_key(&connection_id));
|
||||
let notified_connection_id = timeout(
|
||||
OUTBOUND_QUEUE_FULL_GRACE + Duration::from_millis(100),
|
||||
disconnect_notifier_rx.recv(),
|
||||
)
|
||||
.await
|
||||
.expect("full queue should notify the router after the grace expires")
|
||||
.expect("disconnect notification should contain a connection id");
|
||||
assert_eq!(notified_connection_id, connection_id);
|
||||
assert!(disconnect_connection(
|
||||
&mut connections,
|
||||
notified_connection_id
|
||||
));
|
||||
assert!(!connections.contains_key(&connection_id));
|
||||
assert!(disconnect_token.is_cancelled());
|
||||
let original_message = writer_rx
|
||||
.try_recv()
|
||||
.expect("full queue should retain its original buffered message");
|
||||
assert!(matches!(
|
||||
original_message.message,
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification { summary, .. }
|
||||
)) if summary == "already-buffered"
|
||||
@@ -983,11 +1369,13 @@ mod tests {
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
connection_id,
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*disconnect_sender*/ None,
|
||||
/*disconnect_notifier*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user