mirror of
https://github.com/openai/codex.git
synced 2026-06-01 19:02:59 +00:00
app-server: switch remote control to protocol v3 segmentation (#20341)
## Why Remote-control protocol v3 makes segmentation an explicit wire-level feature. The app-server transport needs to support that protocol directly so large messages can be chunked, acknowledged, replayed, and reassembled consistently. ## What changed - Bump the remote-control websocket protocol version from `2` to `3`. - Add explicit client/server chunk envelope variants plus chunk-aware acknowledgements. - Split oversized outbound server messages into bounded transport chunks. - Reassemble ordered inbound client chunks with bounded memory usage and stream/client invalidation handling. - Track inbound chunk cursors and outbound ack cursors as `(seq_id, segment_id)` so duplicate chunks and partial replays behave correctly. - Add focused coverage for chunk splitting, reassembly, duplicate suppression, and stream replacement behavior. ## Validation - Added targeted unit coverage for segmented message handling in `remote_control`. - Local validation is currently blocked before compilation because `packageproxy` does not serve the locked `rustls-webpki 0.103.13` dependency required by the workspace.
This commit is contained in:
committed by
GitHub
parent
af089fb21d
commit
972b819213
@@ -195,7 +195,7 @@ impl ClientTracker {
|
||||
})
|
||||
.await
|
||||
}
|
||||
ClientEvent::Ack => Ok(()),
|
||||
ClientEvent::ClientMessageChunk { .. } | ClientEvent::Ack { .. } => Ok(()),
|
||||
ClientEvent::Ping => {
|
||||
if let Some(client) = self.clients.get_mut(&client_key) {
|
||||
client.last_activity_at = Instant::now();
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
mod client_tracker;
|
||||
mod enroll;
|
||||
mod protocol;
|
||||
mod segment;
|
||||
mod websocket;
|
||||
|
||||
use crate::transport::remote_control::websocket::RemoteControlChannels;
|
||||
@@ -121,5 +122,7 @@ pub(crate) async fn start_remote_control(
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod segment_tests;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
@@ -47,10 +47,20 @@ pub enum ClientEvent {
|
||||
ClientMessage {
|
||||
message: JSONRPCMessage,
|
||||
},
|
||||
ClientMessageChunk {
|
||||
segment_id: usize,
|
||||
segment_count: usize,
|
||||
message_size_bytes: usize,
|
||||
message_chunk_base64: String,
|
||||
},
|
||||
/// Backend-generated acknowledgement for all server envelopes addressed to
|
||||
/// `client_id` and `stream_id` whose envelope `seq_id` is less than or equal
|
||||
/// to this ack's `seq_id`. This cursor is stream-scoped.
|
||||
Ack,
|
||||
/// to this ack's `seq_id`. Chunk acknowledgements carry `segment_id` so the
|
||||
/// sender can retain only the still-unacked wire chunks on reconnect.
|
||||
Ack {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
segment_id: Option<usize>,
|
||||
},
|
||||
Ping,
|
||||
ClientClosed,
|
||||
}
|
||||
@@ -85,6 +95,12 @@ pub enum ServerEvent {
|
||||
ServerMessage {
|
||||
message: Box<OutgoingMessage>,
|
||||
},
|
||||
ServerMessageChunk {
|
||||
segment_id: usize,
|
||||
segment_count: usize,
|
||||
message_size_bytes: usize,
|
||||
message_chunk_base64: String,
|
||||
},
|
||||
#[allow(dead_code)]
|
||||
Ack,
|
||||
Pong {
|
||||
@@ -92,6 +108,15 @@ pub enum ServerEvent {
|
||||
},
|
||||
}
|
||||
|
||||
impl ServerEvent {
|
||||
pub(crate) fn segment_id(&self) -> Option<usize> {
|
||||
match self {
|
||||
Self::ServerMessageChunk { segment_id, .. } => Some(*segment_id),
|
||||
Self::ServerMessage { .. } | Self::Ack | Self::Pong { .. } => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub(crate) struct ServerEnvelope {
|
||||
|
||||
449
codex-rs/app-server/src/transport/remote_control/segment.rs
Normal file
449
codex-rs/app-server/src/transport/remote_control/segment.rs
Normal file
@@ -0,0 +1,449 @@
|
||||
use super::protocol::ClientEnvelope;
|
||||
use super::protocol::ClientEvent;
|
||||
use super::protocol::ClientId;
|
||||
use super::protocol::ServerEnvelope;
|
||||
use super::protocol::ServerEvent;
|
||||
use super::protocol::StreamId;
|
||||
use base64::DecodeSliceError;
|
||||
use base64::Engine;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Write;
|
||||
use tokio::time::Instant;
|
||||
use tracing::warn;
|
||||
|
||||
pub(super) const REMOTE_CONTROL_SEGMENT_TARGET_BYTES: usize = 100 * 1024;
|
||||
pub(super) const REMOTE_CONTROL_SEGMENT_MAX_BYTES: usize = 150 * 1024;
|
||||
pub(super) const REMOTE_CONTROL_REASSEMBLED_MAX_BYTES: usize = 100 * 1024 * 1024;
|
||||
pub(super) const REMOTE_CONTROL_SEGMENT_COUNT_MAX: usize = 1024;
|
||||
const REMOTE_CONTROL_SEGMENT_ASSEMBLY_MAX_COUNT: usize = 128;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ClientSegmentAssembly {
|
||||
stream_id: StreamId,
|
||||
metadata: ClientSegmentMetadata,
|
||||
raw: Vec<u8>,
|
||||
next_segment_id: usize,
|
||||
last_chunk_seen_at: Instant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct ClientSegmentMetadata {
|
||||
seq_id: u64,
|
||||
segment_count: usize,
|
||||
message_size_bytes: usize,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(super) struct ClientSegmentReassembler {
|
||||
assemblies: HashMap<ClientId, ClientSegmentAssembly>,
|
||||
}
|
||||
|
||||
pub(super) enum ClientSegmentObservation {
|
||||
Forward(Box<ClientEnvelope>),
|
||||
Pending,
|
||||
Dropped,
|
||||
}
|
||||
|
||||
impl ClientSegmentReassembler {
|
||||
pub(super) fn observe(&mut self, envelope: ClientEnvelope) -> ClientSegmentObservation {
|
||||
let ClientEvent::ClientMessageChunk {
|
||||
segment_id,
|
||||
segment_count,
|
||||
message_size_bytes,
|
||||
message_chunk_base64,
|
||||
} = &envelope.event
|
||||
else {
|
||||
return ClientSegmentObservation::Forward(Box::new(envelope));
|
||||
};
|
||||
let segment_id = *segment_id;
|
||||
let segment_count = *segment_count;
|
||||
let message_size_bytes = *message_size_bytes;
|
||||
|
||||
let Some(metadata) = ClientSegmentMetadata::from_envelope(&envelope) else {
|
||||
warn!(
|
||||
client_id = envelope.client_id.0.as_str(),
|
||||
"dropping segmented remote-control client envelope without seq_id"
|
||||
);
|
||||
return ClientSegmentObservation::Dropped;
|
||||
};
|
||||
let Some(stream_id) = envelope.stream_id.clone() else {
|
||||
warn!(
|
||||
client_id = envelope.client_id.0.as_str(),
|
||||
"dropping segmented remote-control client envelope without stream_id"
|
||||
);
|
||||
return ClientSegmentObservation::Dropped;
|
||||
};
|
||||
if self.should_ignore_chunk(&envelope.client_id, &stream_id, metadata.seq_id, segment_id) {
|
||||
return ClientSegmentObservation::Dropped;
|
||||
}
|
||||
if segment_count == 0
|
||||
|| segment_count > REMOTE_CONTROL_SEGMENT_COUNT_MAX
|
||||
|| segment_id >= segment_count
|
||||
|| message_size_bytes == 0
|
||||
|| message_size_bytes > REMOTE_CONTROL_REASSEMBLED_MAX_BYTES
|
||||
|| message_chunk_base64.is_empty()
|
||||
{
|
||||
warn!(
|
||||
client_id = envelope.client_id.0.as_str(),
|
||||
"dropping invalid segmented remote-control client envelope"
|
||||
);
|
||||
self.remove_assembly(&envelope.client_id, &stream_id);
|
||||
return ClientSegmentObservation::Dropped;
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
match self.assemblies.get(&envelope.client_id) {
|
||||
Some(assembly) if assembly.stream_id != stream_id => {
|
||||
warn!(
|
||||
client_id = envelope.client_id.0.as_str(),
|
||||
"resetting segmented remote-control client envelope after stream change"
|
||||
);
|
||||
self.assemblies.insert(
|
||||
envelope.client_id.clone(),
|
||||
ClientSegmentAssembly {
|
||||
stream_id: stream_id.clone(),
|
||||
metadata: metadata.clone(),
|
||||
raw: Vec::new(),
|
||||
next_segment_id: 0,
|
||||
last_chunk_seen_at: now,
|
||||
},
|
||||
);
|
||||
}
|
||||
Some(_) => {}
|
||||
None => {
|
||||
self.evict_assemblies_if_full();
|
||||
self.assemblies.insert(
|
||||
envelope.client_id.clone(),
|
||||
ClientSegmentAssembly {
|
||||
stream_id: stream_id.clone(),
|
||||
metadata: metadata.clone(),
|
||||
raw: Vec::new(),
|
||||
next_segment_id: 0,
|
||||
last_chunk_seen_at: now,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
let result = {
|
||||
let Some(assembly) = self.assemblies.get_mut(&envelope.client_id) else {
|
||||
warn!(
|
||||
client_id = envelope.client_id.0.as_str(),
|
||||
"dropping segmented remote-control client envelope without assembly"
|
||||
);
|
||||
return ClientSegmentObservation::Dropped;
|
||||
};
|
||||
if metadata.seq_id < assembly.metadata.seq_id {
|
||||
AssemblyUpdate::Ignore
|
||||
} else if assembly.metadata != metadata {
|
||||
warn!(
|
||||
client_id = envelope.client_id.0.as_str(),
|
||||
"resetting segmented remote-control client envelope after metadata mismatch"
|
||||
);
|
||||
AssemblyUpdate::Drop
|
||||
} else if segment_id < assembly.next_segment_id {
|
||||
AssemblyUpdate::Pending
|
||||
} else if segment_id != assembly.next_segment_id {
|
||||
warn!(
|
||||
client_id = envelope.client_id.0.as_str(),
|
||||
"dropping out-of-order segmented remote-control client envelope"
|
||||
);
|
||||
AssemblyUpdate::Drop
|
||||
} else {
|
||||
assembly.last_chunk_seen_at = now;
|
||||
let chunk_start = assembly.raw.len();
|
||||
let decoded_chunk_len = base64::decoded_len_estimate(message_chunk_base64.len());
|
||||
let chunk_end = usize::min(
|
||||
message_size_bytes,
|
||||
chunk_start.saturating_add(decoded_chunk_len),
|
||||
);
|
||||
assembly.raw.resize(chunk_end, 0);
|
||||
match base64::engine::general_purpose::STANDARD.decode_slice(
|
||||
message_chunk_base64.as_bytes(),
|
||||
&mut assembly.raw[chunk_start..],
|
||||
) {
|
||||
Ok(decoded_chunk_len) => {
|
||||
assembly.raw.truncate(chunk_start + decoded_chunk_len);
|
||||
assembly.next_segment_id += 1;
|
||||
if assembly.next_segment_id < segment_count {
|
||||
AssemblyUpdate::Pending
|
||||
} else if assembly.raw.len() != message_size_bytes {
|
||||
warn!(
|
||||
client_id = envelope.client_id.0.as_str(),
|
||||
"dropping reassembled remote-control client envelope with mismatched size"
|
||||
);
|
||||
AssemblyUpdate::Drop
|
||||
} else {
|
||||
match serde_json::from_slice::<JSONRPCMessage>(&assembly.raw) {
|
||||
Ok(message) => AssemblyUpdate::Complete(message),
|
||||
Err(err) => {
|
||||
warn!(
|
||||
client_id = envelope.client_id.0.as_str(),
|
||||
"dropping invalid reassembled remote-control client envelope: {err}"
|
||||
);
|
||||
AssemblyUpdate::Drop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(DecodeSliceError::OutputSliceTooSmall) => {
|
||||
warn!(
|
||||
client_id = envelope.client_id.0.as_str(),
|
||||
"dropping segmented remote-control client envelope after size overflow"
|
||||
);
|
||||
AssemblyUpdate::Drop
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
client_id = envelope.client_id.0.as_str(),
|
||||
"dropping segmented remote-control client envelope with invalid base64: {err}"
|
||||
);
|
||||
AssemblyUpdate::Drop
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match result {
|
||||
AssemblyUpdate::Pending => ClientSegmentObservation::Pending,
|
||||
AssemblyUpdate::Ignore => ClientSegmentObservation::Dropped,
|
||||
AssemblyUpdate::Drop => {
|
||||
self.remove_assembly(&envelope.client_id, &stream_id);
|
||||
ClientSegmentObservation::Dropped
|
||||
}
|
||||
AssemblyUpdate::Complete(message) => {
|
||||
self.remove_assembly(&envelope.client_id, &stream_id);
|
||||
ClientSegmentObservation::Forward(Box::new(ClientEnvelope {
|
||||
event: ClientEvent::ClientMessage { message },
|
||||
..envelope
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn invalidate_stream(&mut self, client_id: &ClientId, stream_id: &StreamId) {
|
||||
self.remove_assembly(client_id, stream_id);
|
||||
}
|
||||
|
||||
pub(super) fn invalidate_client(&mut self, client_id: &ClientId) {
|
||||
self.assemblies.remove(client_id);
|
||||
}
|
||||
|
||||
pub(super) fn should_ignore_chunk(
|
||||
&self,
|
||||
client_id: &ClientId,
|
||||
stream_id: &StreamId,
|
||||
seq_id: u64,
|
||||
segment_id: usize,
|
||||
) -> bool {
|
||||
self.assemblies.get(client_id).is_some_and(|assembly| {
|
||||
assembly.stream_id == *stream_id
|
||||
&& (seq_id < assembly.metadata.seq_id
|
||||
|| (seq_id == assembly.metadata.seq_id
|
||||
&& segment_id < assembly.next_segment_id))
|
||||
})
|
||||
}
|
||||
|
||||
fn remove_assembly(&mut self, client_id: &ClientId, stream_id: &StreamId) {
|
||||
if self
|
||||
.assemblies
|
||||
.get(client_id)
|
||||
.is_some_and(|assembly| &assembly.stream_id == stream_id)
|
||||
{
|
||||
self.assemblies.remove(client_id);
|
||||
}
|
||||
}
|
||||
|
||||
fn evict_assemblies_if_full(&mut self) {
|
||||
while self.assemblies.len() >= REMOTE_CONTROL_SEGMENT_ASSEMBLY_MAX_COUNT {
|
||||
let Some(client_id) = self
|
||||
.assemblies
|
||||
.iter()
|
||||
.min_by_key(|(_, assembly)| assembly.last_chunk_seen_at)
|
||||
.map(|(client_id, _)| client_id.clone())
|
||||
else {
|
||||
return;
|
||||
};
|
||||
self.assemblies.remove(&client_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum AssemblyUpdate {
|
||||
Pending,
|
||||
Ignore,
|
||||
Drop,
|
||||
Complete(JSONRPCMessage),
|
||||
}
|
||||
|
||||
impl ClientSegmentMetadata {
|
||||
fn from_envelope(envelope: &ClientEnvelope) -> Option<Self> {
|
||||
let ClientEvent::ClientMessageChunk {
|
||||
segment_count,
|
||||
message_size_bytes,
|
||||
..
|
||||
} = &envelope.event
|
||||
else {
|
||||
return None;
|
||||
};
|
||||
Some(Self {
|
||||
seq_id: envelope.seq_id?,
|
||||
segment_count: *segment_count,
|
||||
message_size_bytes: *message_size_bytes,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn split_server_envelope_for_transport(
|
||||
envelope: ServerEnvelope,
|
||||
) -> io::Result<Vec<ServerEnvelope>> {
|
||||
if !matches!(envelope.event, ServerEvent::ServerMessage { .. }) {
|
||||
return Ok(vec![envelope]);
|
||||
}
|
||||
|
||||
let envelope_size_bytes = serialized_len(&envelope)?;
|
||||
if envelope_size_bytes <= REMOTE_CONTROL_SEGMENT_MAX_BYTES {
|
||||
return Ok(vec![envelope]);
|
||||
}
|
||||
|
||||
let ServerEvent::ServerMessage { message } = envelope.event.clone() else {
|
||||
unreachable!("server message variant checked above");
|
||||
};
|
||||
let raw = serde_json::to_vec(message.as_ref()).map_err(io::Error::other)?;
|
||||
let message_size_bytes = raw.len();
|
||||
if message_size_bytes > REMOTE_CONTROL_REASSEMBLED_MAX_BYTES {
|
||||
warn!("dropping remote-control server envelope that exceeds reassembled size limit");
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let minimal_segment_count =
|
||||
usize::min(message_size_bytes.max(1), REMOTE_CONTROL_SEGMENT_COUNT_MAX);
|
||||
let minimal_chunk = &raw[..usize::min(raw.len(), 1)];
|
||||
if serialized_chunk_len(
|
||||
&envelope,
|
||||
/*segment_id*/ 0,
|
||||
minimal_segment_count,
|
||||
message_size_bytes,
|
||||
minimal_chunk,
|
||||
)? > REMOTE_CONTROL_SEGMENT_MAX_BYTES
|
||||
{
|
||||
warn!("dropping remote-control server envelope that cannot fit within segment size limit");
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut segment_count = usize::max(
|
||||
2,
|
||||
message_size_bytes.div_ceil(REMOTE_CONTROL_SEGMENT_TARGET_BYTES),
|
||||
);
|
||||
loop {
|
||||
let chunk_size = usize::max(1, message_size_bytes.div_ceil(segment_count));
|
||||
segment_count = message_size_bytes.div_ceil(chunk_size);
|
||||
let segments_fit = raw
|
||||
.chunks(chunk_size)
|
||||
.enumerate()
|
||||
.all(|(segment_id, chunk)| {
|
||||
serialized_chunk_len(
|
||||
&envelope,
|
||||
segment_id,
|
||||
segment_count,
|
||||
message_size_bytes,
|
||||
chunk,
|
||||
)
|
||||
.is_ok_and(|size| size <= REMOTE_CONTROL_SEGMENT_MAX_BYTES)
|
||||
});
|
||||
if segments_fit {
|
||||
return raw
|
||||
.chunks(chunk_size)
|
||||
.enumerate()
|
||||
.map(|(segment_id, chunk)| {
|
||||
build_chunk_envelope(
|
||||
&envelope,
|
||||
segment_id,
|
||||
segment_count,
|
||||
message_size_bytes,
|
||||
chunk,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
if chunk_size == 1 {
|
||||
warn!(
|
||||
"dropping remote-control server envelope that cannot fit within segment size limit"
|
||||
);
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let next_segment_count = segment_count + 1;
|
||||
let next_chunk_size = usize::max(1, message_size_bytes.div_ceil(next_segment_count));
|
||||
segment_count = if next_chunk_size == chunk_size {
|
||||
message_size_bytes
|
||||
} else {
|
||||
next_segment_count
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn serialized_chunk_len(
|
||||
envelope: &ServerEnvelope,
|
||||
segment_id: usize,
|
||||
segment_count: usize,
|
||||
message_size_bytes: usize,
|
||||
chunk: &[u8],
|
||||
) -> io::Result<usize> {
|
||||
serialized_len(&build_chunk_envelope(
|
||||
envelope,
|
||||
segment_id,
|
||||
segment_count,
|
||||
message_size_bytes,
|
||||
chunk,
|
||||
)?)
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct CountingWriter {
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl Write for CountingWriter {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
self.len += buf.len();
|
||||
Ok(buf.len())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn serialized_len(value: &impl serde::Serialize) -> io::Result<usize> {
|
||||
let mut writer = CountingWriter::default();
|
||||
serde_json::to_writer(&mut writer, value).map_err(io::Error::other)?;
|
||||
Ok(writer.len)
|
||||
}
|
||||
|
||||
fn build_chunk_envelope(
|
||||
envelope: &ServerEnvelope,
|
||||
segment_id: usize,
|
||||
segment_count: usize,
|
||||
message_size_bytes: usize,
|
||||
chunk: &[u8],
|
||||
) -> io::Result<ServerEnvelope> {
|
||||
if segment_count > REMOTE_CONTROL_SEGMENT_COUNT_MAX {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"remote-control segment count exceeds maximum",
|
||||
));
|
||||
}
|
||||
Ok(ServerEnvelope {
|
||||
event: ServerEvent::ServerMessageChunk {
|
||||
segment_id,
|
||||
segment_count,
|
||||
message_size_bytes,
|
||||
message_chunk_base64: base64::engine::general_purpose::STANDARD.encode(chunk),
|
||||
},
|
||||
client_id: envelope.client_id.clone(),
|
||||
stream_id: envelope.stream_id.clone(),
|
||||
seq_id: envelope.seq_id,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,386 @@
|
||||
use super::protocol::ClientEnvelope;
|
||||
use super::protocol::ClientEvent;
|
||||
use super::protocol::ClientId;
|
||||
use super::protocol::ServerEnvelope;
|
||||
use super::protocol::ServerEvent;
|
||||
use super::protocol::StreamId;
|
||||
use super::segment::ClientSegmentObservation;
|
||||
use super::segment::ClientSegmentReassembler;
|
||||
use super::segment::REMOTE_CONTROL_SEGMENT_MAX_BYTES;
|
||||
use super::segment::split_server_envelope_for_transport;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use base64::Engine;
|
||||
use codex_app_server_protocol::ConfigWarningNotification;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn reassembles_client_message_chunks() {
|
||||
let message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
let raw = serde_json::to_vec(&message).expect("message should serialize");
|
||||
let split = raw.len() / 2;
|
||||
let client_id = ClientId("client-1".to_string());
|
||||
let stream_id = Some(StreamId("stream-1".to_string()));
|
||||
let mut reassembler = ClientSegmentReassembler::default();
|
||||
|
||||
assert!(matches!(
|
||||
reassembler.observe(chunk_envelope(
|
||||
client_id.clone(),
|
||||
stream_id.clone(),
|
||||
/*seq_id*/ 7,
|
||||
/*segment_id*/ 0,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[..split],
|
||||
)),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
let reassembled = match reassembler.observe(chunk_envelope(
|
||||
client_id.clone(),
|
||||
stream_id,
|
||||
/*seq_id*/ 7,
|
||||
/*segment_id*/ 1,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[split..],
|
||||
)) {
|
||||
ClientSegmentObservation::Forward(reassembled) => *reassembled,
|
||||
ClientSegmentObservation::Pending | ClientSegmentObservation::Dropped => {
|
||||
panic!("message should reassemble")
|
||||
}
|
||||
};
|
||||
assert_eq!(reassembled.client_id, client_id);
|
||||
assert_eq!(
|
||||
reassembled.stream_id,
|
||||
Some(StreamId("stream-1".to_string()))
|
||||
);
|
||||
assert_eq!(reassembled.seq_id, Some(7));
|
||||
assert_eq!(reassembled.cursor, None);
|
||||
match reassembled.event {
|
||||
ClientEvent::ClientMessage {
|
||||
message: reassembled_message,
|
||||
} => assert_eq!(reassembled_message, message),
|
||||
other => panic!("expected client message, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn splits_large_server_messages_into_wire_chunks() {
|
||||
let envelope = ServerEnvelope {
|
||||
event: ServerEvent::ServerMessage {
|
||||
message: Box::new(OutgoingMessage::AppServerNotification(
|
||||
ServerNotification::ConfigWarning(ConfigWarningNotification {
|
||||
summary: "x".repeat(REMOTE_CONTROL_SEGMENT_MAX_BYTES),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
}),
|
||||
)),
|
||||
},
|
||||
client_id: ClientId("client-1".to_string()),
|
||||
stream_id: StreamId("stream-1".to_string()),
|
||||
seq_id: 9,
|
||||
};
|
||||
|
||||
let segments = split_server_envelope_for_transport(envelope).expect("split should succeed");
|
||||
|
||||
assert!(segments.len() > 1);
|
||||
assert!(
|
||||
segments
|
||||
.iter()
|
||||
.all(|segment| matches!(segment.event, ServerEvent::ServerMessageChunk { .. }))
|
||||
);
|
||||
assert!(segments.iter().all(|segment| segment.seq_id == 9));
|
||||
assert!(segments.iter().all(|segment| {
|
||||
serde_json::to_vec(segment)
|
||||
.expect("segment should serialize")
|
||||
.len()
|
||||
<= REMOTE_CONTROL_SEGMENT_MAX_BYTES
|
||||
}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalidates_incomplete_stream_assemblies() {
|
||||
let message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
let raw = serde_json::to_vec(&message).expect("message should serialize");
|
||||
let split = raw.len() / 2;
|
||||
let client_id = ClientId("client-1".to_string());
|
||||
let stream_id = StreamId("stream-1".to_string());
|
||||
let mut reassembler = ClientSegmentReassembler::default();
|
||||
|
||||
assert!(matches!(
|
||||
reassembler.observe(chunk_envelope(
|
||||
client_id.clone(),
|
||||
Some(stream_id.clone()),
|
||||
/*seq_id*/ 7,
|
||||
/*segment_id*/ 0,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[..split],
|
||||
)),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
reassembler.invalidate_stream(&client_id, &stream_id);
|
||||
assert!(matches!(
|
||||
reassembler.observe(chunk_envelope(
|
||||
client_id,
|
||||
Some(stream_id),
|
||||
/*seq_id*/ 7,
|
||||
/*segment_id*/ 1,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[split..],
|
||||
)),
|
||||
ClientSegmentObservation::Dropped
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resets_incomplete_client_assembly_when_stream_changes() {
|
||||
let message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
let raw = serde_json::to_vec(&message).expect("message should serialize");
|
||||
let split = raw.len() / 2;
|
||||
let client_id = ClientId("client-1".to_string());
|
||||
let first_stream_id = StreamId("stream-1".to_string());
|
||||
let second_stream_id = StreamId("stream-2".to_string());
|
||||
let mut reassembler = ClientSegmentReassembler::default();
|
||||
|
||||
assert!(matches!(
|
||||
reassembler.observe(chunk_envelope(
|
||||
client_id.clone(),
|
||||
Some(first_stream_id.clone()),
|
||||
/*seq_id*/ 7,
|
||||
/*segment_id*/ 0,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[..split],
|
||||
)),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
assert!(matches!(
|
||||
reassembler.observe(chunk_envelope(
|
||||
client_id.clone(),
|
||||
Some(second_stream_id.clone()),
|
||||
/*seq_id*/ 8,
|
||||
/*segment_id*/ 0,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[..split],
|
||||
)),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
let reassembled = match reassembler.observe(chunk_envelope(
|
||||
client_id.clone(),
|
||||
Some(second_stream_id),
|
||||
/*seq_id*/ 8,
|
||||
/*segment_id*/ 1,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[split..],
|
||||
)) {
|
||||
ClientSegmentObservation::Forward(reassembled) => *reassembled,
|
||||
ClientSegmentObservation::Pending | ClientSegmentObservation::Dropped => {
|
||||
panic!("replacement stream should reassemble")
|
||||
}
|
||||
};
|
||||
assert_eq!(
|
||||
reassembled.stream_id,
|
||||
Some(StreamId("stream-2".to_string()))
|
||||
);
|
||||
assert!(matches!(
|
||||
reassembler.observe(chunk_envelope(
|
||||
client_id,
|
||||
Some(first_stream_id),
|
||||
/*seq_id*/ 7,
|
||||
/*segment_id*/ 1,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[split..],
|
||||
)),
|
||||
ClientSegmentObservation::Dropped
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignores_stale_chunks_without_dropping_newer_assembly() {
|
||||
let message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
let raw = serde_json::to_vec(&message).expect("message should serialize");
|
||||
let split = raw.len() / 2;
|
||||
let client_id = ClientId("client-1".to_string());
|
||||
let stream_id = Some(StreamId("stream-1".to_string()));
|
||||
let mut reassembler = ClientSegmentReassembler::default();
|
||||
|
||||
assert!(matches!(
|
||||
reassembler.observe(chunk_envelope(
|
||||
client_id.clone(),
|
||||
stream_id.clone(),
|
||||
/*seq_id*/ 8,
|
||||
/*segment_id*/ 0,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[..split],
|
||||
)),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
assert!(matches!(
|
||||
reassembler.observe(chunk_envelope(
|
||||
client_id.clone(),
|
||||
stream_id.clone(),
|
||||
/*seq_id*/ 7,
|
||||
/*segment_id*/ 0,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[..split],
|
||||
)),
|
||||
ClientSegmentObservation::Dropped
|
||||
));
|
||||
assert!(matches!(
|
||||
reassembler.observe(chunk_envelope(
|
||||
client_id,
|
||||
stream_id,
|
||||
/*seq_id*/ 8,
|
||||
/*segment_id*/ 1,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[split..],
|
||||
)),
|
||||
ClientSegmentObservation::Forward(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignores_invalid_stale_chunks_without_dropping_newer_assembly() {
|
||||
let message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
let raw = serde_json::to_vec(&message).expect("message should serialize");
|
||||
let split = raw.len() / 2;
|
||||
let client_id = ClientId("client-1".to_string());
|
||||
let stream_id = Some(StreamId("stream-1".to_string()));
|
||||
let mut reassembler = ClientSegmentReassembler::default();
|
||||
|
||||
assert!(matches!(
|
||||
reassembler.observe(chunk_envelope(
|
||||
client_id.clone(),
|
||||
stream_id.clone(),
|
||||
/*seq_id*/ 8,
|
||||
/*segment_id*/ 0,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[..split],
|
||||
)),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
assert!(matches!(
|
||||
reassembler.observe(chunk_envelope(
|
||||
client_id.clone(),
|
||||
stream_id.clone(),
|
||||
/*seq_id*/ 7,
|
||||
/*segment_id*/ 1,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
b"",
|
||||
)),
|
||||
ClientSegmentObservation::Dropped
|
||||
));
|
||||
assert!(matches!(
|
||||
reassembler.observe(chunk_envelope(
|
||||
client_id,
|
||||
stream_id,
|
||||
/*seq_id*/ 8,
|
||||
/*segment_id*/ 1,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[split..],
|
||||
)),
|
||||
ClientSegmentObservation::Forward(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignores_invalid_duplicate_chunks_without_dropping_current_assembly() {
|
||||
let message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
let raw = serde_json::to_vec(&message).expect("message should serialize");
|
||||
let split = raw.len() / 2;
|
||||
let client_id = ClientId("client-1".to_string());
|
||||
let stream_id = Some(StreamId("stream-1".to_string()));
|
||||
let mut reassembler = ClientSegmentReassembler::default();
|
||||
|
||||
assert!(matches!(
|
||||
reassembler.observe(chunk_envelope(
|
||||
client_id.clone(),
|
||||
stream_id.clone(),
|
||||
/*seq_id*/ 8,
|
||||
/*segment_id*/ 0,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[..split],
|
||||
)),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
assert!(matches!(
|
||||
reassembler.observe(chunk_envelope(
|
||||
client_id.clone(),
|
||||
stream_id.clone(),
|
||||
/*seq_id*/ 8,
|
||||
/*segment_id*/ 0,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
b"",
|
||||
)),
|
||||
ClientSegmentObservation::Dropped
|
||||
));
|
||||
assert!(matches!(
|
||||
reassembler.observe(chunk_envelope(
|
||||
client_id,
|
||||
stream_id,
|
||||
/*seq_id*/ 8,
|
||||
/*segment_id*/ 1,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[split..],
|
||||
)),
|
||||
ClientSegmentObservation::Forward(_)
|
||||
));
|
||||
}
|
||||
|
||||
fn chunk_envelope(
|
||||
client_id: ClientId,
|
||||
stream_id: Option<StreamId>,
|
||||
seq_id: u64,
|
||||
segment_id: usize,
|
||||
segment_count: usize,
|
||||
message_size_bytes: usize,
|
||||
chunk: &[u8],
|
||||
) -> ClientEnvelope {
|
||||
ClientEnvelope {
|
||||
event: ClientEvent::ClientMessageChunk {
|
||||
segment_id,
|
||||
segment_count,
|
||||
message_size_bytes,
|
||||
message_chunk_base64: base64::engine::general_purpose::STANDARD.encode(chunk),
|
||||
},
|
||||
client_id,
|
||||
stream_id,
|
||||
seq_id: Some(seq_id),
|
||||
cursor: None,
|
||||
}
|
||||
}
|
||||
@@ -831,7 +831,7 @@ async fn remote_control_transport_clears_outgoing_buffer_when_backend_acks() {
|
||||
send_client_event(
|
||||
&mut first_websocket,
|
||||
ClientEnvelope {
|
||||
event: ClientEvent::Ack,
|
||||
event: ClientEvent::Ack { segment_id: None },
|
||||
client_id: client_id.clone(),
|
||||
stream_id: Some(stream_id),
|
||||
seq_id: Some(1),
|
||||
|
||||
@@ -15,6 +15,10 @@ use super::protocol::ClientId;
|
||||
use super::protocol::RemoteControlTarget;
|
||||
use super::protocol::ServerEnvelope;
|
||||
use super::protocol::StreamId;
|
||||
use super::segment::ClientSegmentObservation;
|
||||
use super::segment::ClientSegmentReassembler;
|
||||
use super::segment::REMOTE_CONTROL_SEGMENT_MAX_BYTES;
|
||||
use super::segment::split_server_envelope_for_transport;
|
||||
use axum::http::HeaderValue;
|
||||
use base64::Engine;
|
||||
use codex_app_server_protocol::RemoteControlConnectionStatus;
|
||||
@@ -49,7 +53,7 @@ use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
pub(super) const REMOTE_CONTROL_PROTOCOL_VERSION: &str = "2";
|
||||
pub(super) const REMOTE_CONTROL_PROTOCOL_VERSION: &str = "3";
|
||||
pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id";
|
||||
const REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER: &str = "x-codex-subscribe-cursor";
|
||||
const REMOTE_CONTROL_WEBSOCKET_PING_INTERVAL: std::time::Duration =
|
||||
@@ -85,17 +89,29 @@ impl BoundedOutboundBuffer {
|
||||
self.used_tx.send_modify(|used| *used += 1);
|
||||
}
|
||||
|
||||
fn ack(&mut self, client_id: &ClientId, stream_id: &StreamId, acked_seq_id: u64) {
|
||||
fn ack(
|
||||
&mut self,
|
||||
client_id: &ClientId,
|
||||
stream_id: &StreamId,
|
||||
acked_seq_id: u64,
|
||||
acked_segment_id: Option<usize>,
|
||||
) {
|
||||
let key = (client_id.clone(), stream_id.clone());
|
||||
let Some(buffer) = self.buffer_by_stream.get_mut(&key) else {
|
||||
return;
|
||||
};
|
||||
while let Some(server_envelope) = buffer.front()
|
||||
&& server_envelope.seq_id <= acked_seq_id
|
||||
{
|
||||
buffer.pop_front();
|
||||
self.used_tx.send_modify(|used| *used -= 1);
|
||||
}
|
||||
let acked_cursor = (acked_seq_id, acked_segment_id.unwrap_or(usize::MAX));
|
||||
buffer.retain(|server_envelope| {
|
||||
let envelope_cursor = (
|
||||
server_envelope.seq_id,
|
||||
server_envelope.event.segment_id().unwrap_or_default(),
|
||||
);
|
||||
let is_acked = envelope_cursor <= acked_cursor;
|
||||
if is_acked {
|
||||
self.used_tx.send_modify(|used| *used -= 1);
|
||||
}
|
||||
!is_acked
|
||||
});
|
||||
if buffer.is_empty() {
|
||||
self.buffer_by_stream.remove(&key);
|
||||
}
|
||||
@@ -112,6 +128,88 @@ struct WebsocketState {
|
||||
outbound_buffer: BoundedOutboundBuffer,
|
||||
subscribe_cursor: Option<String>,
|
||||
next_seq_id_by_stream: HashMap<(ClientId, StreamId), u64>,
|
||||
last_completed_client_chunk_seq_id_by_stream: HashMap<(ClientId, Option<StreamId>), u64>,
|
||||
client_segment_reassembler: ClientSegmentReassembler,
|
||||
}
|
||||
|
||||
impl WebsocketState {
|
||||
fn observe_client_message(
|
||||
&mut self,
|
||||
client_envelope: ClientEnvelope,
|
||||
wire_size_bytes: usize,
|
||||
) -> ClientSegmentObservation {
|
||||
let client_message_key = Self::client_message_key(&client_envelope);
|
||||
if let Some((key, seq_id)) = client_message_key.as_ref()
|
||||
&& self
|
||||
.last_completed_client_chunk_seq_id_by_stream
|
||||
.get(key)
|
||||
.is_some_and(|last_seq_id| last_seq_id >= seq_id)
|
||||
{
|
||||
return ClientSegmentObservation::Dropped;
|
||||
}
|
||||
if let (
|
||||
Some((_, seq_id)),
|
||||
Some(stream_id),
|
||||
ClientEvent::ClientMessageChunk { segment_id, .. },
|
||||
) = (
|
||||
client_message_key.as_ref(),
|
||||
client_envelope.stream_id.as_ref(),
|
||||
&client_envelope.event,
|
||||
) && self.client_segment_reassembler.should_ignore_chunk(
|
||||
&client_envelope.client_id,
|
||||
stream_id,
|
||||
*seq_id,
|
||||
*segment_id,
|
||||
) {
|
||||
return ClientSegmentObservation::Dropped;
|
||||
}
|
||||
if client_message_key.is_some() && wire_size_bytes > REMOTE_CONTROL_SEGMENT_MAX_BYTES {
|
||||
warn!(
|
||||
client_id = client_envelope.client_id.0.as_str(),
|
||||
"dropping oversized segmented remote-control client envelope"
|
||||
);
|
||||
if let Some(stream_id) = client_envelope.stream_id.as_ref() {
|
||||
self.client_segment_reassembler
|
||||
.invalidate_stream(&client_envelope.client_id, stream_id);
|
||||
}
|
||||
return ClientSegmentObservation::Dropped;
|
||||
}
|
||||
|
||||
let observation = self.client_segment_reassembler.observe(client_envelope);
|
||||
if matches!(observation, ClientSegmentObservation::Forward(_))
|
||||
&& let Some((key, seq_id)) = client_message_key
|
||||
{
|
||||
self.last_completed_client_chunk_seq_id_by_stream
|
||||
.insert(key, seq_id);
|
||||
}
|
||||
observation
|
||||
}
|
||||
|
||||
fn invalidate_client_message_stream(&mut self, client_id: &ClientId, stream_id: &StreamId) {
|
||||
self.last_completed_client_chunk_seq_id_by_stream
|
||||
.remove(&(client_id.clone(), Some(stream_id.clone())));
|
||||
}
|
||||
|
||||
fn invalidate_client_message_client(&mut self, client_id: &ClientId) {
|
||||
self.last_completed_client_chunk_seq_id_by_stream
|
||||
.retain(|(cursor_client_id, _), _| cursor_client_id != client_id);
|
||||
}
|
||||
|
||||
fn client_message_key(
|
||||
client_envelope: &ClientEnvelope,
|
||||
) -> Option<((ClientId, Option<StreamId>), u64)> {
|
||||
let seq_id = match (&client_envelope.event, client_envelope.seq_id) {
|
||||
(ClientEvent::ClientMessageChunk { .. }, Some(seq_id)) => seq_id,
|
||||
_ => return None,
|
||||
};
|
||||
Some((
|
||||
(
|
||||
client_envelope.client_id.clone(),
|
||||
client_envelope.stream_id.clone(),
|
||||
),
|
||||
seq_id,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct RemoteControlWebsocket {
|
||||
@@ -231,6 +329,8 @@ impl RemoteControlWebsocket {
|
||||
outbound_buffer,
|
||||
subscribe_cursor: None,
|
||||
next_seq_id_by_stream: HashMap::new(),
|
||||
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
|
||||
client_segment_reassembler: ClientSegmentReassembler::default(),
|
||||
})),
|
||||
server_event_rx: Arc::new(Mutex::new(server_event_rx)),
|
||||
used_rx,
|
||||
@@ -556,7 +656,7 @@ impl RemoteControlWebsocket {
|
||||
}
|
||||
}
|
||||
};
|
||||
let (payload, write_complete_tx) = {
|
||||
let (payloads, write_complete_tx) = {
|
||||
let mut state = state.lock().await;
|
||||
let seq_key = (
|
||||
queued_server_envelope.client_id.clone(),
|
||||
@@ -573,29 +673,42 @@ impl RemoteControlWebsocket {
|
||||
seq_id,
|
||||
stream_id: queued_server_envelope.stream_id,
|
||||
};
|
||||
let payload = match serde_json::to_string(&server_envelope) {
|
||||
Ok(payload) => payload,
|
||||
let server_envelopes = match split_server_envelope_for_transport(server_envelope) {
|
||||
Ok(server_envelopes) => server_envelopes,
|
||||
Err(err) => {
|
||||
error!("failed to serialize remote-control server event: {err}");
|
||||
error!("failed to split remote-control server event: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let mut payloads = Vec::with_capacity(server_envelopes.len());
|
||||
for server_envelope in server_envelopes {
|
||||
let payload = match serde_json::to_string(&server_envelope) {
|
||||
Ok(payload) => payload,
|
||||
Err(err) => {
|
||||
error!("failed to serialize remote-control server event: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
state.outbound_buffer.insert(&server_envelope);
|
||||
payloads.push(payload);
|
||||
}
|
||||
state
|
||||
.next_seq_id_by_stream
|
||||
.insert(seq_key, seq_id.saturating_add(1));
|
||||
state.outbound_buffer.insert(&server_envelope);
|
||||
|
||||
(payload, queued_server_envelope.write_complete_tx)
|
||||
(payloads, queued_server_envelope.write_complete_tx)
|
||||
};
|
||||
|
||||
tokio::select! {
|
||||
_ = shutdown_token.cancelled() => return Ok(()),
|
||||
send_result = websocket_writer.send(tungstenite::Message::Text(payload.into())) => {
|
||||
if let Err(err) = send_result {
|
||||
return Err(io::Error::other(err));
|
||||
for payload in payloads {
|
||||
tokio::select! {
|
||||
_ = shutdown_token.cancelled() => return Ok(()),
|
||||
send_result = websocket_writer.send(tungstenite::Message::Text(payload.into())) => {
|
||||
if let Err(err) = send_result {
|
||||
return Err(io::Error::other(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
if let Some(write_complete_tx) = write_complete_tx {
|
||||
let _ = write_complete_tx.send(());
|
||||
}
|
||||
@@ -657,11 +770,30 @@ impl RemoteControlWebsocket {
|
||||
if client_tracker.close_client(&client_key).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
state
|
||||
.lock()
|
||||
.await
|
||||
.client_segment_reassembler
|
||||
.invalidate_stream(&client_key.0, &client_key.1);
|
||||
state
|
||||
.lock()
|
||||
.await
|
||||
.invalidate_client_message_stream(&client_key.0, &client_key.1);
|
||||
continue;
|
||||
}
|
||||
_ = idle_sweep_interval.tick() => {
|
||||
if client_tracker.close_expired_clients().await.is_err() {
|
||||
return Ok(());
|
||||
match client_tracker.close_expired_clients().await {
|
||||
Ok(client_keys) => {
|
||||
let mut websocket_state = state.lock().await;
|
||||
for (client_id, stream_id) in client_keys {
|
||||
websocket_state
|
||||
.client_segment_reassembler
|
||||
.invalidate_stream(&client_id, &stream_id);
|
||||
websocket_state
|
||||
.invalidate_client_message_stream(&client_id, &stream_id);
|
||||
}
|
||||
}
|
||||
Err(_) => return Ok(()),
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@@ -672,10 +804,11 @@ impl RemoteControlWebsocket {
|
||||
}
|
||||
}
|
||||
};
|
||||
let client_envelope = match incoming_message {
|
||||
let (client_envelope, wire_size_bytes) = match incoming_message {
|
||||
Ok(tungstenite::Message::Text(text)) => {
|
||||
let wire_size_bytes = text.len();
|
||||
match serde_json::from_str::<ClientEnvelope>(&text) {
|
||||
Ok(client_envelope) => client_envelope,
|
||||
Ok(client_envelope) => (client_envelope, wire_size_bytes),
|
||||
Err(err) => {
|
||||
warn!("failed to deserialize remote-control client event: {err}");
|
||||
continue;
|
||||
@@ -707,12 +840,21 @@ impl RemoteControlWebsocket {
|
||||
}
|
||||
};
|
||||
|
||||
let observation = {
|
||||
let mut websocket_state = state.lock().await;
|
||||
websocket_state.observe_client_message(client_envelope, wire_size_bytes)
|
||||
};
|
||||
let client_envelope = match observation {
|
||||
ClientSegmentObservation::Forward(client_envelope) => *client_envelope,
|
||||
ClientSegmentObservation::Pending | ClientSegmentObservation::Dropped => continue,
|
||||
};
|
||||
|
||||
{
|
||||
let mut websocket_state = state.lock().await;
|
||||
if let Some(cursor) = client_envelope.cursor.as_deref() {
|
||||
websocket_state.subscribe_cursor = Some(cursor.to_string());
|
||||
}
|
||||
if let ClientEvent::Ack = &client_envelope.event
|
||||
if let ClientEvent::Ack { segment_id } = &client_envelope.event
|
||||
&& let Some(acked_seq_id) = client_envelope.seq_id
|
||||
&& let Some(stream_id) = client_envelope.stream_id.as_ref()
|
||||
{
|
||||
@@ -720,10 +862,18 @@ impl RemoteControlWebsocket {
|
||||
&client_envelope.client_id,
|
||||
stream_id,
|
||||
acked_seq_id,
|
||||
*segment_id,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let closed_client =
|
||||
matches!(&client_envelope.event, ClientEvent::ClientClosed).then(|| {
|
||||
(
|
||||
client_envelope.client_id.clone(),
|
||||
client_envelope.stream_id.clone(),
|
||||
)
|
||||
});
|
||||
if client_tracker
|
||||
.handle_message(client_envelope)
|
||||
.await
|
||||
@@ -731,6 +881,20 @@ impl RemoteControlWebsocket {
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
if let Some((client_id, stream_id)) = closed_client {
|
||||
let mut websocket_state = state.lock().await;
|
||||
if let Some(stream_id) = stream_id {
|
||||
websocket_state
|
||||
.client_segment_reassembler
|
||||
.invalidate_stream(&client_id, &stream_id);
|
||||
websocket_state.invalidate_client_message_stream(&client_id, &stream_id);
|
||||
} else {
|
||||
websocket_state
|
||||
.client_segment_reassembler
|
||||
.invalidate_client(&client_id);
|
||||
websocket_state.invalidate_client_message_client(&client_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1052,6 +1216,8 @@ mod tests {
|
||||
use chrono::Utc;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_app_server_protocol::ConfigWarningNotification;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_config::types::AuthCredentialsStoreMode;
|
||||
use codex_core::test_support::auth_manager_from_auth;
|
||||
@@ -1603,6 +1769,8 @@ mod tests {
|
||||
outbound_buffer,
|
||||
subscribe_cursor: None,
|
||||
next_seq_id_by_stream: HashMap::new(),
|
||||
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
|
||||
client_segment_reassembler: ClientSegmentReassembler::default(),
|
||||
}));
|
||||
let (_server_event_tx, server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY);
|
||||
let server_event_rx = Arc::new(Mutex::new(server_event_rx));
|
||||
@@ -1639,6 +1807,8 @@ mod tests {
|
||||
outbound_buffer,
|
||||
subscribe_cursor: None,
|
||||
next_seq_id_by_stream: HashMap::new(),
|
||||
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
|
||||
client_segment_reassembler: ClientSegmentReassembler::default(),
|
||||
}));
|
||||
let (server_event_tx, server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY);
|
||||
let server_event_rx = Arc::new(Mutex::new(server_event_rx));
|
||||
@@ -1716,6 +1886,8 @@ mod tests {
|
||||
outbound_buffer,
|
||||
subscribe_cursor: None,
|
||||
next_seq_id_by_stream: HashMap::new(),
|
||||
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
|
||||
client_segment_reassembler: ClientSegmentReassembler::default(),
|
||||
}));
|
||||
let (server_event_tx, _server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY);
|
||||
let (transport_event_tx, _transport_event_rx) =
|
||||
@@ -1771,7 +1943,9 @@ mod tests {
|
||||
"first-client-new-stream",
|
||||
));
|
||||
|
||||
outbound_buffer.ack(&client_1, &stream_1, /*acked_seq_id*/ 3);
|
||||
outbound_buffer.ack(
|
||||
&client_1, &stream_1, /*acked_seq_id*/ 3, /*acked_segment_id*/ None,
|
||||
);
|
||||
|
||||
let mut retained = outbound_buffer
|
||||
.server_envelopes()
|
||||
@@ -1814,7 +1988,9 @@ mod tests {
|
||||
&client_2, "stream-1", /*seq_id*/ 3, "second",
|
||||
));
|
||||
|
||||
outbound_buffer.ack(&client_1, &stream_1, /*acked_seq_id*/ 1);
|
||||
outbound_buffer.ack(
|
||||
&client_1, &stream_1, /*acked_seq_id*/ 1, /*acked_segment_id*/ None,
|
||||
);
|
||||
|
||||
let mut retained = outbound_buffer
|
||||
.server_envelopes()
|
||||
@@ -1834,6 +2010,390 @@ mod tests {
|
||||
assert_eq!(*used_rx.borrow(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn outbound_buffer_advances_segmented_acks_by_wire_cursor() {
|
||||
let (mut outbound_buffer, used_rx) = BoundedOutboundBuffer::new();
|
||||
let client_id = ClientId("client-1".to_string());
|
||||
let stream_id = StreamId("stream-1".to_string());
|
||||
|
||||
outbound_buffer.insert(&server_chunk_envelope(
|
||||
&client_id, "stream-1", /*seq_id*/ 4, /*segment_id*/ 0,
|
||||
));
|
||||
outbound_buffer.insert(&server_chunk_envelope(
|
||||
&client_id, "stream-1", /*seq_id*/ 4, /*segment_id*/ 1,
|
||||
));
|
||||
|
||||
outbound_buffer.ack(
|
||||
&client_id,
|
||||
&stream_id,
|
||||
/*acked_seq_id*/ 4,
|
||||
/*acked_segment_id*/ Some(1),
|
||||
);
|
||||
|
||||
let retained = outbound_buffer
|
||||
.server_envelopes()
|
||||
.map(|server_envelope| server_envelope.event.segment_id())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(retained, Vec::<Option<usize>>::new());
|
||||
assert_eq!(*used_rx.borrow(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn outbound_buffer_treats_segmentless_acks_as_seq_level_acks() {
|
||||
let (mut outbound_buffer, used_rx) = BoundedOutboundBuffer::new();
|
||||
let client_id = ClientId("client-1".to_string());
|
||||
let stream_id = StreamId("stream-1".to_string());
|
||||
|
||||
outbound_buffer.insert(&server_chunk_envelope(
|
||||
&client_id, "stream-1", /*seq_id*/ 4, /*segment_id*/ 0,
|
||||
));
|
||||
outbound_buffer.insert(&server_chunk_envelope(
|
||||
&client_id, "stream-1", /*seq_id*/ 4, /*segment_id*/ 1,
|
||||
));
|
||||
|
||||
outbound_buffer.ack(
|
||||
&client_id, &stream_id, /*acked_seq_id*/ 4, /*acked_segment_id*/ None,
|
||||
);
|
||||
|
||||
let retained = outbound_buffer
|
||||
.server_envelopes()
|
||||
.map(|server_envelope| server_envelope.event.segment_id())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(retained, Vec::<Option<usize>>::new());
|
||||
assert_eq!(*used_rx.borrow(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_state_drops_duplicate_client_chunks_while_pending() {
|
||||
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
|
||||
let mut state = WebsocketState {
|
||||
outbound_buffer,
|
||||
subscribe_cursor: None,
|
||||
next_seq_id_by_stream: HashMap::new(),
|
||||
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
|
||||
client_segment_reassembler: ClientSegmentReassembler::default(),
|
||||
};
|
||||
let first_chunk = client_chunk_envelope(
|
||||
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0,
|
||||
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"x",
|
||||
);
|
||||
let second_chunk = client_chunk_envelope(
|
||||
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 1,
|
||||
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"y",
|
||||
);
|
||||
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, first_chunk.clone()),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, first_chunk.clone()),
|
||||
ClientSegmentObservation::Dropped
|
||||
));
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, second_chunk),
|
||||
ClientSegmentObservation::Dropped
|
||||
));
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, first_chunk),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_state_drops_replayed_client_chunks_after_completion() {
|
||||
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
|
||||
let mut state = WebsocketState {
|
||||
outbound_buffer,
|
||||
subscribe_cursor: None,
|
||||
next_seq_id_by_stream: HashMap::new(),
|
||||
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
|
||||
client_segment_reassembler: ClientSegmentReassembler::default(),
|
||||
};
|
||||
let message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
let raw = serde_json::to_vec(&message).expect("message should serialize");
|
||||
let split = raw.len() / 2;
|
||||
let first_chunk = client_chunk_envelope(
|
||||
"client-1",
|
||||
"stream-1",
|
||||
/*seq_id*/ 4,
|
||||
/*segment_id*/ 0,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[..split],
|
||||
);
|
||||
let second_chunk = client_chunk_envelope(
|
||||
"client-1",
|
||||
"stream-1",
|
||||
/*seq_id*/ 4,
|
||||
/*segment_id*/ 1,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[split..],
|
||||
);
|
||||
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, first_chunk.clone()),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, second_chunk),
|
||||
ClientSegmentObservation::Forward(_)
|
||||
));
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, first_chunk),
|
||||
ClientSegmentObservation::Dropped
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_state_allows_replay_after_rejected_out_of_order_chunk() {
|
||||
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
|
||||
let mut state = WebsocketState {
|
||||
outbound_buffer,
|
||||
subscribe_cursor: None,
|
||||
next_seq_id_by_stream: HashMap::new(),
|
||||
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
|
||||
client_segment_reassembler: ClientSegmentReassembler::default(),
|
||||
};
|
||||
let first_chunk = client_chunk_envelope(
|
||||
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0,
|
||||
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"x",
|
||||
);
|
||||
let second_chunk = client_chunk_envelope(
|
||||
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 1,
|
||||
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"y",
|
||||
);
|
||||
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, second_chunk),
|
||||
ClientSegmentObservation::Dropped
|
||||
));
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, first_chunk),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_state_allows_replay_after_later_chunk_drops() {
|
||||
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
|
||||
let mut state = WebsocketState {
|
||||
outbound_buffer,
|
||||
subscribe_cursor: None,
|
||||
next_seq_id_by_stream: HashMap::new(),
|
||||
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
|
||||
client_segment_reassembler: ClientSegmentReassembler::default(),
|
||||
};
|
||||
let first_chunk = client_chunk_envelope(
|
||||
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0,
|
||||
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"x",
|
||||
);
|
||||
let invalid_second_chunk = client_chunk_envelope(
|
||||
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 1,
|
||||
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"",
|
||||
);
|
||||
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, first_chunk.clone()),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, invalid_second_chunk),
|
||||
ClientSegmentObservation::Dropped
|
||||
));
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, first_chunk),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_state_drops_oversized_client_chunk_frames() {
|
||||
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
|
||||
let mut state = WebsocketState {
|
||||
outbound_buffer,
|
||||
subscribe_cursor: None,
|
||||
next_seq_id_by_stream: HashMap::new(),
|
||||
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
|
||||
client_segment_reassembler: ClientSegmentReassembler::default(),
|
||||
};
|
||||
let chunk = client_chunk_envelope(
|
||||
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0,
|
||||
/*segment_count*/ 1, /*message_size_bytes*/ 1, b"x",
|
||||
);
|
||||
|
||||
assert!(matches!(
|
||||
state.observe_client_message(chunk, REMOTE_CONTROL_SEGMENT_MAX_BYTES + 1),
|
||||
ClientSegmentObservation::Dropped
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_state_ignores_oversized_stale_chunks_without_dropping_newer_assembly() {
|
||||
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
|
||||
let mut state = WebsocketState {
|
||||
outbound_buffer,
|
||||
subscribe_cursor: None,
|
||||
next_seq_id_by_stream: HashMap::new(),
|
||||
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
|
||||
client_segment_reassembler: ClientSegmentReassembler::default(),
|
||||
};
|
||||
let message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
let raw = serde_json::to_vec(&message).expect("message should serialize");
|
||||
let split = raw.len() / 2;
|
||||
let first_newer_chunk = client_chunk_envelope(
|
||||
"client-1",
|
||||
"stream-1",
|
||||
/*seq_id*/ 8,
|
||||
/*segment_id*/ 0,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[..split],
|
||||
);
|
||||
let oversized_stale_chunk = client_chunk_envelope(
|
||||
"client-1",
|
||||
"stream-1",
|
||||
/*seq_id*/ 7,
|
||||
/*segment_id*/ 0,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[..split],
|
||||
);
|
||||
let second_newer_chunk = client_chunk_envelope(
|
||||
"client-1",
|
||||
"stream-1",
|
||||
/*seq_id*/ 8,
|
||||
/*segment_id*/ 1,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[split..],
|
||||
);
|
||||
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, first_newer_chunk),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
assert!(matches!(
|
||||
state.observe_client_message(
|
||||
oversized_stale_chunk,
|
||||
REMOTE_CONTROL_SEGMENT_MAX_BYTES + 1,
|
||||
),
|
||||
ClientSegmentObservation::Dropped
|
||||
));
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, second_newer_chunk),
|
||||
ClientSegmentObservation::Forward(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_state_ignores_oversized_duplicate_chunks_without_dropping_current_assembly() {
|
||||
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
|
||||
let mut state = WebsocketState {
|
||||
outbound_buffer,
|
||||
subscribe_cursor: None,
|
||||
next_seq_id_by_stream: HashMap::new(),
|
||||
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
|
||||
client_segment_reassembler: ClientSegmentReassembler::default(),
|
||||
};
|
||||
let message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
let raw = serde_json::to_vec(&message).expect("message should serialize");
|
||||
let split = raw.len() / 2;
|
||||
let first_chunk = client_chunk_envelope(
|
||||
"client-1",
|
||||
"stream-1",
|
||||
/*seq_id*/ 8,
|
||||
/*segment_id*/ 0,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[..split],
|
||||
);
|
||||
let oversized_duplicate_chunk = client_chunk_envelope(
|
||||
"client-1",
|
||||
"stream-1",
|
||||
/*seq_id*/ 8,
|
||||
/*segment_id*/ 0,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[..split],
|
||||
);
|
||||
let second_chunk = client_chunk_envelope(
|
||||
"client-1",
|
||||
"stream-1",
|
||||
/*seq_id*/ 8,
|
||||
/*segment_id*/ 1,
|
||||
/*segment_count*/ 2,
|
||||
raw.len(),
|
||||
&raw[split..],
|
||||
);
|
||||
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, first_chunk),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
assert!(matches!(
|
||||
state.observe_client_message(
|
||||
oversized_duplicate_chunk,
|
||||
REMOTE_CONTROL_SEGMENT_MAX_BYTES + 1,
|
||||
),
|
||||
ClientSegmentObservation::Dropped
|
||||
));
|
||||
assert!(matches!(
|
||||
observe_client_message(&mut state, second_chunk),
|
||||
ClientSegmentObservation::Forward(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_state_clears_chunk_cursor_when_stream_is_invalidated() {
|
||||
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
|
||||
let mut state = WebsocketState {
|
||||
outbound_buffer,
|
||||
subscribe_cursor: None,
|
||||
next_seq_id_by_stream: HashMap::new(),
|
||||
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
|
||||
client_segment_reassembler: ClientSegmentReassembler::default(),
|
||||
};
|
||||
let client_id = ClientId("client-1".to_string());
|
||||
let stream_id = StreamId("stream-1".to_string());
|
||||
|
||||
assert!(matches!(
|
||||
observe_client_message(
|
||||
&mut state,
|
||||
client_chunk_envelope(
|
||||
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0,
|
||||
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"x",
|
||||
)
|
||||
),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
state.invalidate_client_message_stream(&client_id, &stream_id);
|
||||
state
|
||||
.client_segment_reassembler
|
||||
.invalidate_stream(&client_id, &stream_id);
|
||||
|
||||
assert!(matches!(
|
||||
observe_client_message(
|
||||
&mut state,
|
||||
client_chunk_envelope(
|
||||
"client-1", "stream-1", /*seq_id*/ 1, /*segment_id*/ 0,
|
||||
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"x",
|
||||
)
|
||||
),
|
||||
ClientSegmentObservation::Pending
|
||||
));
|
||||
}
|
||||
|
||||
fn server_envelope(
|
||||
client_id: &ClientId,
|
||||
stream_id: &str,
|
||||
@@ -1857,6 +2417,58 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn server_chunk_envelope(
|
||||
client_id: &ClientId,
|
||||
stream_id: &str,
|
||||
seq_id: u64,
|
||||
segment_id: usize,
|
||||
) -> ServerEnvelope {
|
||||
ServerEnvelope {
|
||||
event: ServerEvent::ServerMessageChunk {
|
||||
segment_id,
|
||||
segment_count: 2,
|
||||
message_size_bytes: 2,
|
||||
message_chunk_base64: String::new(),
|
||||
},
|
||||
client_id: client_id.clone(),
|
||||
stream_id: StreamId(stream_id.to_string()),
|
||||
seq_id,
|
||||
}
|
||||
}
|
||||
|
||||
fn client_chunk_envelope(
|
||||
client_id: &str,
|
||||
stream_id: &str,
|
||||
seq_id: u64,
|
||||
segment_id: usize,
|
||||
segment_count: usize,
|
||||
message_size_bytes: usize,
|
||||
chunk: &[u8],
|
||||
) -> ClientEnvelope {
|
||||
ClientEnvelope {
|
||||
event: ClientEvent::ClientMessageChunk {
|
||||
segment_id,
|
||||
segment_count,
|
||||
message_size_bytes,
|
||||
message_chunk_base64: base64::engine::general_purpose::STANDARD.encode(chunk),
|
||||
},
|
||||
client_id: ClientId(client_id.to_string()),
|
||||
stream_id: Some(StreamId(stream_id.to_string())),
|
||||
seq_id: Some(seq_id),
|
||||
cursor: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn observe_client_message(
|
||||
state: &mut WebsocketState,
|
||||
envelope: ClientEnvelope,
|
||||
) -> ClientSegmentObservation {
|
||||
let wire_size_bytes = serde_json::to_vec(&envelope)
|
||||
.expect("client envelope should serialize")
|
||||
.len();
|
||||
state.observe_client_message(envelope, wire_size_bytes)
|
||||
}
|
||||
|
||||
async fn accept_http_request(listener: &TcpListener) -> (TcpStream, String) {
|
||||
let (stream, _) = timeout(TEST_HTTP_ACCEPT_TIMEOUT, listener.accept())
|
||||
.await
|
||||
|
||||
Reference in New Issue
Block a user