mirror of
https://github.com/openai/codex.git
synced 2026-04-24 06:35:50 +00:00
app-server: Add back pressure and batching to command/exec (#15547)
* Add `OutgoingMessageSender::send_server_notification_to_connection_and_wait` which returns only once message is written to websocket (or failed to do so) * Use this mechanism to apply back pressure to stdout/stderr streams of processes spawned by `command/exec`, to limit them to at most one message in-memory at a time * Use back pressure signal to also batch smaller chunks into ≈64KiB ones This should make commands execution more robust over high-latency/low-throughput networks
This commit is contained in:
committed by
GitHub
parent
daf5e584c2
commit
d61c03ca08
@@ -8871,6 +8871,7 @@ mod tests {
|
||||
request_id: sent_request_id,
|
||||
..
|
||||
}),
|
||||
..
|
||||
} = request_message
|
||||
else {
|
||||
panic!("expected tool request to be sent to the subscribed connection");
|
||||
|
||||
@@ -42,6 +42,7 @@ use crate::outgoing_message::ConnectionRequestId;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
||||
const EXEC_TIMEOUT_EXIT_CODE: i32 = 124;
|
||||
const OUTPUT_CHUNK_SIZE_HINT: usize = 64 * 1024;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct CommandExecManager {
|
||||
@@ -577,13 +578,19 @@ fn spawn_process_output(params: SpawnProcessOutputParams) -> tokio::task::JoinHa
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
let mut observed_num_bytes = 0usize;
|
||||
loop {
|
||||
let chunk = tokio::select! {
|
||||
let mut chunk = tokio::select! {
|
||||
chunk = output_rx.recv() => match chunk {
|
||||
Some(chunk) => chunk,
|
||||
None => break,
|
||||
},
|
||||
_ = stdio_timeout_rx.wait_for(|&v| v) => break,
|
||||
};
|
||||
// Individual chunks are at most 8KiB, so overshooting a bit is acceptable.
|
||||
while chunk.len() < OUTPUT_CHUNK_SIZE_HINT
|
||||
&& let Ok(next_chunk) = output_rx.try_recv()
|
||||
{
|
||||
chunk.extend_from_slice(&next_chunk);
|
||||
}
|
||||
let capped_chunk = match output_bytes_cap {
|
||||
Some(output_bytes_cap) => {
|
||||
let capped_chunk_len = output_bytes_cap
|
||||
@@ -597,8 +604,8 @@ fn spawn_process_output(params: SpawnProcessOutputParams) -> tokio::task::JoinHa
|
||||
let cap_reached = Some(observed_num_bytes) == output_bytes_cap;
|
||||
if let (true, Some(process_id)) = (stream_output, process_id.as_ref()) {
|
||||
outgoing
|
||||
.send_server_notification_to_connections(
|
||||
&[connection_id],
|
||||
.send_server_notification_to_connection_and_wait(
|
||||
connection_id,
|
||||
ServerNotification::CommandExecOutputDelta(
|
||||
CommandExecOutputDeltaNotification {
|
||||
process_id: process_id.clone(),
|
||||
@@ -809,6 +816,7 @@ mod tests {
|
||||
let OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
..
|
||||
} = envelope
|
||||
else {
|
||||
panic!("expected connection-scoped outgoing message");
|
||||
@@ -891,6 +899,7 @@ mod tests {
|
||||
let OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
..
|
||||
} = envelope
|
||||
else {
|
||||
panic!("expected connection-scoped outgoing message");
|
||||
|
||||
@@ -60,6 +60,7 @@ use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::OutgoingEnvelope;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use crate::transport::CHANNEL_CAPACITY;
|
||||
use crate::transport::OutboundConnectionState;
|
||||
use crate::transport::route_outgoing_envelope;
|
||||
@@ -353,7 +354,7 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingEnvelope>(channel_capacity);
|
||||
let outgoing_message_sender = Arc::new(OutgoingMessageSender::new(outgoing_tx));
|
||||
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<OutgoingMessage>(channel_capacity);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(channel_capacity);
|
||||
let outbound_initialized = Arc::new(AtomicBool::new(false));
|
||||
let outbound_experimental_api_enabled = Arc::new(AtomicBool::new(false));
|
||||
let outbound_opted_out_notification_methods = Arc::new(RwLock::new(HashSet::new()));
|
||||
@@ -547,10 +548,11 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
|
||||
}
|
||||
}
|
||||
}
|
||||
outgoing_message = writer_rx.recv() => {
|
||||
let Some(outgoing_message) = outgoing_message else {
|
||||
queued_message = writer_rx.recv() => {
|
||||
let Some(queued_message) = queued_message else {
|
||||
break;
|
||||
};
|
||||
let outgoing_message = queued_message.message;
|
||||
match outgoing_message {
|
||||
OutgoingMessage::Response(response) => {
|
||||
if let Some(response_tx) = pending_request_responses.remove(&response.id) {
|
||||
@@ -629,6 +631,9 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(write_complete_tx) = queued_message.write_complete_tx {
|
||||
let _ = write_complete_tx.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ use crate::message_processor::MessageProcessorArgs;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::OutgoingEnvelope;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use crate::transport::CHANNEL_CAPACITY;
|
||||
use crate::transport::ConnectionState;
|
||||
use crate::transport::OutboundConnectionState;
|
||||
@@ -103,7 +104,7 @@ enum OutboundControlEvent {
|
||||
/// Register a new writer for an opened connection.
|
||||
Opened {
|
||||
connection_id: ConnectionId,
|
||||
writer: mpsc::Sender<crate::outgoing_message::OutgoingMessage>,
|
||||
writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
initialized: Arc<AtomicBool>,
|
||||
experimental_api_enabled: Arc<AtomicBool>,
|
||||
|
||||
@@ -390,6 +390,7 @@ async fn read_response<T: serde::de::DeserializeOwned>(
|
||||
let crate::outgoing_message::OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
..
|
||||
} = envelope
|
||||
else {
|
||||
continue;
|
||||
@@ -420,6 +421,7 @@ async fn read_thread_started_notification(
|
||||
crate::outgoing_message::OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
..
|
||||
} => {
|
||||
if connection_id != TEST_CONNECTION_ID {
|
||||
continue;
|
||||
|
||||
@@ -81,17 +81,33 @@ impl RequestContext {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum OutgoingEnvelope {
|
||||
ToConnection {
|
||||
connection_id: ConnectionId,
|
||||
message: OutgoingMessage,
|
||||
write_complete_tx: Option<oneshot::Sender<()>>,
|
||||
},
|
||||
Broadcast {
|
||||
message: OutgoingMessage,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct QueuedOutgoingMessage {
|
||||
pub(crate) message: OutgoingMessage,
|
||||
pub(crate) write_complete_tx: Option<oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
impl QueuedOutgoingMessage {
|
||||
pub(crate) fn new(message: OutgoingMessage) -> Self {
|
||||
Self {
|
||||
message,
|
||||
write_complete_tx: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends messages to the client and manages request callbacks.
|
||||
pub(crate) struct OutgoingMessageSender {
|
||||
next_server_request_id: AtomicI64,
|
||||
@@ -299,6 +315,7 @@ impl OutgoingMessageSender {
|
||||
.send(OutgoingEnvelope::ToConnection {
|
||||
connection_id: *connection_id,
|
||||
message: outgoing_message.clone(),
|
||||
write_complete_tx: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
@@ -333,6 +350,7 @@ impl OutgoingMessageSender {
|
||||
.send(OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::Request(request),
|
||||
write_complete_tx: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
@@ -519,6 +537,7 @@ impl OutgoingMessageSender {
|
||||
.send(OutgoingEnvelope::ToConnection {
|
||||
connection_id: *connection_id,
|
||||
message: outgoing_message.clone(),
|
||||
write_complete_tx: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
@@ -527,6 +546,28 @@ impl OutgoingMessageSender {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn send_server_notification_to_connection_and_wait(
|
||||
&self,
|
||||
connection_id: ConnectionId,
|
||||
notification: ServerNotification,
|
||||
) {
|
||||
tracing::trace!("app-server event: {notification}");
|
||||
let outgoing_message = OutgoingMessage::AppServerNotification(notification);
|
||||
let (write_complete_tx, write_complete_rx) = oneshot::channel();
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: outgoing_message,
|
||||
write_complete_tx: Some(write_complete_tx),
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send server notification to client: {err:?}");
|
||||
}
|
||||
let _ = write_complete_rx.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn send_error(
|
||||
&self,
|
||||
request_id: ConnectionRequestId,
|
||||
@@ -566,6 +607,7 @@ impl OutgoingMessageSender {
|
||||
let send_fut = self.sender.send(OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
write_complete_tx: None,
|
||||
});
|
||||
let send_result = if let Some(request_context) = request_context {
|
||||
send_fut.instrument(request_context.span()).await
|
||||
@@ -818,6 +860,7 @@ mod tests {
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(connection_id, ConnectionId(42));
|
||||
let OutgoingMessage::Response(response) = message else {
|
||||
@@ -880,6 +923,7 @@ mod tests {
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(connection_id, ConnectionId(9));
|
||||
let OutgoingMessage::Error(outgoing_error) = message else {
|
||||
@@ -892,6 +936,50 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_server_notification_to_connection_and_wait_tracks_write_completion() {
|
||||
let (tx, mut rx) = mpsc::channel::<OutgoingEnvelope>(4);
|
||||
let outgoing = OutgoingMessageSender::new(tx);
|
||||
let send_task = tokio::spawn(async move {
|
||||
outgoing
|
||||
.send_server_notification_to_connection_and_wait(
|
||||
ConnectionId(42),
|
||||
ServerNotification::ModelRerouted(ModelReroutedNotification {
|
||||
thread_id: "thread-1".to_string(),
|
||||
turn_id: "turn-1".to_string(),
|
||||
from_model: "gpt-5.3-codex".to_string(),
|
||||
to_model: "gpt-5.2".to_string(),
|
||||
reason: ModelRerouteReason::HighRiskCyberActivity,
|
||||
}),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
let envelope = timeout(Duration::from_secs(1), rx.recv())
|
||||
.await
|
||||
.expect("should receive envelope before timeout")
|
||||
.expect("channel should contain one message");
|
||||
let OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
write_complete_tx,
|
||||
} = envelope
|
||||
else {
|
||||
panic!("expected targeted server notification envelope");
|
||||
};
|
||||
assert_eq!(connection_id, ConnectionId(42));
|
||||
assert!(matches!(message, OutgoingMessage::AppServerNotification(_)));
|
||||
write_complete_tx
|
||||
.expect("write completion sender should be attached")
|
||||
.send(())
|
||||
.expect("receiver should still be waiting");
|
||||
|
||||
timeout(Duration::from_secs(1), send_task)
|
||||
.await
|
||||
.expect("send task should finish after write completion is signaled")
|
||||
.expect("send task should not panic");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connection_closed_clears_registered_request_contexts() {
|
||||
let (tx, _rx) = mpsc::channel::<OutgoingEnvelope>(4);
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::OutgoingEnvelope;
|
||||
use crate::outgoing_message::OutgoingError;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use axum::Router;
|
||||
use axum::body::Body;
|
||||
use axum::extract::ConnectInfo;
|
||||
@@ -187,7 +188,7 @@ impl FromStr for AppServerTransport {
|
||||
pub(crate) enum TransportEvent {
|
||||
ConnectionOpened {
|
||||
connection_id: ConnectionId,
|
||||
writer: mpsc::Sender<OutgoingMessage>,
|
||||
writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
},
|
||||
ConnectionClosed {
|
||||
@@ -225,13 +226,13 @@ pub(crate) struct OutboundConnectionState {
|
||||
pub(crate) initialized: Arc<AtomicBool>,
|
||||
pub(crate) experimental_api_enabled: Arc<AtomicBool>,
|
||||
pub(crate) opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
pub(crate) writer: mpsc::Sender<OutgoingMessage>,
|
||||
pub(crate) writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
}
|
||||
|
||||
impl OutboundConnectionState {
|
||||
pub(crate) fn new(
|
||||
writer: mpsc::Sender<OutgoingMessage>,
|
||||
writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
initialized: Arc<AtomicBool>,
|
||||
experimental_api_enabled: Arc<AtomicBool>,
|
||||
opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
@@ -262,7 +263,7 @@ pub(crate) async fn start_stdio_connection(
|
||||
stdio_handles: &mut Vec<JoinHandle<()>>,
|
||||
) -> IoResult<()> {
|
||||
let connection_id = ConnectionId(0);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let writer_tx_for_reader = writer_tx.clone();
|
||||
transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
@@ -309,8 +310,8 @@ pub(crate) async fn start_stdio_connection(
|
||||
|
||||
stdio_handles.push(tokio::spawn(async move {
|
||||
let mut stdout = io::stdout();
|
||||
while let Some(outgoing_message) = writer_rx.recv().await {
|
||||
let Some(mut json) = serialize_outgoing_message(outgoing_message) else {
|
||||
while let Some(queued_message) = writer_rx.recv().await {
|
||||
let Some(mut json) = serialize_outgoing_message(queued_message.message) else {
|
||||
continue;
|
||||
};
|
||||
json.push('\n');
|
||||
@@ -318,6 +319,9 @@ pub(crate) async fn start_stdio_connection(
|
||||
error!("Failed to write to stdout: {err}");
|
||||
break;
|
||||
}
|
||||
if let Some(write_complete_tx) = queued_message.write_complete_tx {
|
||||
let _ = write_complete_tx.send(());
|
||||
}
|
||||
}
|
||||
info!("stdout writer exited (channel closed)");
|
||||
}));
|
||||
@@ -364,7 +368,7 @@ async fn run_websocket_connection(
|
||||
websocket_stream: WebSocket,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
) {
|
||||
let (writer_tx, writer_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let (writer_tx, writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let writer_tx_for_reader = writer_tx.clone();
|
||||
let disconnect_token = CancellationToken::new();
|
||||
if transport_event_tx
|
||||
@@ -415,7 +419,7 @@ async fn run_websocket_connection(
|
||||
|
||||
async fn run_websocket_outbound_loop(
|
||||
mut websocket_writer: futures::stream::SplitSink<WebSocket, WebSocketMessage>,
|
||||
mut writer_rx: mpsc::Receiver<OutgoingMessage>,
|
||||
mut writer_rx: mpsc::Receiver<QueuedOutgoingMessage>,
|
||||
mut writer_control_rx: mpsc::Receiver<WebSocketMessage>,
|
||||
disconnect_token: CancellationToken,
|
||||
) {
|
||||
@@ -432,16 +436,19 @@ async fn run_websocket_outbound_loop(
|
||||
break;
|
||||
}
|
||||
}
|
||||
outgoing_message = writer_rx.recv() => {
|
||||
let Some(outgoing_message) = outgoing_message else {
|
||||
queued_message = writer_rx.recv() => {
|
||||
let Some(queued_message) = queued_message else {
|
||||
break;
|
||||
};
|
||||
let Some(json) = serialize_outgoing_message(outgoing_message) else {
|
||||
let Some(json) = serialize_outgoing_message(queued_message.message) else {
|
||||
continue;
|
||||
};
|
||||
if websocket_writer.send(WebSocketMessage::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
if let Some(write_complete_tx) = queued_message.write_complete_tx {
|
||||
let _ = write_complete_tx.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -450,7 +457,7 @@ async fn run_websocket_outbound_loop(
|
||||
async fn run_websocket_inbound_loop(
|
||||
mut websocket_reader: futures::stream::SplitStream<WebSocket>,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
writer_tx_for_reader: mpsc::Sender<OutgoingMessage>,
|
||||
writer_tx_for_reader: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
writer_control_tx: mpsc::Sender<WebSocketMessage>,
|
||||
connection_id: ConnectionId,
|
||||
disconnect_token: CancellationToken,
|
||||
@@ -501,7 +508,7 @@ async fn run_websocket_inbound_loop(
|
||||
|
||||
async fn forward_incoming_message(
|
||||
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||
writer: &mpsc::Sender<OutgoingMessage>,
|
||||
writer: &mpsc::Sender<QueuedOutgoingMessage>,
|
||||
connection_id: ConnectionId,
|
||||
payload: &str,
|
||||
) -> bool {
|
||||
@@ -518,7 +525,7 @@ async fn forward_incoming_message(
|
||||
|
||||
async fn enqueue_incoming_message(
|
||||
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||
writer: &mpsc::Sender<OutgoingMessage>,
|
||||
writer: &mpsc::Sender<QueuedOutgoingMessage>,
|
||||
connection_id: ConnectionId,
|
||||
message: JSONRPCMessage,
|
||||
) -> bool {
|
||||
@@ -541,7 +548,7 @@ async fn enqueue_incoming_message(
|
||||
data: None,
|
||||
},
|
||||
});
|
||||
match writer.try_send(overload_error) {
|
||||
match writer.try_send(QueuedOutgoingMessage::new(overload_error)) {
|
||||
Ok(()) => true,
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => false,
|
||||
Err(mpsc::error::TrySendError::Full(_overload_error)) => {
|
||||
@@ -607,6 +614,7 @@ async fn send_message_to_connection(
|
||||
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||
connection_id: ConnectionId,
|
||||
message: OutgoingMessage,
|
||||
write_complete_tx: Option<tokio::sync::oneshot::Sender<()>>,
|
||||
) -> bool {
|
||||
let Some(connection_state) = connections.get(&connection_id) else {
|
||||
warn!("dropping message for disconnected connection: {connection_id:?}");
|
||||
@@ -618,8 +626,12 @@ async fn send_message_to_connection(
|
||||
}
|
||||
|
||||
let writer = connection_state.writer.clone();
|
||||
let queued_message = QueuedOutgoingMessage {
|
||||
message,
|
||||
write_complete_tx,
|
||||
};
|
||||
if connection_state.can_disconnect() {
|
||||
match writer.try_send(message) {
|
||||
match writer.try_send(queued_message) {
|
||||
Ok(()) => false,
|
||||
Err(mpsc::error::TrySendError::Full(_)) => {
|
||||
warn!(
|
||||
@@ -631,7 +643,7 @@ async fn send_message_to_connection(
|
||||
disconnect_connection(connections, connection_id)
|
||||
}
|
||||
}
|
||||
} else if writer.send(message).await.is_err() {
|
||||
} else if writer.send(queued_message).await.is_err() {
|
||||
disconnect_connection(connections, connection_id)
|
||||
} else {
|
||||
false
|
||||
@@ -670,8 +682,11 @@ pub(crate) async fn route_outgoing_envelope(
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
write_complete_tx,
|
||||
} => {
|
||||
let _ = send_message_to_connection(connections, connection_id, message).await;
|
||||
let _ =
|
||||
send_message_to_connection(connections, connection_id, message, write_complete_tx)
|
||||
.await;
|
||||
}
|
||||
OutgoingEnvelope::Broadcast { message } => {
|
||||
let target_connections: Vec<ConnectionId> = connections
|
||||
@@ -688,8 +703,13 @@ pub(crate) async fn route_outgoing_envelope(
|
||||
.collect();
|
||||
|
||||
for connection_id in target_connections {
|
||||
let _ =
|
||||
send_message_to_connection(connections, connection_id, message.clone()).await;
|
||||
let _ = send_message_to_connection(
|
||||
connections,
|
||||
connection_id,
|
||||
message.clone(),
|
||||
/*write_complete_tx*/ None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -800,7 +820,8 @@ mod tests {
|
||||
.recv()
|
||||
.await
|
||||
.expect("request should receive overload error");
|
||||
let overload_json = serde_json::to_value(overload).expect("serialize overload error");
|
||||
let overload_json =
|
||||
serde_json::to_value(overload.message).expect("serialize overload error");
|
||||
assert_eq!(
|
||||
overload_json,
|
||||
json!({
|
||||
@@ -904,13 +925,15 @@ mod tests {
|
||||
.expect("transport queue should accept first message");
|
||||
|
||||
writer_tx
|
||||
.send(OutgoingMessage::AppServerNotification(
|
||||
ServerNotification::ConfigWarning(ConfigWarningNotification {
|
||||
summary: "queued".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
}),
|
||||
.send(QueuedOutgoingMessage::new(
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "queued".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
))
|
||||
.await
|
||||
.expect("writer queue should accept first message");
|
||||
@@ -934,7 +957,8 @@ mod tests {
|
||||
.recv()
|
||||
.await
|
||||
.expect("writer queue should still contain original message");
|
||||
let queued_json = serde_json::to_value(queued_outgoing).expect("serialize queued message");
|
||||
let queued_json =
|
||||
serde_json::to_value(queued_outgoing.message).expect("serialize queued message");
|
||||
assert_eq!(
|
||||
queued_json,
|
||||
json!({
|
||||
@@ -979,6 +1003,7 @@ mod tests {
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -989,6 +1014,92 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn to_connection_notifications_are_dropped_for_opted_out_clients() {
|
||||
let connection_id = ConnectionId(10);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::from(["configWarning".to_string()]))),
|
||||
None,
|
||||
),
|
||||
);
|
||||
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "task_started".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
writer_rx.try_recv().is_err(),
|
||||
"opted-out notifications should not reach clients"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn to_connection_notifications_are_preserved_for_non_opted_out_clients() {
|
||||
let connection_id = ConnectionId(11);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
None,
|
||||
),
|
||||
);
|
||||
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "task_started".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let message = writer_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("notification should reach non-opted-out clients");
|
||||
assert!(matches!(
|
||||
message.message,
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification { summary, .. }
|
||||
)) if summary == "task_started"
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn command_execution_request_approval_strips_experimental_fields_without_capability() {
|
||||
let connection_id = ConnectionId(8);
|
||||
@@ -1042,6 +1153,7 @@ mod tests {
|
||||
available_decisions: None,
|
||||
},
|
||||
}),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -1050,7 +1162,7 @@ mod tests {
|
||||
.recv()
|
||||
.await
|
||||
.expect("request should be delivered to the connection");
|
||||
let json = serde_json::to_value(message).expect("request should serialize");
|
||||
let json = serde_json::to_value(message.message).expect("request should serialize");
|
||||
assert_eq!(json["params"].get("additionalPermissions"), None);
|
||||
assert_eq!(json["params"].get("skillMetadata"), None);
|
||||
}
|
||||
@@ -1108,6 +1220,7 @@ mod tests {
|
||||
available_decisions: None,
|
||||
},
|
||||
}),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -1116,7 +1229,7 @@ mod tests {
|
||||
.recv()
|
||||
.await
|
||||
.expect("request should be delivered to the connection");
|
||||
let json = serde_json::to_value(message).expect("request should serialize");
|
||||
let json = serde_json::to_value(message.message).expect("request should serialize");
|
||||
let allowed_path = absolute_path("/tmp/allowed").to_string_lossy().into_owned();
|
||||
assert_eq!(
|
||||
json["params"]["additionalPermissions"],
|
||||
@@ -1178,7 +1291,7 @@ mod tests {
|
||||
}),
|
||||
);
|
||||
slow_writer_tx
|
||||
.try_send(queued_message)
|
||||
.try_send(QueuedOutgoingMessage::new(queued_message))
|
||||
.expect("channel should have room");
|
||||
|
||||
let broadcast_message = OutgoingMessage::AppServerNotification(
|
||||
@@ -1207,7 +1320,7 @@ mod tests {
|
||||
.try_recv()
|
||||
.expect("fast connection should receive the broadcast notification");
|
||||
assert!(matches!(
|
||||
fast_message,
|
||||
fast_message.message,
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification { summary, .. }
|
||||
)) if summary == "test"
|
||||
@@ -1217,7 +1330,7 @@ mod tests {
|
||||
.try_recv()
|
||||
.expect("slow connection should retain its original buffered message");
|
||||
assert!(matches!(
|
||||
slow_message,
|
||||
slow_message.message,
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification { summary, .. }
|
||||
)) if summary == "already-buffered"
|
||||
@@ -1229,13 +1342,15 @@ mod tests {
|
||||
let connection_id = ConnectionId(3);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
writer_tx
|
||||
.send(OutgoingMessage::AppServerNotification(
|
||||
ServerNotification::ConfigWarning(ConfigWarningNotification {
|
||||
summary: "queued".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
}),
|
||||
.send(QueuedOutgoingMessage::new(
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "queued".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
))
|
||||
.await
|
||||
.expect("channel should accept the first queued message");
|
||||
@@ -1265,6 +1380,7 @@ mod tests {
|
||||
range: None,
|
||||
}),
|
||||
),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
@@ -1280,7 +1396,7 @@ mod tests {
|
||||
.expect("routing task should succeed");
|
||||
|
||||
assert!(matches!(
|
||||
first,
|
||||
first.message,
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification { summary, .. }
|
||||
)) if summary == "queued"
|
||||
@@ -1289,7 +1405,7 @@ mod tests {
|
||||
.try_recv()
|
||||
.expect("second notification should be delivered once the queue has room");
|
||||
assert!(matches!(
|
||||
second,
|
||||
second.message,
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification { summary, .. }
|
||||
)) if summary == "second"
|
||||
|
||||
Reference in New Issue
Block a user