app-server: switch remote control to protocol v3 segmentation (#20341)

## Why

Remote-control protocol v3 makes segmentation an explicit wire-level
feature. The app-server transport needs to support that protocol
directly so large messages can be chunked, acknowledged, replayed, and
reassembled consistently.

## What changed

- Bump the remote-control websocket protocol version from `2` to `3`.
- Add explicit client/server chunk envelope variants plus chunk-aware
acknowledgements.
- Split oversized outbound server messages into bounded transport
chunks.
- Reassemble ordered inbound client chunks with bounded memory usage and
stream/client invalidation handling.
- Track inbound chunk cursors and outbound ack cursors as `(seq_id,
segment_id)` so duplicate chunks and partial replays behave correctly.
- Add focused coverage for chunk splitting, reassembly, duplicate
suppression, and stream replacement behavior.

## Validation

- Added targeted unit coverage for segmented message handling in
`remote_control`.
- Local validation is currently blocked before compilation because
`packageproxy` does not serve the locked `rustls-webpki 0.103.13`
dependency required by the workspace.
This commit is contained in:
Ruslan Nigmatullin
2026-04-30 18:27:16 -07:00
committed by GitHub
parent af089fb21d
commit 972b819213
7 changed files with 1506 additions and 31 deletions

View File

@@ -195,7 +195,7 @@ impl ClientTracker {
})
.await
}
ClientEvent::Ack => Ok(()),
ClientEvent::ClientMessageChunk { .. } | ClientEvent::Ack { .. } => Ok(()),
ClientEvent::Ping => {
if let Some(client) = self.clients.get_mut(&client_key) {
client.last_activity_at = Instant::now();

View File

@@ -1,6 +1,7 @@
mod client_tracker;
mod enroll;
mod protocol;
mod segment;
mod websocket;
use crate::transport::remote_control::websocket::RemoteControlChannels;
@@ -121,5 +122,7 @@ pub(crate) async fn start_remote_control(
))
}
#[cfg(test)]
mod segment_tests;
#[cfg(test)]
mod tests;

View File

