Compare commits

...

1 Commits

Author SHA1 Message Date
Ruslan Nigmatullin
0d15f90832 remote-control: harden stream sequencing 2026-05-07 11:09:51 -07:00
5 changed files with 263 additions and 82 deletions

View File

@@ -26,6 +26,11 @@ pub(crate) const REMOTE_CONTROL_IDLE_SWEEP_INTERVAL: Duration = Duration::from_s
#[derive(Debug)]
pub(crate) struct Stopped;
pub(crate) enum ClientMessageOutcome {
Handled,
StreamClosed(ClientId, StreamId),
}
struct ClientState {
connection_id: ConnectionId,
disconnect_token: CancellationToken,
@@ -86,7 +91,7 @@ impl ClientTracker {
pub(crate) async fn handle_message(
&mut self,
client_envelope: ClientEnvelope,
) -> Result<(), Stopped> {
) -> Result<ClientMessageOutcome, Stopped> {
let ClientEnvelope {
client_id,
event,
@@ -117,19 +122,28 @@ impl ClientTracker {
}),
};
if stream_id.0.is_empty() {
return Ok(());
return Ok(ClientMessageOutcome::Handled);
}
let client_key = (client_id.clone(), stream_id.clone());
match event {
ClientEvent::ClientMessage { message } => {
if let Some(seq_id) = seq_id
&& let Some(client) = self.clients.get(&client_key)
&& client
.last_inbound_seq_id
.is_some_and(|last_seq_id| last_seq_id >= seq_id)
&& !is_initialize
{
return Ok(());
if client
.last_inbound_seq_id
.is_some_and(|last_seq_id| last_seq_id >= seq_id)
{
return Ok(ClientMessageOutcome::Handled);
}
if client
.last_inbound_seq_id
.is_some_and(|last_seq_id| seq_id > last_seq_id.saturating_add(1))
{
self.close_client(&client_key).await?;
return Ok(ClientMessageOutcome::StreamClosed(client_id, stream_id));
}
}
if is_initialize && self.clients.contains_key(&client_key) {
@@ -148,11 +162,11 @@ impl ClientTracker {
message,
})
.await?;
return Ok(());
return Ok(ClientMessageOutcome::Handled);
}
if !is_initialize {
return Ok(());
return Ok(ClientMessageOutcome::Handled);
}
let connection_id = next_connection_id();
@@ -193,14 +207,17 @@ impl ClientTracker {
connection_id,
message,
})
.await
.await?;
Ok(ClientMessageOutcome::Handled)
}
ClientEvent::ClientMessageChunk { .. } | ClientEvent::Ack { .. } => {
Ok(ClientMessageOutcome::Handled)
}
ClientEvent::ClientMessageChunk { .. } | ClientEvent::Ack { .. } => Ok(()),
ClientEvent::Ping => {
if let Some(client) = self.clients.get_mut(&client_key) {
client.last_activity_at = Instant::now();
let _ = client.status_tx.send(PongStatus::Active);
return Ok(());
return Ok(ClientMessageOutcome::Handled);
}
let server_event_tx = self.server_event_tx.clone();
@@ -215,9 +232,12 @@ impl ClientTracker {
};
let _ = server_event_tx.send(server_envelope).await;
});
Ok(())
Ok(ClientMessageOutcome::Handled)
}
ClientEvent::ClientClosed => {
self.close_client(&client_key).await?;
Ok(ClientMessageOutcome::StreamClosed(client_id, stream_id))
}
ClientEvent::ClientClosed => self.close_client(&client_key).await,
}
}
@@ -523,6 +543,59 @@ mod tests {
assert_ne!(first_connection_id, second_connection_id);
}
#[tokio::test]
async fn client_message_seq_gap_closes_only_affected_stream() {
let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
let shutdown_token = CancellationToken::new();
let mut client_tracker =
ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token);
for stream_id in ["stream-1", "stream-2"] {
client_tracker
.handle_message(initialize_envelope_with_stream_id(
"client-1",
Some(stream_id),
))
.await
.expect("initialize should open client");
let _ = transport_event_rx.recv().await.expect("open event");
let _ = transport_event_rx.recv().await.expect("initialize event");
}
let outcome = client_tracker
.handle_message(ClientEnvelope {
event: ClientEvent::ClientMessage {
message: JSONRPCMessage::Notification(
codex_app_server_protocol::JSONRPCNotification {
method: "initialized".to_string(),
params: None,
},
),
},
client_id: ClientId("client-1".to_string()),
stream_id: Some(StreamId("stream-1".to_string())),
seq_id: Some(2),
cursor: None,
})
.await
.expect("gap should be handled");
assert!(matches!(
outcome,
ClientMessageOutcome::StreamClosed(client_id, stream_id)
if client_id == ClientId("client-1".to_string())
&& stream_id == StreamId("stream-1".to_string())
));
assert!(client_tracker.clients.contains_key(&(
ClientId("client-1".to_string()),
StreamId("stream-2".to_string()),
)));
match transport_event_rx.recv().await.expect("close event") {
TransportEvent::ConnectionClosed { .. } => {}
other => panic!("expected connection closed, got {other:?}"),
}
}
#[tokio::test]
async fn legacy_initialize_without_stream_id_resets_inbound_seq_id() {
let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY);

