diff --git a/codex-rs/app-server/src/transport/remote_control/client_tracker.rs b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs index cbd74c2fd9..4639942b08 100644 --- a/codex-rs/app-server/src/transport/remote_control/client_tracker.rs +++ b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs @@ -195,7 +195,7 @@ impl ClientTracker { }) .await } - ClientEvent::Ack => Ok(()), + ClientEvent::ClientMessageChunk { .. } | ClientEvent::Ack { .. } => Ok(()), ClientEvent::Ping => { if let Some(client) = self.clients.get_mut(&client_key) { client.last_activity_at = Instant::now(); diff --git a/codex-rs/app-server/src/transport/remote_control/mod.rs b/codex-rs/app-server/src/transport/remote_control/mod.rs index ef517e1ae2..2d0eb7dfb9 100644 --- a/codex-rs/app-server/src/transport/remote_control/mod.rs +++ b/codex-rs/app-server/src/transport/remote_control/mod.rs @@ -1,6 +1,7 @@ mod client_tracker; mod enroll; mod protocol; +mod segment; mod websocket; use crate::transport::remote_control::websocket::RemoteControlChannels; @@ -121,5 +122,7 @@ pub(crate) async fn start_remote_control( )) } +#[cfg(test)] +mod segment_tests; #[cfg(test)] mod tests; diff --git a/codex-rs/app-server/src/transport/remote_control/protocol.rs b/codex-rs/app-server/src/transport/remote_control/protocol.rs index f0db5ecacb..dea5404ab1 100644 --- a/codex-rs/app-server/src/transport/remote_control/protocol.rs +++ b/codex-rs/app-server/src/transport/remote_control/protocol.rs @@ -47,10 +47,20 @@ pub enum ClientEvent { ClientMessage { message: JSONRPCMessage, }, + ClientMessageChunk { + segment_id: usize, + segment_count: usize, + message_size_bytes: usize, + message_chunk_base64: String, + }, /// Backend-generated acknowledgement for all server envelopes addressed to /// `client_id` and `stream_id` whose envelope `seq_id` is less than or equal - /// to this ack's `seq_id`. This cursor is stream-scoped. - Ack, + /// to this ack's `seq_id`. Chunk acknowledgements carry `segment_id` so the + /// sender can retain only the still-unacked wire chunks on reconnect. + Ack { + #[serde(skip_serializing_if = "Option::is_none")] + segment_id: Option, + }, Ping, ClientClosed, } @@ -85,6 +95,12 @@ pub enum ServerEvent { ServerMessage { message: Box, }, + ServerMessageChunk { + segment_id: usize, + segment_count: usize, + message_size_bytes: usize, + message_chunk_base64: String, + }, #[allow(dead_code)] Ack, Pong { @@ -92,6 +108,15 @@ pub enum ServerEvent { }, } +impl ServerEvent { + pub(crate) fn segment_id(&self) -> Option { + match self { + Self::ServerMessageChunk { segment_id, .. } => Some(*segment_id), + Self::ServerMessage { .. } | Self::Ack | Self::Pong { .. } => None, + } + } +} + #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "snake_case")] pub(crate) struct ServerEnvelope { diff --git a/codex-rs/app-server/src/transport/remote_control/segment.rs b/codex-rs/app-server/src/transport/remote_control/segment.rs new file mode 100644 index 0000000000..ab0d23a881 --- /dev/null +++ b/codex-rs/app-server/src/transport/remote_control/segment.rs @@ -0,0 +1,449 @@ +use super::protocol::ClientEnvelope; +use super::protocol::ClientEvent; +use super::protocol::ClientId; +use super::protocol::ServerEnvelope; +use super::protocol::ServerEvent; +use super::protocol::StreamId; +use base64::DecodeSliceError; +use base64::Engine; +use codex_app_server_protocol::JSONRPCMessage; +use std::collections::HashMap; +use std::io; +use std::io::ErrorKind; +use std::io::Write; +use tokio::time::Instant; +use tracing::warn; + +pub(super) const REMOTE_CONTROL_SEGMENT_TARGET_BYTES: usize = 100 * 1024; +pub(super) const REMOTE_CONTROL_SEGMENT_MAX_BYTES: usize = 150 * 1024; +pub(super) const REMOTE_CONTROL_REASSEMBLED_MAX_BYTES: usize = 100 * 1024 * 1024; +pub(super) const REMOTE_CONTROL_SEGMENT_COUNT_MAX: usize = 1024; +const REMOTE_CONTROL_SEGMENT_ASSEMBLY_MAX_COUNT: usize = 128; + +#[derive(Debug)] +struct ClientSegmentAssembly { + stream_id: StreamId, + metadata: ClientSegmentMetadata, + raw: Vec, + next_segment_id: usize, + last_chunk_seen_at: Instant, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct ClientSegmentMetadata { + seq_id: u64, + segment_count: usize, + message_size_bytes: usize, +} + +#[derive(Default)] +pub(super) struct ClientSegmentReassembler { + assemblies: HashMap, +} + +pub(super) enum ClientSegmentObservation { + Forward(Box), + Pending, + Dropped, +} + +impl ClientSegmentReassembler { + pub(super) fn observe(&mut self, envelope: ClientEnvelope) -> ClientSegmentObservation { + let ClientEvent::ClientMessageChunk { + segment_id, + segment_count, + message_size_bytes, + message_chunk_base64, + } = &envelope.event + else { + return ClientSegmentObservation::Forward(Box::new(envelope)); + }; + let segment_id = *segment_id; + let segment_count = *segment_count; + let message_size_bytes = *message_size_bytes; + + let Some(metadata) = ClientSegmentMetadata::from_envelope(&envelope) else { + warn!( + client_id = envelope.client_id.0.as_str(), + "dropping segmented remote-control client envelope without seq_id" + ); + return ClientSegmentObservation::Dropped; + }; + let Some(stream_id) = envelope.stream_id.clone() else { + warn!( + client_id = envelope.client_id.0.as_str(), + "dropping segmented remote-control client envelope without stream_id" + ); + return ClientSegmentObservation::Dropped; + }; + if self.should_ignore_chunk(&envelope.client_id, &stream_id, metadata.seq_id, segment_id) { + return ClientSegmentObservation::Dropped; + } + if segment_count == 0 + || segment_count > REMOTE_CONTROL_SEGMENT_COUNT_MAX + || segment_id >= segment_count + || message_size_bytes == 0 + || message_size_bytes > REMOTE_CONTROL_REASSEMBLED_MAX_BYTES + || message_chunk_base64.is_empty() + { + warn!( + client_id = envelope.client_id.0.as_str(), + "dropping invalid segmented remote-control client envelope" + ); + self.remove_assembly(&envelope.client_id, &stream_id); + return ClientSegmentObservation::Dropped; + } + + let now = Instant::now(); + match self.assemblies.get(&envelope.client_id) { + Some(assembly) if assembly.stream_id != stream_id => { + warn!( + client_id = envelope.client_id.0.as_str(), + "resetting segmented remote-control client envelope after stream change" + ); + self.assemblies.insert( + envelope.client_id.clone(), + ClientSegmentAssembly { + stream_id: stream_id.clone(), + metadata: metadata.clone(), + raw: Vec::new(), + next_segment_id: 0, + last_chunk_seen_at: now, + }, + ); + } + Some(_) => {} + None => { + self.evict_assemblies_if_full(); + self.assemblies.insert( + envelope.client_id.clone(), + ClientSegmentAssembly { + stream_id: stream_id.clone(), + metadata: metadata.clone(), + raw: Vec::new(), + next_segment_id: 0, + last_chunk_seen_at: now, + }, + ); + } + } + let result = { + let Some(assembly) = self.assemblies.get_mut(&envelope.client_id) else { + warn!( + client_id = envelope.client_id.0.as_str(), + "dropping segmented remote-control client envelope without assembly" + ); + return ClientSegmentObservation::Dropped; + }; + if metadata.seq_id < assembly.metadata.seq_id { + AssemblyUpdate::Ignore + } else if assembly.metadata != metadata { + warn!( + client_id = envelope.client_id.0.as_str(), + "resetting segmented remote-control client envelope after metadata mismatch" + ); + AssemblyUpdate::Drop + } else if segment_id < assembly.next_segment_id { + AssemblyUpdate::Pending + } else if segment_id != assembly.next_segment_id { + warn!( + client_id = envelope.client_id.0.as_str(), + "dropping out-of-order segmented remote-control client envelope" + ); + AssemblyUpdate::Drop + } else { + assembly.last_chunk_seen_at = now; + let chunk_start = assembly.raw.len(); + let decoded_chunk_len = base64::decoded_len_estimate(message_chunk_base64.len()); + let chunk_end = usize::min( + message_size_bytes, + chunk_start.saturating_add(decoded_chunk_len), + ); + assembly.raw.resize(chunk_end, 0); + match base64::engine::general_purpose::STANDARD.decode_slice( + message_chunk_base64.as_bytes(), + &mut assembly.raw[chunk_start..], + ) { + Ok(decoded_chunk_len) => { + assembly.raw.truncate(chunk_start + decoded_chunk_len); + assembly.next_segment_id += 1; + if assembly.next_segment_id < segment_count { + AssemblyUpdate::Pending + } else if assembly.raw.len() != message_size_bytes { + warn!( + client_id = envelope.client_id.0.as_str(), + "dropping reassembled remote-control client envelope with mismatched size" + ); + AssemblyUpdate::Drop + } else { + match serde_json::from_slice::(&assembly.raw) { + Ok(message) => AssemblyUpdate::Complete(message), + Err(err) => { + warn!( + client_id = envelope.client_id.0.as_str(), + "dropping invalid reassembled remote-control client envelope: {err}" + ); + AssemblyUpdate::Drop + } + } + } + } + Err(DecodeSliceError::OutputSliceTooSmall) => { + warn!( + client_id = envelope.client_id.0.as_str(), + "dropping segmented remote-control client envelope after size overflow" + ); + AssemblyUpdate::Drop + } + Err(err) => { + warn!( + client_id = envelope.client_id.0.as_str(), + "dropping segmented remote-control client envelope with invalid base64: {err}" + ); + AssemblyUpdate::Drop + } + } + } + }; + + match result { + AssemblyUpdate::Pending => ClientSegmentObservation::Pending, + AssemblyUpdate::Ignore => ClientSegmentObservation::Dropped, + AssemblyUpdate::Drop => { + self.remove_assembly(&envelope.client_id, &stream_id); + ClientSegmentObservation::Dropped + } + AssemblyUpdate::Complete(message) => { + self.remove_assembly(&envelope.client_id, &stream_id); + ClientSegmentObservation::Forward(Box::new(ClientEnvelope { + event: ClientEvent::ClientMessage { message }, + ..envelope + })) + } + } + } + + pub(super) fn invalidate_stream(&mut self, client_id: &ClientId, stream_id: &StreamId) { + self.remove_assembly(client_id, stream_id); + } + + pub(super) fn invalidate_client(&mut self, client_id: &ClientId) { + self.assemblies.remove(client_id); + } + + pub(super) fn should_ignore_chunk( + &self, + client_id: &ClientId, + stream_id: &StreamId, + seq_id: u64, + segment_id: usize, + ) -> bool { + self.assemblies.get(client_id).is_some_and(|assembly| { + assembly.stream_id == *stream_id + && (seq_id < assembly.metadata.seq_id + || (seq_id == assembly.metadata.seq_id + && segment_id < assembly.next_segment_id)) + }) + } + + fn remove_assembly(&mut self, client_id: &ClientId, stream_id: &StreamId) { + if self + .assemblies + .get(client_id) + .is_some_and(|assembly| &assembly.stream_id == stream_id) + { + self.assemblies.remove(client_id); + } + } + + fn evict_assemblies_if_full(&mut self) { + while self.assemblies.len() >= REMOTE_CONTROL_SEGMENT_ASSEMBLY_MAX_COUNT { + let Some(client_id) = self + .assemblies + .iter() + .min_by_key(|(_, assembly)| assembly.last_chunk_seen_at) + .map(|(client_id, _)| client_id.clone()) + else { + return; + }; + self.assemblies.remove(&client_id); + } + } +} + +enum AssemblyUpdate { + Pending, + Ignore, + Drop, + Complete(JSONRPCMessage), +} + +impl ClientSegmentMetadata { + fn from_envelope(envelope: &ClientEnvelope) -> Option { + let ClientEvent::ClientMessageChunk { + segment_count, + message_size_bytes, + .. + } = &envelope.event + else { + return None; + }; + Some(Self { + seq_id: envelope.seq_id?, + segment_count: *segment_count, + message_size_bytes: *message_size_bytes, + }) + } +} + +pub(super) fn split_server_envelope_for_transport( + envelope: ServerEnvelope, +) -> io::Result> { + if !matches!(envelope.event, ServerEvent::ServerMessage { .. }) { + return Ok(vec![envelope]); + } + + let envelope_size_bytes = serialized_len(&envelope)?; + if envelope_size_bytes <= REMOTE_CONTROL_SEGMENT_MAX_BYTES { + return Ok(vec![envelope]); + } + + let ServerEvent::ServerMessage { message } = envelope.event.clone() else { + unreachable!("server message variant checked above"); + }; + let raw = serde_json::to_vec(message.as_ref()).map_err(io::Error::other)?; + let message_size_bytes = raw.len(); + if message_size_bytes > REMOTE_CONTROL_REASSEMBLED_MAX_BYTES { + warn!("dropping remote-control server envelope that exceeds reassembled size limit"); + return Ok(Vec::new()); + } + + let minimal_segment_count = + usize::min(message_size_bytes.max(1), REMOTE_CONTROL_SEGMENT_COUNT_MAX); + let minimal_chunk = &raw[..usize::min(raw.len(), 1)]; + if serialized_chunk_len( + &envelope, + /*segment_id*/ 0, + minimal_segment_count, + message_size_bytes, + minimal_chunk, + )? > REMOTE_CONTROL_SEGMENT_MAX_BYTES + { + warn!("dropping remote-control server envelope that cannot fit within segment size limit"); + return Ok(Vec::new()); + } + + let mut segment_count = usize::max( + 2, + message_size_bytes.div_ceil(REMOTE_CONTROL_SEGMENT_TARGET_BYTES), + ); + loop { + let chunk_size = usize::max(1, message_size_bytes.div_ceil(segment_count)); + segment_count = message_size_bytes.div_ceil(chunk_size); + let segments_fit = raw + .chunks(chunk_size) + .enumerate() + .all(|(segment_id, chunk)| { + serialized_chunk_len( + &envelope, + segment_id, + segment_count, + message_size_bytes, + chunk, + ) + .is_ok_and(|size| size <= REMOTE_CONTROL_SEGMENT_MAX_BYTES) + }); + if segments_fit { + return raw + .chunks(chunk_size) + .enumerate() + .map(|(segment_id, chunk)| { + build_chunk_envelope( + &envelope, + segment_id, + segment_count, + message_size_bytes, + chunk, + ) + }) + .collect(); + } + if chunk_size == 1 { + warn!( + "dropping remote-control server envelope that cannot fit within segment size limit" + ); + return Ok(Vec::new()); + } + let next_segment_count = segment_count + 1; + let next_chunk_size = usize::max(1, message_size_bytes.div_ceil(next_segment_count)); + segment_count = if next_chunk_size == chunk_size { + message_size_bytes + } else { + next_segment_count + }; + } +} + +fn serialized_chunk_len( + envelope: &ServerEnvelope, + segment_id: usize, + segment_count: usize, + message_size_bytes: usize, + chunk: &[u8], +) -> io::Result { + serialized_len(&build_chunk_envelope( + envelope, + segment_id, + segment_count, + message_size_bytes, + chunk, + )?) +} + +#[derive(Default)] +struct CountingWriter { + len: usize, +} + +impl Write for CountingWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.len += buf.len(); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +fn serialized_len(value: &impl serde::Serialize) -> io::Result { + let mut writer = CountingWriter::default(); + serde_json::to_writer(&mut writer, value).map_err(io::Error::other)?; + Ok(writer.len) +} + +fn build_chunk_envelope( + envelope: &ServerEnvelope, + segment_id: usize, + segment_count: usize, + message_size_bytes: usize, + chunk: &[u8], +) -> io::Result { + if segment_count > REMOTE_CONTROL_SEGMENT_COUNT_MAX { + return Err(io::Error::new( + ErrorKind::InvalidData, + "remote-control segment count exceeds maximum", + )); + } + Ok(ServerEnvelope { + event: ServerEvent::ServerMessageChunk { + segment_id, + segment_count, + message_size_bytes, + message_chunk_base64: base64::engine::general_purpose::STANDARD.encode(chunk), + }, + client_id: envelope.client_id.clone(), + stream_id: envelope.stream_id.clone(), + seq_id: envelope.seq_id, + }) +} diff --git a/codex-rs/app-server/src/transport/remote_control/segment_tests.rs b/codex-rs/app-server/src/transport/remote_control/segment_tests.rs new file mode 100644 index 0000000000..dc15bdf8ba --- /dev/null +++ b/codex-rs/app-server/src/transport/remote_control/segment_tests.rs @@ -0,0 +1,386 @@ +use super::protocol::ClientEnvelope; +use super::protocol::ClientEvent; +use super::protocol::ClientId; +use super::protocol::ServerEnvelope; +use super::protocol::ServerEvent; +use super::protocol::StreamId; +use super::segment::ClientSegmentObservation; +use super::segment::ClientSegmentReassembler; +use super::segment::REMOTE_CONTROL_SEGMENT_MAX_BYTES; +use super::segment::split_server_envelope_for_transport; +use crate::outgoing_message::OutgoingMessage; +use base64::Engine; +use codex_app_server_protocol::ConfigWarningNotification; +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCNotification; +use codex_app_server_protocol::ServerNotification; +use pretty_assertions::assert_eq; + +#[test] +fn reassembles_client_message_chunks() { + let message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + let raw = serde_json::to_vec(&message).expect("message should serialize"); + let split = raw.len() / 2; + let client_id = ClientId("client-1".to_string()); + let stream_id = Some(StreamId("stream-1".to_string())); + let mut reassembler = ClientSegmentReassembler::default(); + + assert!(matches!( + reassembler.observe(chunk_envelope( + client_id.clone(), + stream_id.clone(), + /*seq_id*/ 7, + /*segment_id*/ 0, + /*segment_count*/ 2, + raw.len(), + &raw[..split], + )), + ClientSegmentObservation::Pending + )); + let reassembled = match reassembler.observe(chunk_envelope( + client_id.clone(), + stream_id, + /*seq_id*/ 7, + /*segment_id*/ 1, + /*segment_count*/ 2, + raw.len(), + &raw[split..], + )) { + ClientSegmentObservation::Forward(reassembled) => *reassembled, + ClientSegmentObservation::Pending | ClientSegmentObservation::Dropped => { + panic!("message should reassemble") + } + }; + assert_eq!(reassembled.client_id, client_id); + assert_eq!( + reassembled.stream_id, + Some(StreamId("stream-1".to_string())) + ); + assert_eq!(reassembled.seq_id, Some(7)); + assert_eq!(reassembled.cursor, None); + match reassembled.event { + ClientEvent::ClientMessage { + message: reassembled_message, + } => assert_eq!(reassembled_message, message), + other => panic!("expected client message, got {other:?}"), + } +} + +#[test] +fn splits_large_server_messages_into_wire_chunks() { + let envelope = ServerEnvelope { + event: ServerEvent::ServerMessage { + message: Box::new(OutgoingMessage::AppServerNotification( + ServerNotification::ConfigWarning(ConfigWarningNotification { + summary: "x".repeat(REMOTE_CONTROL_SEGMENT_MAX_BYTES), + details: None, + path: None, + range: None, + }), + )), + }, + client_id: ClientId("client-1".to_string()), + stream_id: StreamId("stream-1".to_string()), + seq_id: 9, + }; + + let segments = split_server_envelope_for_transport(envelope).expect("split should succeed"); + + assert!(segments.len() > 1); + assert!( + segments + .iter() + .all(|segment| matches!(segment.event, ServerEvent::ServerMessageChunk { .. })) + ); + assert!(segments.iter().all(|segment| segment.seq_id == 9)); + assert!(segments.iter().all(|segment| { + serde_json::to_vec(segment) + .expect("segment should serialize") + .len() + <= REMOTE_CONTROL_SEGMENT_MAX_BYTES + })); +} + +#[test] +fn invalidates_incomplete_stream_assemblies() { + let message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + let raw = serde_json::to_vec(&message).expect("message should serialize"); + let split = raw.len() / 2; + let client_id = ClientId("client-1".to_string()); + let stream_id = StreamId("stream-1".to_string()); + let mut reassembler = ClientSegmentReassembler::default(); + + assert!(matches!( + reassembler.observe(chunk_envelope( + client_id.clone(), + Some(stream_id.clone()), + /*seq_id*/ 7, + /*segment_id*/ 0, + /*segment_count*/ 2, + raw.len(), + &raw[..split], + )), + ClientSegmentObservation::Pending + )); + reassembler.invalidate_stream(&client_id, &stream_id); + assert!(matches!( + reassembler.observe(chunk_envelope( + client_id, + Some(stream_id), + /*seq_id*/ 7, + /*segment_id*/ 1, + /*segment_count*/ 2, + raw.len(), + &raw[split..], + )), + ClientSegmentObservation::Dropped + )); +} + +#[test] +fn resets_incomplete_client_assembly_when_stream_changes() { + let message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + let raw = serde_json::to_vec(&message).expect("message should serialize"); + let split = raw.len() / 2; + let client_id = ClientId("client-1".to_string()); + let first_stream_id = StreamId("stream-1".to_string()); + let second_stream_id = StreamId("stream-2".to_string()); + let mut reassembler = ClientSegmentReassembler::default(); + + assert!(matches!( + reassembler.observe(chunk_envelope( + client_id.clone(), + Some(first_stream_id.clone()), + /*seq_id*/ 7, + /*segment_id*/ 0, + /*segment_count*/ 2, + raw.len(), + &raw[..split], + )), + ClientSegmentObservation::Pending + )); + assert!(matches!( + reassembler.observe(chunk_envelope( + client_id.clone(), + Some(second_stream_id.clone()), + /*seq_id*/ 8, + /*segment_id*/ 0, + /*segment_count*/ 2, + raw.len(), + &raw[..split], + )), + ClientSegmentObservation::Pending + )); + let reassembled = match reassembler.observe(chunk_envelope( + client_id.clone(), + Some(second_stream_id), + /*seq_id*/ 8, + /*segment_id*/ 1, + /*segment_count*/ 2, + raw.len(), + &raw[split..], + )) { + ClientSegmentObservation::Forward(reassembled) => *reassembled, + ClientSegmentObservation::Pending | ClientSegmentObservation::Dropped => { + panic!("replacement stream should reassemble") + } + }; + assert_eq!( + reassembled.stream_id, + Some(StreamId("stream-2".to_string())) + ); + assert!(matches!( + reassembler.observe(chunk_envelope( + client_id, + Some(first_stream_id), + /*seq_id*/ 7, + /*segment_id*/ 1, + /*segment_count*/ 2, + raw.len(), + &raw[split..], + )), + ClientSegmentObservation::Dropped + )); +} + +#[test] +fn ignores_stale_chunks_without_dropping_newer_assembly() { + let message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + let raw = serde_json::to_vec(&message).expect("message should serialize"); + let split = raw.len() / 2; + let client_id = ClientId("client-1".to_string()); + let stream_id = Some(StreamId("stream-1".to_string())); + let mut reassembler = ClientSegmentReassembler::default(); + + assert!(matches!( + reassembler.observe(chunk_envelope( + client_id.clone(), + stream_id.clone(), + /*seq_id*/ 8, + /*segment_id*/ 0, + /*segment_count*/ 2, + raw.len(), + &raw[..split], + )), + ClientSegmentObservation::Pending + )); + assert!(matches!( + reassembler.observe(chunk_envelope( + client_id.clone(), + stream_id.clone(), + /*seq_id*/ 7, + /*segment_id*/ 0, + /*segment_count*/ 2, + raw.len(), + &raw[..split], + )), + ClientSegmentObservation::Dropped + )); + assert!(matches!( + reassembler.observe(chunk_envelope( + client_id, + stream_id, + /*seq_id*/ 8, + /*segment_id*/ 1, + /*segment_count*/ 2, + raw.len(), + &raw[split..], + )), + ClientSegmentObservation::Forward(_) + )); +} + +#[test] +fn ignores_invalid_stale_chunks_without_dropping_newer_assembly() { + let message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + let raw = serde_json::to_vec(&message).expect("message should serialize"); + let split = raw.len() / 2; + let client_id = ClientId("client-1".to_string()); + let stream_id = Some(StreamId("stream-1".to_string())); + let mut reassembler = ClientSegmentReassembler::default(); + + assert!(matches!( + reassembler.observe(chunk_envelope( + client_id.clone(), + stream_id.clone(), + /*seq_id*/ 8, + /*segment_id*/ 0, + /*segment_count*/ 2, + raw.len(), + &raw[..split], + )), + ClientSegmentObservation::Pending + )); + assert!(matches!( + reassembler.observe(chunk_envelope( + client_id.clone(), + stream_id.clone(), + /*seq_id*/ 7, + /*segment_id*/ 1, + /*segment_count*/ 2, + raw.len(), + b"", + )), + ClientSegmentObservation::Dropped + )); + assert!(matches!( + reassembler.observe(chunk_envelope( + client_id, + stream_id, + /*seq_id*/ 8, + /*segment_id*/ 1, + /*segment_count*/ 2, + raw.len(), + &raw[split..], + )), + ClientSegmentObservation::Forward(_) + )); +} + +#[test] +fn ignores_invalid_duplicate_chunks_without_dropping_current_assembly() { + let message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + let raw = serde_json::to_vec(&message).expect("message should serialize"); + let split = raw.len() / 2; + let client_id = ClientId("client-1".to_string()); + let stream_id = Some(StreamId("stream-1".to_string())); + let mut reassembler = ClientSegmentReassembler::default(); + + assert!(matches!( + reassembler.observe(chunk_envelope( + client_id.clone(), + stream_id.clone(), + /*seq_id*/ 8, + /*segment_id*/ 0, + /*segment_count*/ 2, + raw.len(), + &raw[..split], + )), + ClientSegmentObservation::Pending + )); + assert!(matches!( + reassembler.observe(chunk_envelope( + client_id.clone(), + stream_id.clone(), + /*seq_id*/ 8, + /*segment_id*/ 0, + /*segment_count*/ 2, + raw.len(), + b"", + )), + ClientSegmentObservation::Dropped + )); + assert!(matches!( + reassembler.observe(chunk_envelope( + client_id, + stream_id, + /*seq_id*/ 8, + /*segment_id*/ 1, + /*segment_count*/ 2, + raw.len(), + &raw[split..], + )), + ClientSegmentObservation::Forward(_) + )); +} + +fn chunk_envelope( + client_id: ClientId, + stream_id: Option, + seq_id: u64, + segment_id: usize, + segment_count: usize, + message_size_bytes: usize, + chunk: &[u8], +) -> ClientEnvelope { + ClientEnvelope { + event: ClientEvent::ClientMessageChunk { + segment_id, + segment_count, + message_size_bytes, + message_chunk_base64: base64::engine::general_purpose::STANDARD.encode(chunk), + }, + client_id, + stream_id, + seq_id: Some(seq_id), + cursor: None, + } +} diff --git a/codex-rs/app-server/src/transport/remote_control/tests.rs b/codex-rs/app-server/src/transport/remote_control/tests.rs index 92ac9e431e..5fd3caa401 100644 --- a/codex-rs/app-server/src/transport/remote_control/tests.rs +++ b/codex-rs/app-server/src/transport/remote_control/tests.rs @@ -831,7 +831,7 @@ async fn remote_control_transport_clears_outgoing_buffer_when_backend_acks() { send_client_event( &mut first_websocket, ClientEnvelope { - event: ClientEvent::Ack, + event: ClientEvent::Ack { segment_id: None }, client_id: client_id.clone(), stream_id: Some(stream_id), seq_id: Some(1), diff --git a/codex-rs/app-server/src/transport/remote_control/websocket.rs b/codex-rs/app-server/src/transport/remote_control/websocket.rs index df673079cd..f7b49b72ec 100644 --- a/codex-rs/app-server/src/transport/remote_control/websocket.rs +++ b/codex-rs/app-server/src/transport/remote_control/websocket.rs @@ -15,6 +15,10 @@ use super::protocol::ClientId; use super::protocol::RemoteControlTarget; use super::protocol::ServerEnvelope; use super::protocol::StreamId; +use super::segment::ClientSegmentObservation; +use super::segment::ClientSegmentReassembler; +use super::segment::REMOTE_CONTROL_SEGMENT_MAX_BYTES; +use super::segment::split_server_envelope_for_transport; use axum::http::HeaderValue; use base64::Engine; use codex_app_server_protocol::RemoteControlConnectionStatus; @@ -49,7 +53,7 @@ use tracing::error; use tracing::info; use tracing::warn; -pub(super) const REMOTE_CONTROL_PROTOCOL_VERSION: &str = "2"; +pub(super) const REMOTE_CONTROL_PROTOCOL_VERSION: &str = "3"; pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id"; const REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER: &str = "x-codex-subscribe-cursor"; const REMOTE_CONTROL_WEBSOCKET_PING_INTERVAL: std::time::Duration = @@ -85,17 +89,29 @@ impl BoundedOutboundBuffer { self.used_tx.send_modify(|used| *used += 1); } - fn ack(&mut self, client_id: &ClientId, stream_id: &StreamId, acked_seq_id: u64) { + fn ack( + &mut self, + client_id: &ClientId, + stream_id: &StreamId, + acked_seq_id: u64, + acked_segment_id: Option, + ) { let key = (client_id.clone(), stream_id.clone()); let Some(buffer) = self.buffer_by_stream.get_mut(&key) else { return; }; - while let Some(server_envelope) = buffer.front() - && server_envelope.seq_id <= acked_seq_id - { - buffer.pop_front(); - self.used_tx.send_modify(|used| *used -= 1); - } + let acked_cursor = (acked_seq_id, acked_segment_id.unwrap_or(usize::MAX)); + buffer.retain(|server_envelope| { + let envelope_cursor = ( + server_envelope.seq_id, + server_envelope.event.segment_id().unwrap_or_default(), + ); + let is_acked = envelope_cursor <= acked_cursor; + if is_acked { + self.used_tx.send_modify(|used| *used -= 1); + } + !is_acked + }); if buffer.is_empty() { self.buffer_by_stream.remove(&key); } @@ -112,6 +128,88 @@ struct WebsocketState { outbound_buffer: BoundedOutboundBuffer, subscribe_cursor: Option, next_seq_id_by_stream: HashMap<(ClientId, StreamId), u64>, + last_completed_client_chunk_seq_id_by_stream: HashMap<(ClientId, Option), u64>, + client_segment_reassembler: ClientSegmentReassembler, +} + +impl WebsocketState { + fn observe_client_message( + &mut self, + client_envelope: ClientEnvelope, + wire_size_bytes: usize, + ) -> ClientSegmentObservation { + let client_message_key = Self::client_message_key(&client_envelope); + if let Some((key, seq_id)) = client_message_key.as_ref() + && self + .last_completed_client_chunk_seq_id_by_stream + .get(key) + .is_some_and(|last_seq_id| last_seq_id >= seq_id) + { + return ClientSegmentObservation::Dropped; + } + if let ( + Some((_, seq_id)), + Some(stream_id), + ClientEvent::ClientMessageChunk { segment_id, .. }, + ) = ( + client_message_key.as_ref(), + client_envelope.stream_id.as_ref(), + &client_envelope.event, + ) && self.client_segment_reassembler.should_ignore_chunk( + &client_envelope.client_id, + stream_id, + *seq_id, + *segment_id, + ) { + return ClientSegmentObservation::Dropped; + } + if client_message_key.is_some() && wire_size_bytes > REMOTE_CONTROL_SEGMENT_MAX_BYTES { + warn!( + client_id = client_envelope.client_id.0.as_str(), + "dropping oversized segmented remote-control client envelope" + ); + if let Some(stream_id) = client_envelope.stream_id.as_ref() { + self.client_segment_reassembler + .invalidate_stream(&client_envelope.client_id, stream_id); + } + return ClientSegmentObservation::Dropped; + } + + let observation = self.client_segment_reassembler.observe(client_envelope); + if matches!(observation, ClientSegmentObservation::Forward(_)) + && let Some((key, seq_id)) = client_message_key + { + self.last_completed_client_chunk_seq_id_by_stream + .insert(key, seq_id); + } + observation + } + + fn invalidate_client_message_stream(&mut self, client_id: &ClientId, stream_id: &StreamId) { + self.last_completed_client_chunk_seq_id_by_stream + .remove(&(client_id.clone(), Some(stream_id.clone()))); + } + + fn invalidate_client_message_client(&mut self, client_id: &ClientId) { + self.last_completed_client_chunk_seq_id_by_stream + .retain(|(cursor_client_id, _), _| cursor_client_id != client_id); + } + + fn client_message_key( + client_envelope: &ClientEnvelope, + ) -> Option<((ClientId, Option), u64)> { + let seq_id = match (&client_envelope.event, client_envelope.seq_id) { + (ClientEvent::ClientMessageChunk { .. }, Some(seq_id)) => seq_id, + _ => return None, + }; + Some(( + ( + client_envelope.client_id.clone(), + client_envelope.stream_id.clone(), + ), + seq_id, + )) + } } pub(crate) struct RemoteControlWebsocket { @@ -231,6 +329,8 @@ impl RemoteControlWebsocket { outbound_buffer, subscribe_cursor: None, next_seq_id_by_stream: HashMap::new(), + last_completed_client_chunk_seq_id_by_stream: HashMap::new(), + client_segment_reassembler: ClientSegmentReassembler::default(), })), server_event_rx: Arc::new(Mutex::new(server_event_rx)), used_rx, @@ -556,7 +656,7 @@ impl RemoteControlWebsocket { } } }; - let (payload, write_complete_tx) = { + let (payloads, write_complete_tx) = { let mut state = state.lock().await; let seq_key = ( queued_server_envelope.client_id.clone(), @@ -573,29 +673,42 @@ impl RemoteControlWebsocket { seq_id, stream_id: queued_server_envelope.stream_id, }; - let payload = match serde_json::to_string(&server_envelope) { - Ok(payload) => payload, + let server_envelopes = match split_server_envelope_for_transport(server_envelope) { + Ok(server_envelopes) => server_envelopes, Err(err) => { - error!("failed to serialize remote-control server event: {err}"); + error!("failed to split remote-control server event: {err}"); continue; } }; + let mut payloads = Vec::with_capacity(server_envelopes.len()); + for server_envelope in server_envelopes { + let payload = match serde_json::to_string(&server_envelope) { + Ok(payload) => payload, + Err(err) => { + error!("failed to serialize remote-control server event: {err}"); + continue; + } + }; + state.outbound_buffer.insert(&server_envelope); + payloads.push(payload); + } state .next_seq_id_by_stream .insert(seq_key, seq_id.saturating_add(1)); - state.outbound_buffer.insert(&server_envelope); - (payload, queued_server_envelope.write_complete_tx) + (payloads, queued_server_envelope.write_complete_tx) }; - tokio::select! { - _ = shutdown_token.cancelled() => return Ok(()), - send_result = websocket_writer.send(tungstenite::Message::Text(payload.into())) => { - if let Err(err) = send_result { - return Err(io::Error::other(err)); + for payload in payloads { + tokio::select! { + _ = shutdown_token.cancelled() => return Ok(()), + send_result = websocket_writer.send(tungstenite::Message::Text(payload.into())) => { + if let Err(err) = send_result { + return Err(io::Error::other(err)); + } } } - }; + } if let Some(write_complete_tx) = write_complete_tx { let _ = write_complete_tx.send(()); } @@ -657,11 +770,30 @@ impl RemoteControlWebsocket { if client_tracker.close_client(&client_key).await.is_err() { return Ok(()); } + state + .lock() + .await + .client_segment_reassembler + .invalidate_stream(&client_key.0, &client_key.1); + state + .lock() + .await + .invalidate_client_message_stream(&client_key.0, &client_key.1); continue; } _ = idle_sweep_interval.tick() => { - if client_tracker.close_expired_clients().await.is_err() { - return Ok(()); + match client_tracker.close_expired_clients().await { + Ok(client_keys) => { + let mut websocket_state = state.lock().await; + for (client_id, stream_id) in client_keys { + websocket_state + .client_segment_reassembler + .invalidate_stream(&client_id, &stream_id); + websocket_state + .invalidate_client_message_stream(&client_id, &stream_id); + } + } + Err(_) => return Ok(()), } continue; } @@ -672,10 +804,11 @@ impl RemoteControlWebsocket { } } }; - let client_envelope = match incoming_message { + let (client_envelope, wire_size_bytes) = match incoming_message { Ok(tungstenite::Message::Text(text)) => { + let wire_size_bytes = text.len(); match serde_json::from_str::(&text) { - Ok(client_envelope) => client_envelope, + Ok(client_envelope) => (client_envelope, wire_size_bytes), Err(err) => { warn!("failed to deserialize remote-control client event: {err}"); continue; @@ -707,12 +840,21 @@ impl RemoteControlWebsocket { } }; + let observation = { + let mut websocket_state = state.lock().await; + websocket_state.observe_client_message(client_envelope, wire_size_bytes) + }; + let client_envelope = match observation { + ClientSegmentObservation::Forward(client_envelope) => *client_envelope, + ClientSegmentObservation::Pending | ClientSegmentObservation::Dropped => continue, + }; + { let mut websocket_state = state.lock().await; if let Some(cursor) = client_envelope.cursor.as_deref() { websocket_state.subscribe_cursor = Some(cursor.to_string()); } - if let ClientEvent::Ack = &client_envelope.event + if let ClientEvent::Ack { segment_id } = &client_envelope.event && let Some(acked_seq_id) = client_envelope.seq_id && let Some(stream_id) = client_envelope.stream_id.as_ref() { @@ -720,10 +862,18 @@ impl RemoteControlWebsocket { &client_envelope.client_id, stream_id, acked_seq_id, + *segment_id, ); } } + let closed_client = + matches!(&client_envelope.event, ClientEvent::ClientClosed).then(|| { + ( + client_envelope.client_id.clone(), + client_envelope.stream_id.clone(), + ) + }); if client_tracker .handle_message(client_envelope) .await @@ -731,6 +881,20 @@ impl RemoteControlWebsocket { { return Ok(()); } + if let Some((client_id, stream_id)) = closed_client { + let mut websocket_state = state.lock().await; + if let Some(stream_id) = stream_id { + websocket_state + .client_segment_reassembler + .invalidate_stream(&client_id, &stream_id); + websocket_state.invalidate_client_message_stream(&client_id, &stream_id); + } else { + websocket_state + .client_segment_reassembler + .invalidate_client(&client_id); + websocket_state.invalidate_client_message_client(&client_id); + } + } } } } @@ -1052,6 +1216,8 @@ mod tests { use chrono::Utc; use codex_app_server_protocol::AuthMode; use codex_app_server_protocol::ConfigWarningNotification; + use codex_app_server_protocol::JSONRPCMessage; + use codex_app_server_protocol::JSONRPCNotification; use codex_app_server_protocol::ServerNotification; use codex_config::types::AuthCredentialsStoreMode; use codex_core::test_support::auth_manager_from_auth; @@ -1603,6 +1769,8 @@ mod tests { outbound_buffer, subscribe_cursor: None, next_seq_id_by_stream: HashMap::new(), + last_completed_client_chunk_seq_id_by_stream: HashMap::new(), + client_segment_reassembler: ClientSegmentReassembler::default(), })); let (_server_event_tx, server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY); let server_event_rx = Arc::new(Mutex::new(server_event_rx)); @@ -1639,6 +1807,8 @@ mod tests { outbound_buffer, subscribe_cursor: None, next_seq_id_by_stream: HashMap::new(), + last_completed_client_chunk_seq_id_by_stream: HashMap::new(), + client_segment_reassembler: ClientSegmentReassembler::default(), })); let (server_event_tx, server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY); let server_event_rx = Arc::new(Mutex::new(server_event_rx)); @@ -1716,6 +1886,8 @@ mod tests { outbound_buffer, subscribe_cursor: None, next_seq_id_by_stream: HashMap::new(), + last_completed_client_chunk_seq_id_by_stream: HashMap::new(), + client_segment_reassembler: ClientSegmentReassembler::default(), })); let (server_event_tx, _server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY); let (transport_event_tx, _transport_event_rx) = @@ -1771,7 +1943,9 @@ mod tests { "first-client-new-stream", )); - outbound_buffer.ack(&client_1, &stream_1, /*acked_seq_id*/ 3); + outbound_buffer.ack( + &client_1, &stream_1, /*acked_seq_id*/ 3, /*acked_segment_id*/ None, + ); let mut retained = outbound_buffer .server_envelopes() @@ -1814,7 +1988,9 @@ mod tests { &client_2, "stream-1", /*seq_id*/ 3, "second", )); - outbound_buffer.ack(&client_1, &stream_1, /*acked_seq_id*/ 1); + outbound_buffer.ack( + &client_1, &stream_1, /*acked_seq_id*/ 1, /*acked_segment_id*/ None, + ); let mut retained = outbound_buffer .server_envelopes() @@ -1834,6 +2010,390 @@ mod tests { assert_eq!(*used_rx.borrow(), 2); } + #[test] + fn outbound_buffer_advances_segmented_acks_by_wire_cursor() { + let (mut outbound_buffer, used_rx) = BoundedOutboundBuffer::new(); + let client_id = ClientId("client-1".to_string()); + let stream_id = StreamId("stream-1".to_string()); + + outbound_buffer.insert(&server_chunk_envelope( + &client_id, "stream-1", /*seq_id*/ 4, /*segment_id*/ 0, + )); + outbound_buffer.insert(&server_chunk_envelope( + &client_id, "stream-1", /*seq_id*/ 4, /*segment_id*/ 1, + )); + + outbound_buffer.ack( + &client_id, + &stream_id, + /*acked_seq_id*/ 4, + /*acked_segment_id*/ Some(1), + ); + + let retained = outbound_buffer + .server_envelopes() + .map(|server_envelope| server_envelope.event.segment_id()) + .collect::>(); + assert_eq!(retained, Vec::>::new()); + assert_eq!(*used_rx.borrow(), 0); + } + + #[test] + fn outbound_buffer_treats_segmentless_acks_as_seq_level_acks() { + let (mut outbound_buffer, used_rx) = BoundedOutboundBuffer::new(); + let client_id = ClientId("client-1".to_string()); + let stream_id = StreamId("stream-1".to_string()); + + outbound_buffer.insert(&server_chunk_envelope( + &client_id, "stream-1", /*seq_id*/ 4, /*segment_id*/ 0, + )); + outbound_buffer.insert(&server_chunk_envelope( + &client_id, "stream-1", /*seq_id*/ 4, /*segment_id*/ 1, + )); + + outbound_buffer.ack( + &client_id, &stream_id, /*acked_seq_id*/ 4, /*acked_segment_id*/ None, + ); + + let retained = outbound_buffer + .server_envelopes() + .map(|server_envelope| server_envelope.event.segment_id()) + .collect::>(); + assert_eq!(retained, Vec::>::new()); + assert_eq!(*used_rx.borrow(), 0); + } + + #[test] + fn websocket_state_drops_duplicate_client_chunks_while_pending() { + let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new(); + let mut state = WebsocketState { + outbound_buffer, + subscribe_cursor: None, + next_seq_id_by_stream: HashMap::new(), + last_completed_client_chunk_seq_id_by_stream: HashMap::new(), + client_segment_reassembler: ClientSegmentReassembler::default(), + }; + let first_chunk = client_chunk_envelope( + "client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0, + /*segment_count*/ 2, /*message_size_bytes*/ 2, b"x", + ); + let second_chunk = client_chunk_envelope( + "client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 1, + /*segment_count*/ 2, /*message_size_bytes*/ 2, b"y", + ); + + assert!(matches!( + observe_client_message(&mut state, first_chunk.clone()), + ClientSegmentObservation::Pending + )); + assert!(matches!( + observe_client_message(&mut state, first_chunk.clone()), + ClientSegmentObservation::Dropped + )); + assert!(matches!( + observe_client_message(&mut state, second_chunk), + ClientSegmentObservation::Dropped + )); + assert!(matches!( + observe_client_message(&mut state, first_chunk), + ClientSegmentObservation::Pending + )); + } + + #[test] + fn websocket_state_drops_replayed_client_chunks_after_completion() { + let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new(); + let mut state = WebsocketState { + outbound_buffer, + subscribe_cursor: None, + next_seq_id_by_stream: HashMap::new(), + last_completed_client_chunk_seq_id_by_stream: HashMap::new(), + client_segment_reassembler: ClientSegmentReassembler::default(), + }; + let message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + let raw = serde_json::to_vec(&message).expect("message should serialize"); + let split = raw.len() / 2; + let first_chunk = client_chunk_envelope( + "client-1", + "stream-1", + /*seq_id*/ 4, + /*segment_id*/ 0, + /*segment_count*/ 2, + raw.len(), + &raw[..split], + ); + let second_chunk = client_chunk_envelope( + "client-1", + "stream-1", + /*seq_id*/ 4, + /*segment_id*/ 1, + /*segment_count*/ 2, + raw.len(), + &raw[split..], + ); + + assert!(matches!( + observe_client_message(&mut state, first_chunk.clone()), + ClientSegmentObservation::Pending + )); + assert!(matches!( + observe_client_message(&mut state, second_chunk), + ClientSegmentObservation::Forward(_) + )); + assert!(matches!( + observe_client_message(&mut state, first_chunk), + ClientSegmentObservation::Dropped + )); + } + + #[test] + fn websocket_state_allows_replay_after_rejected_out_of_order_chunk() { + let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new(); + let mut state = WebsocketState { + outbound_buffer, + subscribe_cursor: None, + next_seq_id_by_stream: HashMap::new(), + last_completed_client_chunk_seq_id_by_stream: HashMap::new(), + client_segment_reassembler: ClientSegmentReassembler::default(), + }; + let first_chunk = client_chunk_envelope( + "client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0, + /*segment_count*/ 2, /*message_size_bytes*/ 2, b"x", + ); + let second_chunk = client_chunk_envelope( + "client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 1, + /*segment_count*/ 2, /*message_size_bytes*/ 2, b"y", + ); + + assert!(matches!( + observe_client_message(&mut state, second_chunk), + ClientSegmentObservation::Dropped + )); + assert!(matches!( + observe_client_message(&mut state, first_chunk), + ClientSegmentObservation::Pending + )); + } + + #[test] + fn websocket_state_allows_replay_after_later_chunk_drops() { + let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new(); + let mut state = WebsocketState { + outbound_buffer, + subscribe_cursor: None, + next_seq_id_by_stream: HashMap::new(), + last_completed_client_chunk_seq_id_by_stream: HashMap::new(), + client_segment_reassembler: ClientSegmentReassembler::default(), + }; + let first_chunk = client_chunk_envelope( + "client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0, + /*segment_count*/ 2, /*message_size_bytes*/ 2, b"x", + ); + let invalid_second_chunk = client_chunk_envelope( + "client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 1, + /*segment_count*/ 2, /*message_size_bytes*/ 2, b"", + ); + + assert!(matches!( + observe_client_message(&mut state, first_chunk.clone()), + ClientSegmentObservation::Pending + )); + assert!(matches!( + observe_client_message(&mut state, invalid_second_chunk), + ClientSegmentObservation::Dropped + )); + assert!(matches!( + observe_client_message(&mut state, first_chunk), + ClientSegmentObservation::Pending + )); + } + + #[test] + fn websocket_state_drops_oversized_client_chunk_frames() { + let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new(); + let mut state = WebsocketState { + outbound_buffer, + subscribe_cursor: None, + next_seq_id_by_stream: HashMap::new(), + last_completed_client_chunk_seq_id_by_stream: HashMap::new(), + client_segment_reassembler: ClientSegmentReassembler::default(), + }; + let chunk = client_chunk_envelope( + "client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0, + /*segment_count*/ 1, /*message_size_bytes*/ 1, b"x", + ); + + assert!(matches!( + state.observe_client_message(chunk, REMOTE_CONTROL_SEGMENT_MAX_BYTES + 1), + ClientSegmentObservation::Dropped + )); + } + + #[test] + fn websocket_state_ignores_oversized_stale_chunks_without_dropping_newer_assembly() { + let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new(); + let mut state = WebsocketState { + outbound_buffer, + subscribe_cursor: None, + next_seq_id_by_stream: HashMap::new(), + last_completed_client_chunk_seq_id_by_stream: HashMap::new(), + client_segment_reassembler: ClientSegmentReassembler::default(), + }; + let message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + let raw = serde_json::to_vec(&message).expect("message should serialize"); + let split = raw.len() / 2; + let first_newer_chunk = client_chunk_envelope( + "client-1", + "stream-1", + /*seq_id*/ 8, + /*segment_id*/ 0, + /*segment_count*/ 2, + raw.len(), + &raw[..split], + ); + let oversized_stale_chunk = client_chunk_envelope( + "client-1", + "stream-1", + /*seq_id*/ 7, + /*segment_id*/ 0, + /*segment_count*/ 2, + raw.len(), + &raw[..split], + ); + let second_newer_chunk = client_chunk_envelope( + "client-1", + "stream-1", + /*seq_id*/ 8, + /*segment_id*/ 1, + /*segment_count*/ 2, + raw.len(), + &raw[split..], + ); + + assert!(matches!( + observe_client_message(&mut state, first_newer_chunk), + ClientSegmentObservation::Pending + )); + assert!(matches!( + state.observe_client_message( + oversized_stale_chunk, + REMOTE_CONTROL_SEGMENT_MAX_BYTES + 1, + ), + ClientSegmentObservation::Dropped + )); + assert!(matches!( + observe_client_message(&mut state, second_newer_chunk), + ClientSegmentObservation::Forward(_) + )); + } + + #[test] + fn websocket_state_ignores_oversized_duplicate_chunks_without_dropping_current_assembly() { + let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new(); + let mut state = WebsocketState { + outbound_buffer, + subscribe_cursor: None, + next_seq_id_by_stream: HashMap::new(), + last_completed_client_chunk_seq_id_by_stream: HashMap::new(), + client_segment_reassembler: ClientSegmentReassembler::default(), + }; + let message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + let raw = serde_json::to_vec(&message).expect("message should serialize"); + let split = raw.len() / 2; + let first_chunk = client_chunk_envelope( + "client-1", + "stream-1", + /*seq_id*/ 8, + /*segment_id*/ 0, + /*segment_count*/ 2, + raw.len(), + &raw[..split], + ); + let oversized_duplicate_chunk = client_chunk_envelope( + "client-1", + "stream-1", + /*seq_id*/ 8, + /*segment_id*/ 0, + /*segment_count*/ 2, + raw.len(), + &raw[..split], + ); + let second_chunk = client_chunk_envelope( + "client-1", + "stream-1", + /*seq_id*/ 8, + /*segment_id*/ 1, + /*segment_count*/ 2, + raw.len(), + &raw[split..], + ); + + assert!(matches!( + observe_client_message(&mut state, first_chunk), + ClientSegmentObservation::Pending + )); + assert!(matches!( + state.observe_client_message( + oversized_duplicate_chunk, + REMOTE_CONTROL_SEGMENT_MAX_BYTES + 1, + ), + ClientSegmentObservation::Dropped + )); + assert!(matches!( + observe_client_message(&mut state, second_chunk), + ClientSegmentObservation::Forward(_) + )); + } + + #[test] + fn websocket_state_clears_chunk_cursor_when_stream_is_invalidated() { + let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new(); + let mut state = WebsocketState { + outbound_buffer, + subscribe_cursor: None, + next_seq_id_by_stream: HashMap::new(), + last_completed_client_chunk_seq_id_by_stream: HashMap::new(), + client_segment_reassembler: ClientSegmentReassembler::default(), + }; + let client_id = ClientId("client-1".to_string()); + let stream_id = StreamId("stream-1".to_string()); + + assert!(matches!( + observe_client_message( + &mut state, + client_chunk_envelope( + "client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0, + /*segment_count*/ 2, /*message_size_bytes*/ 2, b"x", + ) + ), + ClientSegmentObservation::Pending + )); + state.invalidate_client_message_stream(&client_id, &stream_id); + state + .client_segment_reassembler + .invalidate_stream(&client_id, &stream_id); + + assert!(matches!( + observe_client_message( + &mut state, + client_chunk_envelope( + "client-1", "stream-1", /*seq_id*/ 1, /*segment_id*/ 0, + /*segment_count*/ 2, /*message_size_bytes*/ 2, b"x", + ) + ), + ClientSegmentObservation::Pending + )); + } + fn server_envelope( client_id: &ClientId, stream_id: &str, @@ -1857,6 +2417,58 @@ mod tests { } } + fn server_chunk_envelope( + client_id: &ClientId, + stream_id: &str, + seq_id: u64, + segment_id: usize, + ) -> ServerEnvelope { + ServerEnvelope { + event: ServerEvent::ServerMessageChunk { + segment_id, + segment_count: 2, + message_size_bytes: 2, + message_chunk_base64: String::new(), + }, + client_id: client_id.clone(), + stream_id: StreamId(stream_id.to_string()), + seq_id, + } + } + + fn client_chunk_envelope( + client_id: &str, + stream_id: &str, + seq_id: u64, + segment_id: usize, + segment_count: usize, + message_size_bytes: usize, + chunk: &[u8], + ) -> ClientEnvelope { + ClientEnvelope { + event: ClientEvent::ClientMessageChunk { + segment_id, + segment_count, + message_size_bytes, + message_chunk_base64: base64::engine::general_purpose::STANDARD.encode(chunk), + }, + client_id: ClientId(client_id.to_string()), + stream_id: Some(StreamId(stream_id.to_string())), + seq_id: Some(seq_id), + cursor: None, + } + } + + fn observe_client_message( + state: &mut WebsocketState, + envelope: ClientEnvelope, + ) -> ClientSegmentObservation { + let wire_size_bytes = serde_json::to_vec(&envelope) + .expect("client envelope should serialize") + .len(); + state.observe_client_message(envelope, wire_size_bytes) + } + async fn accept_http_request(listener: &TcpListener) -> (TcpStream, String) { let (stream, _) = timeout(TEST_HTTP_ACCEPT_TIMEOUT, listener.accept()) .await