@@ -47,10 +47,20 @@ pub enum ClientEvent {
ClientMessage {
message: JSONRPCMessage,
},
ClientMessageChunk {
segment_id: usize,
segment_count: usize,
message_size_bytes: usize,
message_chunk_base64: String,
},
/// Backend-generated acknowledgement for all server envelopes addressed to
/// `client_id` and `stream_id` whose envelope `seq_id` is less than or equal
/// to this ack's `seq_id`. This cursor is stream-scoped.
Ack,
/// to this ack's `seq_id`. Chunk acknowledgements carry `segment_id` so the
/// sender can retain only the still-unacked wire chunks on reconnect.
Ack {
#[serde(skip_serializing_if = "Option::is_none")]
segment_id: Option<usize>,
},
Ping,
ClientClosed,
}
@@ -85,6 +95,12 @@ pub enum ServerEvent {
ServerMessage {
message: Box<OutgoingMessage>,
},
ServerMessageChunk {
segment_id: usize,
segment_count: usize,
message_size_bytes: usize,
message_chunk_base64: String,
},
#[allow(dead_code)]
Ack,
Pong {
@@ -92,6 +108,15 @@ pub enum ServerEvent {
},
}
impl ServerEvent {
pub(crate) fn segment_id(&self) -> Option<usize> {
match self {
Self::ServerMessageChunk { segment_id, .. } => Some(*segment_id),
Self::ServerMessage { .. } | Self::Ack | Self::Pong { .. } => None,
}
}
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case")]
pub(crate) struct ServerEnvelope {

View File

@@ -0,0 +1,449 @@
use super::protocol::ClientEnvelope;
use super::protocol::ClientEvent;
use super::protocol::ClientId;
use super::protocol::ServerEnvelope;
use super::protocol::ServerEvent;
use super::protocol::StreamId;
use base64::DecodeSliceError;
use base64::Engine;
use codex_app_server_protocol::JSONRPCMessage;
use std::collections::HashMap;
use std::io;
use std::io::ErrorKind;
use std::io::Write;
use tokio::time::Instant;
use tracing::warn;
pub(super) const REMOTE_CONTROL_SEGMENT_TARGET_BYTES: usize = 100 * 1024;
pub(super) const REMOTE_CONTROL_SEGMENT_MAX_BYTES: usize = 150 * 1024;
pub(super) const REMOTE_CONTROL_REASSEMBLED_MAX_BYTES: usize = 100 * 1024 * 1024;
pub(super) const REMOTE_CONTROL_SEGMENT_COUNT_MAX: usize = 1024;
const REMOTE_CONTROL_SEGMENT_ASSEMBLY_MAX_COUNT: usize = 128;
#[derive(Debug)]
struct ClientSegmentAssembly {
stream_id: StreamId,
metadata: ClientSegmentMetadata,
raw: Vec<u8>,
next_segment_id: usize,
last_chunk_seen_at: Instant,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct ClientSegmentMetadata {
seq_id: u64,
segment_count: usize,
message_size_bytes: usize,
}
#[derive(Default)]
pub(super) struct ClientSegmentReassembler {
assemblies: HashMap<ClientId, ClientSegmentAssembly>,
}
pub(super) enum ClientSegmentObservation {
Forward(Box<ClientEnvelope>),
Pending,
Dropped,
}
impl ClientSegmentReassembler {
pub(super) fn observe(&mut self, envelope: ClientEnvelope) -> ClientSegmentObservation {
let ClientEvent::ClientMessageChunk {
segment_id,
segment_count,
message_size_bytes,
message_chunk_base64,
} = &envelope.event
else {
return ClientSegmentObservation::Forward(Box::new(envelope));
};
let segment_id = *segment_id;
let segment_count = *segment_count;
let message_size_bytes = *message_size_bytes;
let Some(metadata) = ClientSegmentMetadata::from_envelope(&envelope) else {
warn!(
client_id = envelope.client_id.0.as_str(),
"dropping segmented remote-control client envelope without seq_id"
);
return ClientSegmentObservation::Dropped;
};
let Some(stream_id) = envelope.stream_id.clone() else {
warn!(
client_id = envelope.client_id.0.as_str(),
"dropping segmented remote-control client envelope without stream_id"
);
return ClientSegmentObservation::Dropped;
};
if self.should_ignore_chunk(&envelope.client_id, &stream_id, metadata.seq_id, segment_id) {
return ClientSegmentObservation::Dropped;
}
if segment_count == 0
|| segment_count > REMOTE_CONTROL_SEGMENT_COUNT_MAX
|| segment_id >= segment_count
|| message_size_bytes == 0
|| message_size_bytes > REMOTE_CONTROL_REASSEMBLED_MAX_BYTES
|| message_chunk_base64.is_empty()
{
warn!(
client_id = envelope.client_id.0.as_str(),
"dropping invalid segmented remote-control client envelope"
);
self.remove_assembly(&envelope.client_id, &stream_id);
return ClientSegmentObservation::Dropped;
}
let now = Instant::now();
match self.assemblies.get(&envelope.client_id) {
Some(assembly) if assembly.stream_id != stream_id => {
warn!(
client_id = envelope.client_id.0.as_str(),
"resetting segmented remote-control client envelope after stream change"
);
self.assemblies.insert(
envelope.client_id.clone(),
ClientSegmentAssembly {
stream_id: stream_id.clone(),
metadata: metadata.clone(),
raw: Vec::new(),
next_segment_id: 0,
last_chunk_seen_at: now,
},
);
}
Some(_) => {}
None => {
self.evict_assemblies_if_full();
self.assemblies.insert(
envelope.client_id.clone(),
ClientSegmentAssembly {
stream_id: stream_id.clone(),
metadata: metadata.clone(),
raw: Vec::new(),
next_segment_id: 0,
last_chunk_seen_at: now,
},
);
}
}
let result = {
let Some(assembly) = self.assemblies.get_mut(&envelope.client_id) else {
warn!(
client_id = envelope.client_id.0.as_str(),
"dropping segmented remote-control client envelope without assembly"
);
return ClientSegmentObservation::Dropped;
};
if metadata.seq_id < assembly.metadata.seq_id {
AssemblyUpdate::Ignore
} else if assembly.metadata != metadata {
warn!(
client_id = envelope.client_id.0.as_str(),
"resetting segmented remote-control client envelope after metadata mismatch"
);
AssemblyUpdate::Drop
} else if segment_id < assembly.next_segment_id {
AssemblyUpdate::Pending
} else if segment_id != assembly.next_segment_id {
warn!(
client_id = envelope.client_id.0.as_str(),
"dropping out-of-order segmented remote-control client envelope"
);
AssemblyUpdate::Drop
} else {
assembly.last_chunk_seen_at = now;
let chunk_start = assembly.raw.len();
let decoded_chunk_len = base64::decoded_len_estimate(message_chunk_base64.len());
let chunk_end = usize::min(
message_size_bytes,
chunk_start.saturating_add(decoded_chunk_len),
);
assembly.raw.resize(chunk_end, 0);
match base64::engine::general_purpose::STANDARD.decode_slice(
message_chunk_base64.as_bytes(),
&mut assembly.raw[chunk_start..],
) {
Ok(decoded_chunk_len) => {
assembly.raw.truncate(chunk_start + decoded_chunk_len);
assembly.next_segment_id += 1;
if assembly.next_segment_id < segment_count {
AssemblyUpdate::Pending
} else if assembly.raw.len() != message_size_bytes {
warn!(
client_id = envelope.client_id.0.as_str(),
"dropping reassembled remote-control client envelope with mismatched size"
);
AssemblyUpdate::Drop
} else {
match serde_json::from_slice::<JSONRPCMessage>(&assembly.raw) {
Ok(message) => AssemblyUpdate::Complete(message),
Err(err) => {
warn!(
client_id = envelope.client_id.0.as_str(),
"dropping invalid reassembled remote-control client envelope: {err}"
);
AssemblyUpdate::Drop
}
}
}
}
Err(DecodeSliceError::OutputSliceTooSmall) => {
warn!(
client_id = envelope.client_id.0.as_str(),
"dropping segmented remote-control client envelope after size overflow"
);
AssemblyUpdate::Drop
}
Err(err) => {
warn!(
client_id = envelope.client_id.0.as_str(),
"dropping segmented remote-control client envelope with invalid base64: {err}"
);
AssemblyUpdate::Drop
}
}
}
};
match result {
AssemblyUpdate::Pending => ClientSegmentObservation::Pending,
AssemblyUpdate::Ignore => ClientSegmentObservation::Dropped,
AssemblyUpdate::Drop => {
self.remove_assembly(&envelope.client_id, &stream_id);
ClientSegmentObservation::Dropped
}
AssemblyUpdate::Complete(message) => {
self.remove_assembly(&envelope.client_id, &stream_id);
ClientSegmentObservation::Forward(Box::new(ClientEnvelope {
event: ClientEvent::ClientMessage { message },
..envelope
}))
}
}
}
pub(super) fn invalidate_stream(&mut self, client_id: &ClientId, stream_id: &StreamId) {
self.remove_assembly(client_id, stream_id);
}
pub(super) fn invalidate_client(&mut self, client_id: &ClientId) {
self.assemblies.remove(client_id);
}
pub(super) fn should_ignore_chunk(
&self,
client_id: &ClientId,
stream_id: &StreamId,
seq_id: u64,
segment_id: usize,
) -> bool {
self.assemblies.get(client_id).is_some_and(|assembly| {
assembly.stream_id == *stream_id
&& (seq_id < assembly.metadata.seq_id
|| (seq_id == assembly.metadata.seq_id
&& segment_id < assembly.next_segment_id))
})
}
fn remove_assembly(&mut self, client_id: &ClientId, stream_id: &StreamId) {
if self
.assemblies
.get(client_id)
.is_some_and(|assembly| &assembly.stream_id == stream_id)
{
self.assemblies.remove(client_id);
}
}
fn evict_assemblies_if_full(&mut self) {
while self.assemblies.len() >= REMOTE_CONTROL_SEGMENT_ASSEMBLY_MAX_COUNT {
let Some(client_id) = self
.assemblies
.iter()
.min_by_key(|(_, assembly)| assembly.last_chunk_seen_at)
.map(|(client_id, _)| client_id.clone())
else {
return;
};
self.assemblies.remove(&client_id);
}
}
}
enum AssemblyUpdate {
Pending,
Ignore,
Drop,
Complete(JSONRPCMessage),
}
impl ClientSegmentMetadata {
fn from_envelope(envelope: &ClientEnvelope) -> Option<Self> {
let ClientEvent::ClientMessageChunk {
segment_count,
message_size_bytes,
..
} = &envelope.event
else {
return None;
};
Some(Self {
seq_id: envelope.seq_id?,
segment_count: *segment_count,
message_size_bytes: *message_size_bytes,
})
}
}
pub(super) fn split_server_envelope_for_transport(
envelope: ServerEnvelope,
) -> io::Result<Vec<ServerEnvelope>> {
if !matches!(envelope.event, ServerEvent::ServerMessage { .. }) {
return Ok(vec![envelope]);
}
let envelope_size_bytes = serialized_len(&envelope)?;
if envelope_size_bytes <= REMOTE_CONTROL_SEGMENT_MAX_BYTES {
return Ok(vec![envelope]);
}
let ServerEvent::ServerMessage { message } = envelope.event.clone() else {
unreachable!("server message variant checked above");
};
let raw = serde_json::to_vec(message.as_ref()).map_err(io::Error::other)?;
let message_size_bytes = raw.len();
if message_size_bytes > REMOTE_CONTROL_REASSEMBLED_MAX_BYTES {
warn!("dropping remote-control server envelope that exceeds reassembled size limit");
return Ok(Vec::new());
}
let minimal_segment_count =
usize::min(message_size_bytes.max(1), REMOTE_CONTROL_SEGMENT_COUNT_MAX);
let minimal_chunk = &raw[..usize::min(raw.len(), 1)];
if serialized_chunk_len(
&envelope,
/*segment_id*/ 0,
minimal_segment_count,
message_size_bytes,
minimal_chunk,
)? > REMOTE_CONTROL_SEGMENT_MAX_BYTES
{
warn!("dropping remote-control server envelope that cannot fit within segment size limit");
return Ok(Vec::new());
}
let mut segment_count = usize::max(
2,
message_size_bytes.div_ceil(REMOTE_CONTROL_SEGMENT_TARGET_BYTES),
);
loop {
let chunk_size = usize::max(1, message_size_bytes.div_ceil(segment_count));
segment_count = message_size_bytes.div_ceil(chunk_size);
let segments_fit = raw
.chunks(chunk_size)
.enumerate()
.all(|(segment_id, chunk)| {
serialized_chunk_len(
&envelope,
segment_id,
segment_count,
message_size_bytes,
chunk,
)
.is_ok_and(|size| size <= REMOTE_CONTROL_SEGMENT_MAX_BYTES)
});
if segments_fit {
return raw
.chunks(chunk_size)
.enumerate()
.map(|(segment_id, chunk)| {
build_chunk_envelope(
&envelope,
segment_id,
segment_count,
message_size_bytes,
chunk,
)
})
.collect();
}
if chunk_size == 1 {
warn!(
"dropping remote-control server envelope that cannot fit within segment size limit"
);
return Ok(Vec::new());
}
let next_segment_count = segment_count + 1;
let next_chunk_size = usize::max(1, message_size_bytes.div_ceil(next_segment_count));
segment_count = if next_chunk_size == chunk_size {
message_size_bytes
} else {
next_segment_count
};
}
}
fn serialized_chunk_len(
envelope: &ServerEnvelope,
segment_id: usize,
segment_count: usize,
message_size_bytes: usize,
chunk: &[u8],
) -> io::Result<usize> {
serialized_len(&build_chunk_envelope(
envelope,
segment_id,
segment_count,
message_size_bytes,
chunk,
)?)
}
#[derive(Default)]
struct CountingWriter {
len: usize,
}
impl Write for CountingWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.len += buf.len();
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
fn serialized_len(value: &impl serde::Serialize) -> io::Result<usize> {
let mut writer = CountingWriter::default();
serde_json::to_writer(&mut writer, value).map_err(io::Error::other)?;
Ok(writer.len)
}
fn build_chunk_envelope(
envelope: &ServerEnvelope,
segment_id: usize,
segment_count: usize,
message_size_bytes: usize,
chunk: &[u8],
) -> io::Result<ServerEnvelope> {
if segment_count > REMOTE_CONTROL_SEGMENT_COUNT_MAX {
return Err(io::Error::new(
ErrorKind::InvalidData,
"remote-control segment count exceeds maximum",
));
}
Ok(ServerEnvelope {
event: ServerEvent::ServerMessageChunk {
segment_id,
segment_count,
message_size_bytes,
message_chunk_base64: base64::engine::general_purpose::STANDARD.encode(chunk),
},
client_id: envelope.client_id.clone(),
stream_id: envelope.stream_id.clone(),
seq_id: envelope.seq_id,
})
}

View File

@@ -0,0 +1,386 @@
use super::protocol::ClientEnvelope;
use super::protocol::ClientEvent;
use super::protocol::ClientId;
use super::protocol::ServerEnvelope;
use super::protocol::ServerEvent;
use super::protocol::StreamId;
use super::segment::ClientSegmentObservation;
use super::segment::ClientSegmentReassembler;
use super::segment::REMOTE_CONTROL_SEGMENT_MAX_BYTES;
use super::segment::split_server_envelope_for_transport;
use crate::outgoing_message::OutgoingMessage;
use base64::Engine;
use codex_app_server_protocol::ConfigWarningNotification;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::ServerNotification;
use pretty_assertions::assert_eq;
#[test]
fn reassembles_client_message_chunks() {
let message = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
let raw = serde_json::to_vec(&message).expect("message should serialize");
let split = raw.len() / 2;
let client_id = ClientId("client-1".to_string());
let stream_id = Some(StreamId("stream-1".to_string()));
let mut reassembler = ClientSegmentReassembler::default();
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id.clone(),
stream_id.clone(),
/*seq_id*/ 7,
/*segment_id*/ 0,
/*segment_count*/ 2,
raw.len(),
&raw[..split],
)),
ClientSegmentObservation::Pending
));
let reassembled = match reassembler.observe(chunk_envelope(
client_id.clone(),
stream_id,
/*seq_id*/ 7,
/*segment_id*/ 1,
/*segment_count*/ 2,
raw.len(),
&raw[split..],
)) {
ClientSegmentObservation::Forward(reassembled) => *reassembled,
ClientSegmentObservation::Pending | ClientSegmentObservation::Dropped => {
panic!("message should reassemble")
}
};
assert_eq!(reassembled.client_id, client_id);
assert_eq!(
reassembled.stream_id,
Some(StreamId("stream-1".to_string()))
);
assert_eq!(reassembled.seq_id, Some(7));
assert_eq!(reassembled.cursor, None);
match reassembled.event {
ClientEvent::ClientMessage {
message: reassembled_message,
} => assert_eq!(reassembled_message, message),
other => panic!("expected client message, got {other:?}"),
}
}
#[test]
fn splits_large_server_messages_into_wire_chunks() {
let envelope = ServerEnvelope {
event: ServerEvent::ServerMessage {
message: Box::new(OutgoingMessage::AppServerNotification(
ServerNotification::ConfigWarning(ConfigWarningNotification {
summary: "x".repeat(REMOTE_CONTROL_SEGMENT_MAX_BYTES),
details: None,
path: None,
range: None,
}),
)),
},
client_id: ClientId("client-1".to_string()),
stream_id: StreamId("stream-1".to_string()),
seq_id: 9,
};
let segments = split_server_envelope_for_transport(envelope).expect("split should succeed");
assert!(segments.len() > 1);
assert!(
segments
.iter()
.all(|segment| matches!(segment.event, ServerEvent::ServerMessageChunk { .. }))
);
assert!(segments.iter().all(|segment| segment.seq_id == 9));
assert!(segments.iter().all(|segment| {
serde_json::to_vec(segment)
.expect("segment should serialize")
.len()
<= REMOTE_CONTROL_SEGMENT_MAX_BYTES
}));
}
#[test]
fn invalidates_incomplete_stream_assemblies() {
let message = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
let raw = serde_json::to_vec(&message).expect("message should serialize");
let split = raw.len() / 2;
let client_id = ClientId("client-1".to_string());
let stream_id = StreamId("stream-1".to_string());
let mut reassembler = ClientSegmentReassembler::default();
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id.clone(),
Some(stream_id.clone()),
/*seq_id*/ 7,
/*segment_id*/ 0,
/*segment_count*/ 2,
raw.len(),
&raw[..split],
)),
ClientSegmentObservation::Pending
));
reassembler.invalidate_stream(&client_id, &stream_id);
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id,
Some(stream_id),
/*seq_id*/ 7,
/*segment_id*/ 1,
/*segment_count*/ 2,
raw.len(),
&raw[split..],
)),
ClientSegmentObservation::Dropped
));
}
#[test]
fn resets_incomplete_client_assembly_when_stream_changes() {
let message = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
let raw = serde_json::to_vec(&message).expect("message should serialize");
let split = raw.len() / 2;
let client_id = ClientId("client-1".to_string());
let first_stream_id = StreamId("stream-1".to_string());
let second_stream_id = StreamId("stream-2".to_string());
let mut reassembler = ClientSegmentReassembler::default();
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id.clone(),
Some(first_stream_id.clone()),
/*seq_id*/ 7,
/*segment_id*/ 0,
/*segment_count*/ 2,
raw.len(),
&raw[..split],
)),
ClientSegmentObservation::Pending
));
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id.clone(),
Some(second_stream_id.clone()),
/*seq_id*/ 8,
/*segment_id*/ 0,
/*segment_count*/ 2,
raw.len(),
&raw[..split],
)),
ClientSegmentObservation::Pending
));
let reassembled = match reassembler.observe(chunk_envelope(
client_id.clone(),
Some(second_stream_id),
/*seq_id*/ 8,
/*segment_id*/ 1,
/*segment_count*/ 2,
raw.len(),
&raw[split..],
)) {
ClientSegmentObservation::Forward(reassembled) => *reassembled,
ClientSegmentObservation::Pending | ClientSegmentObservation::Dropped => {
panic!("replacement stream should reassemble")
}
};
assert_eq!(
reassembled.stream_id,
Some(StreamId("stream-2".to_string()))
);
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id,
Some(first_stream_id),
/*seq_id*/ 7,
/*segment_id*/ 1,
/*segment_count*/ 2,
raw.len(),
&raw[split..],
)),
ClientSegmentObservation::Dropped
));
}
#[test]
fn ignores_stale_chunks_without_dropping_newer_assembly() {
let message = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
let raw = serde_json::to_vec(&message).expect("message should serialize");
let split = raw.len() / 2;
let client_id = ClientId("client-1".to_string());
let stream_id = Some(StreamId("stream-1".to_string()));
let mut reassembler = ClientSegmentReassembler::default();
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id.clone(),
stream_id.clone(),
/*seq_id*/ 8,
/*segment_id*/ 0,
/*segment_count*/ 2,
raw.len(),
&raw[..split],
)),
ClientSegmentObservation::Pending
));
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id.clone(),
stream_id.clone(),
/*seq_id*/ 7,
/*segment_id*/ 0,
/*segment_count*/ 2,
raw.len(),
&raw[..split],
)),
ClientSegmentObservation::Dropped
));
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id,
stream_id,
/*seq_id*/ 8,
/*segment_id*/ 1,
/*segment_count*/ 2,
raw.len(),
&raw[split..],
)),
ClientSegmentObservation::Forward(_)
));
}
#[test]
fn ignores_invalid_stale_chunks_without_dropping_newer_assembly() {
let message = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
let raw = serde_json::to_vec(&message).expect("message should serialize");
let split = raw.len() / 2;
let client_id = ClientId("client-1".to_string());
let stream_id = Some(StreamId("stream-1".to_string()));
let mut reassembler = ClientSegmentReassembler::default();
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id.clone(),
stream_id.clone(),
/*seq_id*/ 8,
/*segment_id*/ 0,
/*segment_count*/ 2,
raw.len(),
&raw[..split],
)),
ClientSegmentObservation::Pending
));
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id.clone(),
stream_id.clone(),
/*seq_id*/ 7,
/*segment_id*/ 1,
/*segment_count*/ 2,
raw.len(),
b"",
)),
ClientSegmentObservation::Dropped
));
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id,
stream_id,
/*seq_id*/ 8,
/*segment_id*/ 1,
/*segment_count*/ 2,
raw.len(),
&raw[split..],
)),
ClientSegmentObservation::Forward(_)
));
}
#[test]
fn ignores_invalid_duplicate_chunks_without_dropping_current_assembly() {
let message = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
let raw = serde_json::to_vec(&message).expect("message should serialize");
let split = raw.len() / 2;
let client_id = ClientId("client-1".to_string());
let stream_id = Some(StreamId("stream-1".to_string()));
let mut reassembler = ClientSegmentReassembler::default();
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id.clone(),
stream_id.clone(),
/*seq_id*/ 8,
/*segment_id*/ 0,
/*segment_count*/ 2,
raw.len(),
&raw[..split],
)),
ClientSegmentObservation::Pending
));
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id.clone(),
stream_id.clone(),
/*seq_id*/ 8,
/*segment_id*/ 0,
/*segment_count*/ 2,
raw.len(),
b"",
)),
ClientSegmentObservation::Dropped
));
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id,
stream_id,
/*seq_id*/ 8,
/*segment_id*/ 1,
/*segment_count*/ 2,
raw.len(),
&raw[split..],
)),
ClientSegmentObservation::Forward(_)
));
}
fn chunk_envelope(
client_id: ClientId,
stream_id: Option<StreamId>,
seq_id: u64,
segment_id: usize,
segment_count: usize,
message_size_bytes: usize,
chunk: &[u8],
) -> ClientEnvelope {
ClientEnvelope {
event: ClientEvent::ClientMessageChunk {
segment_id,
segment_count,
message_size_bytes,
message_chunk_base64: base64::engine::general_purpose::STANDARD.encode(chunk),
},
client_id,
stream_id,
seq_id: Some(seq_id),
cursor: None,
}
}