View File

@@ -45,6 +45,7 @@ pub(super) enum ClientSegmentObservation {
Forward(Box<ClientEnvelope>),
Pending,
Dropped,
ResetStream(ClientId, StreamId),
}
impl ClientSegmentReassembler {
@@ -148,9 +149,9 @@ impl ClientSegmentReassembler {
} else if segment_id != assembly.next_segment_id {
warn!(
client_id = envelope.client_id.0.as_str(),
"dropping out-of-order segmented remote-control client envelope"
"resetting segmented remote-control client stream after segment gap"
);
AssemblyUpdate::Drop
AssemblyUpdate::ResetStream
} else {
assembly.last_chunk_seen_at = now;
let chunk_start = assembly.raw.len();
@@ -213,6 +214,10 @@ impl ClientSegmentReassembler {
self.remove_assembly(&envelope.client_id, &stream_id);
ClientSegmentObservation::Dropped
}
AssemblyUpdate::ResetStream => {
self.remove_assembly(&envelope.client_id, &stream_id);
ClientSegmentObservation::ResetStream(envelope.client_id, stream_id)
}
AssemblyUpdate::Complete(message) => {
self.remove_assembly(&envelope.client_id, &stream_id);
ClientSegmentObservation::Forward(Box::new(ClientEnvelope {
@@ -227,10 +232,6 @@ impl ClientSegmentReassembler {
self.remove_assembly(client_id, stream_id);
}
pub(super) fn invalidate_client(&mut self, client_id: &ClientId) {
self.assemblies.remove(client_id);
}
pub(super) fn should_ignore_chunk(
&self,
client_id: &ClientId,
@@ -275,6 +276,7 @@ enum AssemblyUpdate {
Pending,
Ignore,
Drop,
ResetStream,
Complete(JSONRPCMessage),
}

View File

@@ -50,7 +50,9 @@ fn reassembles_client_message_chunks() {
&raw[split..],
)) {
ClientSegmentObservation::Forward(reassembled) => *reassembled,
ClientSegmentObservation::Pending | ClientSegmentObservation::Dropped => {
ClientSegmentObservation::Pending
| ClientSegmentObservation::Dropped
| ClientSegmentObservation::ResetStream(_, _) => {
panic!("message should reassemble")
}
};
@@ -139,7 +141,46 @@ fn invalidates_incomplete_stream_assemblies() {
raw.len(),
&raw[split..],
)),
ClientSegmentObservation::Dropped
ClientSegmentObservation::ResetStream(_, _)
));
}
#[test]
fn resets_stream_after_segment_gap() {
let message = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
let raw = serde_json::to_vec(&message).expect("message should serialize");
let split = raw.len() / 2;
let client_id = ClientId("client-1".to_string());
let stream_id = StreamId("stream-1".to_string());
let mut reassembler = ClientSegmentReassembler::default();
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id.clone(),
Some(stream_id.clone()),
/*seq_id*/ 7,
/*segment_id*/ 0,
/*segment_count*/ 3,
raw.len(),
&raw[..split],
)),
ClientSegmentObservation::Pending
));
assert!(matches!(
reassembler.observe(chunk_envelope(
client_id.clone(),
Some(stream_id.clone()),
/*seq_id*/ 7,
/*segment_id*/ 2,
/*segment_count*/ 3,
raw.len(),
&raw[split..],
)),
ClientSegmentObservation::ResetStream(reset_client_id, reset_stream_id)
if reset_client_id == client_id && reset_stream_id == stream_id
));
}
@@ -190,7 +231,9 @@ fn resets_incomplete_client_assembly_when_stream_changes() {
&raw[split..],
)) {
ClientSegmentObservation::Forward(reassembled) => *reassembled,
ClientSegmentObservation::Pending | ClientSegmentObservation::Dropped => {
ClientSegmentObservation::Pending
| ClientSegmentObservation::Dropped
| ClientSegmentObservation::ResetStream(_, _) => {
panic!("replacement stream should reassemble")
}
};
@@ -208,7 +251,7 @@ fn resets_incomplete_client_assembly_when_stream_changes() {
raw.len(),
&raw[split..],
)),
ClientSegmentObservation::Dropped
ClientSegmentObservation::ResetStream(_, _)
));
}

View File

@@ -1,4 +1,5 @@
use crate::transport::TransportEvent;
use crate::transport::remote_control::client_tracker::ClientMessageOutcome;
use crate::transport::remote_control::client_tracker::ClientTracker;
use crate::transport::remote_control::client_tracker::REMOTE_CONTROL_IDLE_SWEEP_INTERVAL;
use crate::transport::remote_control::enroll::RemoteControlConnectionAuth;
@@ -140,12 +141,16 @@ impl WebsocketState {
) -> ClientSegmentObservation {
let client_message_key = Self::client_message_key(&client_envelope);
if let Some((key, seq_id)) = client_message_key.as_ref()
&& self
.last_completed_client_chunk_seq_id_by_stream
.get(key)
.is_some_and(|last_seq_id| last_seq_id >= seq_id)
&& let Some(last_seq_id) = self.last_completed_client_chunk_seq_id_by_stream.get(key)
{
return ClientSegmentObservation::Dropped;
if last_seq_id >= seq_id {
return ClientSegmentObservation::Dropped;
}
if *seq_id > last_seq_id.saturating_add(1)
&& let Some(stream_id) = client_envelope.stream_id.clone()
{
return ClientSegmentObservation::ResetStream(client_envelope.client_id, stream_id);
}
}
if let (
Some((_, seq_id)),
@@ -190,11 +195,6 @@ impl WebsocketState {
.remove(&(client_id.clone(), Some(stream_id.clone())));
}
fn invalidate_client_message_client(&mut self, client_id: &ClientId) {
self.last_completed_client_chunk_seq_id_by_stream
.retain(|(cursor_client_id, _), _| cursor_client_id != client_id);
}
fn client_message_key(
client_envelope: &ClientEnvelope,
) -> Option<((ClientId, Option<StreamId>), u64)> {
@@ -210,6 +210,32 @@ impl WebsocketState {
seq_id,
))
}
fn stage_server_envelopes(
&mut self,
seq_key: (ClientId, StreamId),
seq_id: u64,
server_envelopes: Vec<ServerEnvelope>,
) -> Option<Vec<String>> {
let payloads = match server_envelopes
.iter()
.map(serde_json::to_string)
.collect::<Result<Vec<_>, _>>()
{
Ok(payloads) if !payloads.is_empty() => payloads,
Ok(_) => return None,
Err(err) => {
error!("failed to serialize remote-control server event: {err}");
return None;
}
};
for server_envelope in &server_envelopes {
self.outbound_buffer.insert(server_envelope);
}
self.next_seq_id_by_stream
.insert(seq_key, seq_id.saturating_add(1));
Some(payloads)
}
}
pub(crate) struct RemoteControlWebsocket {
@@ -680,21 +706,11 @@ impl RemoteControlWebsocket {
continue;
}
};
let mut payloads = Vec::with_capacity(server_envelopes.len());
for server_envelope in server_envelopes {
let payload = match serde_json::to_string(&server_envelope) {
Ok(payload) => payload,
Err(err) => {
error!("failed to serialize remote-control server event: {err}");
continue;
}
};
state.outbound_buffer.insert(&server_envelope);
payloads.push(payload);
}
state
.next_seq_id_by_stream
.insert(seq_key, seq_id.saturating_add(1));
let Some(payloads) =
state.stage_server_envelopes(seq_key, seq_id, server_envelopes)
else {
continue;
};
(payloads, queued_server_envelope.write_complete_tx)
};
@@ -847,6 +863,21 @@ impl RemoteControlWebsocket {
let client_envelope = match observation {
ClientSegmentObservation::Forward(client_envelope) => *client_envelope,
ClientSegmentObservation::Pending | ClientSegmentObservation::Dropped => continue,
ClientSegmentObservation::ResetStream(client_id, stream_id) => {
if client_tracker
.close_client(&(client_id.clone(), stream_id.clone()))
.await
.is_err()
{
return Ok(());
}
let mut websocket_state = state.lock().await;
websocket_state
.client_segment_reassembler
.invalidate_stream(&client_id, &stream_id);
websocket_state.invalidate_client_message_stream(&client_id, &stream_id);
continue;
}
};
{
@@ -867,33 +898,16 @@ impl RemoteControlWebsocket {
}
}
let closed_client =
matches!(&client_envelope.event, ClientEvent::ClientClosed).then(|| {
(
client_envelope.client_id.clone(),
client_envelope.stream_id.clone(),
)
});
if client_tracker
.handle_message(client_envelope)
.await
.is_err()
{
return Ok(());
}
if let Some((client_id, stream_id)) = closed_client {
let outcome = match client_tracker.handle_message(client_envelope).await {
Ok(outcome) => outcome,
Err(_) => return Ok(()),
};
if let ClientMessageOutcome::StreamClosed(client_id, stream_id) = outcome {
let mut websocket_state = state.lock().await;
if let Some(stream_id) = stream_id {
websocket_state
.client_segment_reassembler
.invalidate_stream(&client_id, &stream_id);
websocket_state.invalidate_client_message_stream(&client_id, &stream_id);
} else {
websocket_state
.client_segment_reassembler
.invalidate_client(&client_id);
websocket_state.invalidate_client_message_client(&client_id);
}
websocket_state
.client_segment_reassembler
.invalidate_stream(&client_id, &stream_id);
websocket_state.invalidate_client_message_stream(&client_id, &stream_id);
}
}
}
@@ -1877,6 +1891,29 @@ mod tests {
.expect("writer should stop cleanly");
}
#[test]
fn websocket_state_does_not_consume_seq_id_without_wire_payloads() {
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
let mut state = WebsocketState {
outbound_buffer,
subscribe_cursor: None,
next_seq_id_by_stream: HashMap::new(),
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
client_segment_reassembler: ClientSegmentReassembler::default(),
};
let seq_key = (
ClientId("client-1".to_string()),
StreamId("stream-1".to_string()),
);
assert_eq!(
state.stage_server_envelopes(seq_key.clone(), /*seq_id*/ 1, Vec::new()),
None
);
assert_eq!(state.next_seq_id_by_stream.get(&seq_key), None);
assert_eq!(state.outbound_buffer.server_envelopes().count(), 0);
}
#[tokio::test]
async fn run_websocket_reader_inner_times_out_without_pong_frames() {
let (client_stream, _server_stream) = connected_websocket_pair().await;
@@ -2150,7 +2187,7 @@ mod tests {
}
#[test]
fn websocket_state_allows_replay_after_rejected_out_of_order_chunk() {
fn websocket_state_resets_stream_after_segment_gap() {
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
let mut state = WebsocketState {
outbound_buffer,
@@ -2159,10 +2196,6 @@ mod tests {
last_completed_client_chunk_seq_id_by_stream: HashMap::new(),
client_segment_reassembler: ClientSegmentReassembler::default(),
};
let first_chunk = client_chunk_envelope(
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 0,
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"x",
);
let second_chunk = client_chunk_envelope(
"client-1", "stream-1", /*seq_id*/ 4, /*segment_id*/ 1,
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"y",
@@ -2170,11 +2203,36 @@ mod tests {
assert!(matches!(
observe_client_message(&mut state, second_chunk),
ClientSegmentObservation::Dropped
ClientSegmentObservation::ResetStream(_, _)
));
}
#[test]
fn websocket_state_resets_stream_after_client_chunk_seq_gap() {
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
let mut state = WebsocketState {
outbound_buffer,
subscribe_cursor: None,
next_seq_id_by_stream: HashMap::new(),
last_completed_client_chunk_seq_id_by_stream: HashMap::from([(
(
ClientId("client-1".to_string()),
Some(StreamId("stream-1".to_string())),
),
4,
)]),
client_segment_reassembler: ClientSegmentReassembler::default(),
};
assert!(matches!(
observe_client_message(&mut state, first_chunk),
ClientSegmentObservation::Pending
observe_client_message(
&mut state,
client_chunk_envelope(
"client-1", "stream-1", /*seq_id*/ 6, /*segment_id*/ 0,
/*segment_count*/ 2, /*message_size_bytes*/ 2, b"x",
)
),
ClientSegmentObservation::ResetStream(_, _)
));
}

View File

@@ -64,6 +64,11 @@ Backpressure behavior:
- When request ingress is saturated, new requests are rejected with a JSON-RPC error code `-32001` and message `"Server overloaded; retry later."`.
- Clients should treat this as retryable and use exponential backoff with jitter.
Remote-control stream recovery:
- Remote-control envelopes use contiguous `seq_id` values per `(client_id, stream_id)` on the wire.
- An inbound `seq_id` gap or segmented-message `segment_id` gap closes the affected remote-control stream. Clients should reconnect and initialize a fresh stream after either gap.
## Message Schema
Currently, you can dump a TypeScript version of the schema using `codex app-server generate-ts`, or a JSON Schema bundle via `codex app-server generate-json-schema`. Each output is specific to the version of Codex you used to run the command, so the generated artifacts are guaranteed to match that version.