View File

@@ -831,7 +831,7 @@ async fn remote_control_transport_clears_outgoing_buffer_when_backend_acks() {
send_client_event(
&mut first_websocket,
ClientEnvelope {
event: ClientEvent::Ack,
event: ClientEvent::Ack { segment_id: None },
client_id: client_id.clone(),
stream_id: Some(stream_id),
seq_id: Some(1),

View File

@@ -15,6 +15,10 @@ use super::protocol::ClientId;
use super::protocol::RemoteControlTarget;
use super::protocol::ServerEnvelope;
use super::protocol::StreamId;
use super::segment::ClientSegmentObservation;
use super::segment::ClientSegmentReassembler;
use super::segment::REMOTE_CONTROL_SEGMENT_MAX_BYTES;
use super::segment::split_server_envelope_for_transport;
use axum::http::HeaderValue;
use base64::Engine;
use codex_app_server_protocol::RemoteControlConnectionStatus;
@@ -49,7 +53,7 @@ use tracing::error;
use tracing::info;
use tracing::warn;
pub(super) const REMOTE_CONTROL_PROTOCOL_VERSION: &str = "2";
pub(super) const REMOTE_CONTROL_PROTOCOL_VERSION: &str = "3";
pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id";
const REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER: &str = "x-codex-subscribe-cursor";
const REMOTE_CONTROL_WEBSOCKET_PING_INTERVAL: std::time::Duration =
@@ -85,17 +89,29 @@ impl BoundedOutboundBuffer {
self.used_tx.send_modify(|used| *used += 1);
}
fn ack(&mut self, client_id: &ClientId, stream_id: &StreamId, acked_seq_id: u64) {
fn ack(
&mut self,
client_id: &ClientId,
stream_id: &StreamId,
acked_seq_id: u64,
acked_segment_id: Option<usize>,
) {
let key = (client_id.clone(), stream_id.clone());
let Some(buffer) = self.buffer_by_stream.get_mut(&key) else {
return;
};
while let Some(server_envelope) = buffer.front()
&& server_envelope.seq_id <= acked_seq_id
{
buffer.pop_front();
self.used_tx.send_modify(|used| *used -= 1);
}
let acked_cursor = (acked_seq_id, acked_segment_id.unwrap_or(usize::MAX));
buffer.retain(|server_envelope| {
let envelope_cursor = (
server_envelope.seq_id,
server_envelope.event.segment_id().unwrap_or_default(),
);
let is_acked = envelope_cursor <= acked_cursor;
if is_acked {
self.used_tx.send_modify(|used| *used -= 1);
}
!is_acked
});
if buffer.is_empty() {
self.buffer_by_stream.remove(&key);
}
@@ -112,6 +128,88 @@ struct WebsocketState {
outbound_buffer: BoundedOutboundBuffer,
subscribe_cursor: Option<String>,
next_seq_id_by_stream: HashMap<(ClientId, StreamId), u64>,
last_completed_client_chunk_seq_id_by_stream: HashMap<(ClientId, Option<StreamId>), u64>,
client_segment_reassembler: ClientSegmentReassembler,
}
impl WebsocketState {
fn observe_client_message(
&mut self,
client_envelope: ClientEnvelope,
wire_size_bytes: usize,
) -> ClientSegmentObservation {
let client_message_key = Self::client_message_key(&client_envelope);
if let Some((key, seq_id)) = client_message_key.as_ref()
&& self
.last_completed_client_chunk_seq_id_by_stream
.get(key)
.is_some_and(|last_seq_id| last_seq_id >= seq_id)
{
return ClientSegmentObservation::Dropped;
}
if let (
Some((_, seq_id)),
Some(stream_id),
ClientEvent::ClientMessageChunk { segment_id, .. },
) = (
client_message_key.as_ref(),
client_envelope.stream_id.as_ref(),
&client_envelope.event,
) && self.client_segment_reassembler.should_ignore_chunk(
&client_envelope.client_id,
stream_id,
*seq_id,
*segment_id,
) {
return ClientSegmentObservation::Dropped;
}
if client_message_key.is_some() && wire_size_bytes > REMOTE_CONTROL_SEGMENT_MAX_BYTES {
warn!(
client_id = client_envelope.client_id.0.as_str(),
"dropping oversized segmented remote-control client envelope"
);
if let Some(stream_id) = client_envelope.stream_id.as_ref() {
self.client_segment_reassembler
.invalidate_stream(&client_envelope.client_id, stream_id);
}
return ClientSegmentObservation::Dropped;
}
let observation = self.client_segment_reassembler.observe(client_envelope);
if matches!(observation, ClientSegmentObservation::Forward(_))
&& let Some((key, seq_id)) = client_message_key
{
self.last_completed_client_chunk_seq_id_by_stream
.insert(key, seq_id);
}
observation
}
fn invalidate_client_message_stream(&mut self, client_id: &ClientId, stream_id: &StreamId) {
self.last_completed_client_chunk_seq_id_by_stream
.remove(&(client_id.clone(), Some(stream_id.clone())));
}
fn invalidate_client_message_client(&mut self, client_id: &ClientId) {
self.last_completed_client_chunk_seq_id_by_stream
.retain(|(cursor_client_id, _), _| cursor_client_id != client_id);
}
fn client_message_key(
client_envelope: &ClientEnvelope,
) -> Option<((ClientId, Option<StreamId>), u64)> {
let seq_id = match (&client_envelope.event, client_envelope.seq_id) {
(ClientEvent::ClientMessageChunk { .. }, Some(seq_id)) => seq_id,
_ => return None,
};
Some((
(
client_envelope.client_id.clone(),
client_envelope.stream_id.clone(),
),
seq_id,
))
}
}
pub(crate) struct RemoteControlWebsocket {
@@ -231,6 +329,8 @@ impl RemoteControlWebsocket {
outbound_buffer,
subscribe_cursor: None,
next_seq_id_by_stream: HashMap::new(),
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
client_segment_reassembler: ClientSegmentReassembler::default(),
})),
server_event_rx: Arc::new(Mutex::new(server_event_rx)),
used_rx,
@@ -556,7 +656,7 @@ impl RemoteControlWebsocket {
}
}
};
let (payload, write_complete_tx) = {
let (payloads, write_complete_tx) = {
let mut state = state.lock().await;
let seq_key = (
queued_server_envelope.client_id.clone(),
@@ -573,29 +673,42 @@ impl RemoteControlWebsocket {
seq_id,
stream_id: queued_server_envelope.stream_id,
};
let payload = match serde_json::to_string(&server_envelope) {
Ok(payload) => payload,
let server_envelopes = match split_server_envelope_for_transport(server_envelope) {
Ok(server_envelopes) => server_envelopes,
Err(err) => {
error!("failed to serialize remote-control server event: {err}");
error!("failed to split remote-control server event: {err}");
continue;
}
};
let mut payloads = Vec::with_capacity(server_envelopes.len());
for server_envelope in server_envelopes {
let payload = match serde_json::to_string(&server_envelope) {
Ok(payload) => payload,
Err(err) => {
error!("failed to serialize remote-control server event: {err}");
continue;
}
};
state.outbound_buffer.insert(&server_envelope);
payloads.push(payload);
}
state
.next_seq_id_by_stream
.insert(seq_key, seq_id.saturating_add(1));
state.outbound_buffer.insert(&server_envelope);
(payload, queued_server_envelope.write_complete_tx)
(payloads, queued_server_envelope.write_complete_tx)
};
tokio::select! {
_ = shutdown_token.cancelled() => return Ok(()),
send_result = websocket_writer.send(tungstenite::Message::Text(payload.into())) => {
if let Err(err) = send_result {
return Err(io::Error::other(err));
for payload in payloads {
tokio::select! {
_ = shutdown_token.cancelled() => return Ok(()),
send_result = websocket_writer.send(tungstenite::Message::Text(payload.into())) => {
if let Err(err) = send_result {
return Err(io::Error::other(err));
}
}
}
};
}
if let Some(write_complete_tx) = write_complete_tx {
let _ = write_complete_tx.send(());
}
@@ -657,11 +770,30 @@ impl RemoteControlWebsocket {
if client_tracker.close_client(&client_key).await.is_err() {
return Ok(());
}
state
.lock()
.await
.client_segment_reassembler
.invalidate_stream(&client_key.0, &client_key.1);
state
.lock()
.await
.invalidate_client_message_stream(&client_key.0, &client_key.1);
continue;
}
_ = idle_sweep_interval.tick() => {
if client_tracker.close_expired_clients().await.is_err() {
return Ok(());
match client_tracker.close_expired_clients().await {
Ok(client_keys) => {
let mut websocket_state = state.lock().await;
for (client_id, stream_id) in client_keys {
websocket_state
.client_segment_reassembler
.invalidate_stream(&client_id, &stream_id);
websocket_state
.invalidate_client_message_stream(&client_id, &stream_id);
}
}
Err(_) => return Ok(()),
}
continue;
}
@@ -672,10 +804,11 @@ impl RemoteControlWebsocket {
}
}
};
let client_envelope = match incoming_message {
let (client_envelope, wire_size_bytes) = match incoming_message {
Ok(tungstenite::Message::Text(text)) => {
let wire_size_bytes = text.len();
match serde_json::from_str::<ClientEnvelope>(&text) {
Ok(client_envelope) => client_envelope,
Ok(client_envelope) => (client_envelope, wire_size_bytes),
Err(err) => {
warn!("failed to deserialize remote-control client event: {err}");
continue;
@@ -707,12 +840,21 @@ impl RemoteControlWebsocket {
}
};
let observation = {
let mut websocket_state = state.lock().await;
websocket_state.observe_client_message(client_envelope, wire_size_bytes)
};
let client_envelope = match observation {
ClientSegmentObservation::Forward(client_envelope) => *client_envelope,
ClientSegmentObservation::Pending | ClientSegmentObservation::Dropped => continue,
};
{
let mut websocket_state = state.lock().await;
if let Some(cursor) = client_envelope.cursor.as_deref() {
websocket_state.subscribe_cursor = Some(cursor.to_string());
}
if let ClientEvent::Ack = &client_envelope.event
if let ClientEvent::Ack { segment_id } = &client_envelope.event
&& let Some(acked_seq_id) = client_envelope.seq_id
&& let Some(stream_id) = client_envelope.stream_id.as_ref()
{
@@ -720,10 +862,18 @@ impl RemoteControlWebsocket {
&client_envelope.client_id,
stream_id,
acked_seq_id,
*segment_id,
);
}
}
let closed_client =
matches!(&client_envelope.event, ClientEvent::ClientClosed).then(|| {
(
client_envelope.client_id.clone(),
client_envelope.stream_id.clone(),
)
});
if client_tracker
.handle_message(client_envelope)
.await
@@ -731,6 +881,20 @@ impl RemoteControlWebsocket {
{
return Ok(());
}
if let Some((client_id, stream_id)) = closed_client {
let mut websocket_state = state.lock().await;
if let Some(stream_id) = stream_id {
websocket_state
.client_segment_reassembler
.invalidate_stream(&client_id, &stream_id);
websocket_state.invalidate_client_message_stream(&client_id, &stream_id);
} else {
websocket_state
.client_segment_reassembler
.invalidate_client(&client_id);
websocket_state.invalidate_client_message_client(&client_id);
}
}
}
}
}
@@ -1052,6 +1216,8 @@ mod tests {
use chrono::Utc;
use codex_app_server_protocol::AuthMode;
use codex_app_server_protocol::ConfigWarningNotification;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::ServerNotification;
use codex_config::types::AuthCredentialsStoreMode;
use codex_core::test_support::auth_manager_from_auth;
@@ -1603,6 +1769,8 @@ mod tests {
outbound_buffer,
subscribe_cursor: None,
next_seq_id_by_stream: HashMap::new(),
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
client_segment_reassembler: ClientSegmentReassembler::default(),
}));
let (_server_event_tx, server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY);
let server_event_rx = Arc::new(Mutex::new(server_event_rx));
@@ -1639,6 +1807,8 @@ mod tests {
outbound_buffer,
subscribe_cursor: None,
next_seq_id_by_stream: HashMap::new(),
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
client_segment_reassembler: ClientSegmentReassembler::default(),
}));
let (server_event_tx, server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY);
let server_event_rx = Arc::new(Mutex::new(server_event_rx));
@@ -1716,6 +1886,8 @@ mod tests {
outbound_buffer,
subscribe_cursor: None,
next_seq_id_by_stream: HashMap::new(),
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
client_segment_reassembler: ClientSegmentReassembler::default(),
}));
let (server_event_tx, _server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY);
let (transport_event_tx, _transport_event_rx) =
@@ -1771,7 +1943,9 @@ mod tests {
"first-client-new-stream",
));
outbound_buffer.ack(&client_1, &stream_1, /*acked_seq_id*/ 3);
outbound_buffer.ack(
&client_1, &stream_1, /*acked_seq_id*/ 3, /*acked_segment_id*/ None,
);
let mut retained = outbound_buffer
.server_envelopes()
@@ -1814,7 +1988,9 @@ mod tests {
&client_2, "stream-1", /*seq_id*/ 3, "second",
));
outbound_buffer.ack(&client_1, &stream_1, /*acked_seq_id*/ 1);
outbound_buffer.ack(
&client_1, &stream_1, /*acked_seq_id*/ 1, /*acked_segment_id*/ None,
);
let mut retained = outbound_buffer
.server_envelopes()
@@ -1834,6 +2010,390 @@ mod tests {
assert_eq!(*used_rx.borrow(), 2);
}
#[test]
fn outbound_buffer_advances_segmented_acks_by_wire_cursor() {
let (mut outbound_buffer, used_rx) = BoundedOutboundBuffer::new();
let client_id = ClientId("client-1".to_string());
let stream_id = StreamId("stream-1".to_string());
outbound_buffer.insert(&server_chunk_envelope(
&client_id, "stream-1", /*seq_id*/ 4, /*segment_id*/ 0,
));
outbound_buffer.insert(&server_chunk_envelope(
&client_id, "stream-1", /*seq_id*/ 4, /*segment_id*/ 1,
));
outbound_buffer.ack(
&client_id,
&stream_id,
/*acked_seq_id*/ 4,
/*acked_segment_id*/ Some(1),
);
let retained = outbound_buffer
.server_envelopes()
.map(|server_envelope| server_envelope.event.segment_id())
.collect::<Vec<_>>();
assert_eq!(retained, Vec::<Option<usize>>::new());
assert_eq!(*used_rx.borrow(), 0);
}
#[test]
fn outbound_buffer_treats_segmentless_acks_as_seq_level_acks() {
let (mut outbound_buffer, used_rx) = BoundedOutboundBuffer::new();
let client_id = ClientId("client-1".to_string());
let stream_id = StreamId("stream-1".to_string());
outbound_buffer.insert(&server_chunk_envelope(
&client_id, "stream-1", /*seq_id*/ 4, /*segment_id*/ 0,
));
outbound_buffer.insert(&server_chunk_envelope(
&client_id, "stream-1", /*seq_id*/ 4, /*segment_id*/ 1,
));
outbound_buffer.ack(
&client_id, &stream_id, /*acked_seq_id*/ 4, /*acked_segment_id*/ None,
);
let retained = outbound_buffer
.server_envelopes()
.map(|server_envelope| server_envelope.event.segment_id())
.collect::<Vec<_>>();
assert_eq!(retained, Vec::<Option<usize>>::new());
assert_eq!(*used_rx.borrow(), 0);
}
#[test]
fn websocket_state_drops_duplicate_client_chunks_while_pending() {
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
let mut state = WebsocketState {
outbound_buffer,
subscribe_cursor: None,
next_seq_id_by_stream: HashMap::new(),
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
client_segment_reassembler: ClientSegmentReassembler::default(),
};
let first_chunk = client_chunk_envelope(
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0,
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"x",
);
let second_chunk = client_chunk_envelope(
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 1,
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"y",
);
assert!(matches!(
observe_client_message(&mut state, first_chunk.clone()),
ClientSegmentObservation::Pending
));
assert!(matches!(
observe_client_message(&mut state, first_chunk.clone()),
ClientSegmentObservation::Dropped
));
assert!(matches!(
observe_client_message(&mut state, second_chunk),
ClientSegmentObservation::Dropped
));
assert!(matches!(
observe_client_message(&mut state, first_chunk),
ClientSegmentObservation::Pending
));
}
#[test]
fn websocket_state_drops_replayed_client_chunks_after_completion() {
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
let mut state = WebsocketState {
outbound_buffer,
subscribe_cursor: None,
next_seq_id_by_stream: HashMap::new(),
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
client_segment_reassembler: ClientSegmentReassembler::default(),
};
let message = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
let raw = serde_json::to_vec(&message).expect("message should serialize");
let split = raw.len() / 2;
let first_chunk = client_chunk_envelope(
"client-1",
"stream-1",
/*seq_id*/ 4,
/*segment_id*/ 0,
/*segment_count*/ 2,
raw.len(),
&raw[..split],
);
let second_chunk = client_chunk_envelope(
"client-1",
"stream-1",
/*seq_id*/ 4,
/*segment_id*/ 1,
/*segment_count*/ 2,
raw.len(),
&raw[split..],
);
assert!(matches!(
observe_client_message(&mut state, first_chunk.clone()),
ClientSegmentObservation::Pending
));
assert!(matches!(
observe_client_message(&mut state, second_chunk),
ClientSegmentObservation::Forward(_)
));
assert!(matches!(
observe_client_message(&mut state, first_chunk),
ClientSegmentObservation::Dropped
));
}
#[test]
fn websocket_state_allows_replay_after_rejected_out_of_order_chunk() {
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
let mut state = WebsocketState {
outbound_buffer,
subscribe_cursor: None,
next_seq_id_by_stream: HashMap::new(),
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
client_segment_reassembler: ClientSegmentReassembler::default(),
};
let first_chunk = client_chunk_envelope(
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0,
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"x",
);
let second_chunk = client_chunk_envelope(
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 1,
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"y",
);
assert!(matches!(
observe_client_message(&mut state, second_chunk),
ClientSegmentObservation::Dropped
));
assert!(matches!(
observe_client_message(&mut state, first_chunk),
ClientSegmentObservation::Pending
));
}
#[test]
fn websocket_state_allows_replay_after_later_chunk_drops() {
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
let mut state = WebsocketState {
outbound_buffer,
subscribe_cursor: None,
next_seq_id_by_stream: HashMap::new(),
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
client_segment_reassembler: ClientSegmentReassembler::default(),
};
let first_chunk = client_chunk_envelope(
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0,
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"x",
);
let invalid_second_chunk = client_chunk_envelope(
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 1,
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"",
);
assert!(matches!(
observe_client_message(&mut state, first_chunk.clone()),
ClientSegmentObservation::Pending
));
assert!(matches!(
observe_client_message(&mut state, invalid_second_chunk),
ClientSegmentObservation::Dropped
));
assert!(matches!(
observe_client_message(&mut state, first_chunk),
ClientSegmentObservation::Pending
));
}
#[test]
fn websocket_state_drops_oversized_client_chunk_frames() {
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
let mut state = WebsocketState {
outbound_buffer,
subscribe_cursor: None,
next_seq_id_by_stream: HashMap::new(),
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
client_segment_reassembler: ClientSegmentReassembler::default(),
};
let chunk = client_chunk_envelope(
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0,
/*segment_count*/ 1, /*message_size_bytes*/ 1, b"x",
);
assert!(matches!(
state.observe_client_message(chunk, REMOTE_CONTROL_SEGMENT_MAX_BYTES + 1),
ClientSegmentObservation::Dropped
));
}
#[test]
fn websocket_state_ignores_oversized_stale_chunks_without_dropping_newer_assembly() {
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
let mut state = WebsocketState {
outbound_buffer,
subscribe_cursor: None,
next_seq_id_by_stream: HashMap::new(),
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
client_segment_reassembler: ClientSegmentReassembler::default(),
};
let message = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
let raw = serde_json::to_vec(&message).expect("message should serialize");
let split = raw.len() / 2;
let first_newer_chunk = client_chunk_envelope(
"client-1",
"stream-1",
/*seq_id*/ 8,
/*segment_id*/ 0,
/*segment_count*/ 2,
raw.len(),
&raw[..split],
);
let oversized_stale_chunk = client_chunk_envelope(
"client-1",
"stream-1",
/*seq_id*/ 7,
/*segment_id*/ 0,
/*segment_count*/ 2,
raw.len(),
&raw[..split],
);
let second_newer_chunk = client_chunk_envelope(
"client-1",
"stream-1",
/*seq_id*/ 8,
/*segment_id*/ 1,
/*segment_count*/ 2,
raw.len(),
&raw[split..],
);
assert!(matches!(
observe_client_message(&mut state, first_newer_chunk),
ClientSegmentObservation::Pending
));
assert!(matches!(
state.observe_client_message(
oversized_stale_chunk,
REMOTE_CONTROL_SEGMENT_MAX_BYTES + 1,
),
ClientSegmentObservation::Dropped
));
assert!(matches!(
observe_client_message(&mut state, second_newer_chunk),
ClientSegmentObservation::Forward(_)
));
}
#[test]
fn websocket_state_ignores_oversized_duplicate_chunks_without_dropping_current_assembly() {
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
let mut state = WebsocketState {
outbound_buffer,
subscribe_cursor: None,
next_seq_id_by_stream: HashMap::new(),
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
client_segment_reassembler: ClientSegmentReassembler::default(),
};
let message = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
let raw = serde_json::to_vec(&message).expect("message should serialize");
let split = raw.len() / 2;
let first_chunk = client_chunk_envelope(
"client-1",
"stream-1",
/*seq_id*/ 8,
/*segment_id*/ 0,
/*segment_count*/ 2,
raw.len(),
&raw[..split],
);
let oversized_duplicate_chunk = client_chunk_envelope(
"client-1",
"stream-1",
/*seq_id*/ 8,
/*segment_id*/ 0,
/*segment_count*/ 2,
raw.len(),
&raw[..split],
);
let second_chunk = client_chunk_envelope(
"client-1",
"stream-1",
/*seq_id*/ 8,
/*segment_id*/ 1,
/*segment_count*/ 2,
raw.len(),
&raw[split..],
);
assert!(matches!(
observe_client_message(&mut state, first_chunk),
ClientSegmentObservation::Pending
));
assert!(matches!(
state.observe_client_message(
oversized_duplicate_chunk,
REMOTE_CONTROL_SEGMENT_MAX_BYTES + 1,
),
ClientSegmentObservation::Dropped
));
assert!(matches!(
observe_client_message(&mut state, second_chunk),
ClientSegmentObservation::Forward(_)
));
}
#[test]
fn websocket_state_clears_chunk_cursor_when_stream_is_invalidated() {
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
let mut state = WebsocketState {
outbound_buffer,
subscribe_cursor: None,
next_seq_id_by_stream: HashMap::new(),
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
client_segment_reassembler: ClientSegmentReassembler::default(),
};
let client_id = ClientId("client-1".to_string());
let stream_id = StreamId("stream-1".to_string());
assert!(matches!(
observe_client_message(
&mut state,
client_chunk_envelope(
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0,
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"x",
)
),
ClientSegmentObservation::Pending
));
state.invalidate_client_message_stream(&client_id, &stream_id);
state
.client_segment_reassembler
.invalidate_stream(&client_id, &stream_id);
assert!(matches!(
observe_client_message(
&mut state,
client_chunk_envelope(
"client-1", "stream-1", /*seq_id*/ 1, /*segment_id*/ 0,
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"x",
)
),
ClientSegmentObservation::Pending
));
}
fn server_envelope(
client_id: &ClientId,
stream_id: &str,
@@ -1857,6 +2417,58 @@ mod tests {
}
}
fn server_chunk_envelope(
client_id: &ClientId,
stream_id: &str,
seq_id: u64,
segment_id: usize,
) -> ServerEnvelope {
ServerEnvelope {
event: ServerEvent::ServerMessageChunk {
segment_id,
segment_count: 2,
message_size_bytes: 2,
message_chunk_base64: String::new(),
},
client_id: client_id.clone(),
stream_id: StreamId(stream_id.to_string()),
seq_id,
}
}
fn client_chunk_envelope(
client_id: &str,
stream_id: &str,
seq_id: u64,
segment_id: usize,
segment_count: usize,
message_size_bytes: usize,
chunk: &[u8],
) -> ClientEnvelope {
ClientEnvelope {
event: ClientEvent::ClientMessageChunk {
segment_id,
segment_count,
message_size_bytes,
message_chunk_base64: base64::engine::general_purpose::STANDARD.encode(chunk),
},
client_id: ClientId(client_id.to_string()),
stream_id: Some(StreamId(stream_id.to_string())),
seq_id: Some(seq_id),
cursor: None,
}
}
fn observe_client_message(
state: &mut WebsocketState,
envelope: ClientEnvelope,
) -> ClientSegmentObservation {
let wire_size_bytes = serde_json::to_vec(&envelope)
.expect("client envelope should serialize")
.len();
state.observe_client_message(envelope, wire_size_bytes)
}
async fn accept_http_request(listener: &TcpListener) -> (TcpStream, String) {
let (stream, _) = timeout(TEST_HTTP_ACCEPT_TIMEOUT, listener.accept())
.await