diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index afb3cbbed9..f1d78b7479 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2740,6 +2740,7 @@ dependencies = [ "ctor 0.6.3", "futures", "pretty_assertions", + "prost 0.14.3", "reqwest", "serde", "serde_json", diff --git a/codex-rs/exec-server/BUILD.bazel b/codex-rs/exec-server/BUILD.bazel index 224536da8e..e94a5c0043 100644 --- a/codex-rs/exec-server/BUILD.bazel +++ b/codex-rs/exec-server/BUILD.bazel @@ -10,6 +10,9 @@ codex_rust_crate( # they install process-global test-binary dispatch state, and the remote # exec-server cases already rely on serialization around the full CLI path. integration_test_args = ["--test-threads=1"], + integration_compile_data_extra = [ + "src/proto/codex.exec_server.relay.v1.rs", + ], extra_binaries = [ "//codex-rs/bwrap:bwrap", ], diff --git a/codex-rs/exec-server/Cargo.toml b/codex-rs/exec-server/Cargo.toml index 936fa412f1..09a9a71ea0 100644 --- a/codex-rs/exec-server/Cargo.toml +++ b/codex-rs/exec-server/Cargo.toml @@ -26,6 +26,7 @@ codex-utils-pty = { workspace = true } codex-utils-rustls-provider = { workspace = true } futures = { workspace = true } reqwest = { workspace = true, features = ["json", "rustls-tls", "stream"] } +prost = "0.14.3" serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } thiserror = { workspace = true } diff --git a/codex-rs/exec-server/README.md b/codex-rs/exec-server/README.md index 81664eaca0..1eaf6e69eb 100644 --- a/codex-rs/exec-server/README.md +++ b/codex-rs/exec-server/README.md @@ -30,7 +30,65 @@ It requires a bearer token in `CODEX_EXEC_SERVER_REMOTE_BEARER_TOKEN`. Wire framing: -- websocket: one JSON-RPC message per websocket text frame +- local websocket: one JSON-RPC message per websocket frame +- remote websocket: binary protobuf relay frames carrying JSON-RPC payloads + +## Remote Relay Message Format + +In remote mode, the harness and executor communicate through rendezvous using +`codex.exec_server.relay.v1.RelayMessageFrame`; the checked-in schema is in +`src/proto/codex.exec_server.relay.v1.proto`. The relay frame carries stream +identity plus endpoint-owned reliability metadata: + +```text +version +stream_id +body // data | ack_frame | resume | reset | heartbeat +ack // highest contiguous peer segment seq received +ack_bits // bitset for peer segment seqs after ack +seq // data only: segment sequence number +segment_index // data only: 0-based index within message +segment_count // data only: number of segments in message +payload // data only: JSON-RPC message bytes or segment bytes +next_seq // resume only: next sender seq +reason // reset only: reset reason +``` + +`stream_id` identifies one virtual harness/executor JSON-RPC session on the +executor websocket. The harness generates a UUIDv4 `stream_id`; the executor +demuxes frames by `stream_id` and runs an independent `ConnectionProcessor` per +stream. + +Use segment-level sequence numbers for reliability: + +```text +seq = 0, 1, 2, 3, ... +``` + +Use contiguous segment sequence ranges to identify and stitch a segmented +application message: + +```text +message_start_seq = seq - segment_index +segment_index = 0 +segment_count = 1 +``` + +`message_start_seq` is derived by the receiver, not sent on the wire. For +unsplit messages, `message_start_seq == seq`, `segment_index == 0`, and +`segment_count == 1`. + +Use cumulative `ack` plus fixed-size `ack_bits` instead of variable ack ranges: + +```text +ack = highest contiguous received segment seq +bit i in ack_bits acknowledges seq = ack + 1 + i +``` + +Send `ack` and `ack_bits` redundantly on every outbound frame. Acks are not +themselves acked. Acks, retries, duplicate suppression, segmentation, and +reassembly are endpoint responsibilities; rendezvous only routes relay frames +by `stream_id`. ## Lifecycle diff --git a/codex-rs/exec-server/src/client_transport.rs b/codex-rs/exec-server/src/client_transport.rs index 8ca1eb0280..4bdc09a80e 100644 --- a/codex-rs/exec-server/src/client_transport.rs +++ b/codex-rs/exec-server/src/client_transport.rs @@ -15,6 +15,7 @@ use crate::client_api::RemoteExecServerConnectArgs; use crate::client_api::StdioExecServerCommand; use crate::client_api::StdioExecServerConnectArgs; use crate::connection::JsonRpcConnection; +use crate::relay::harness_connection_from_websocket; const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment"; @@ -69,14 +70,13 @@ impl ExecServerClient { source, })?; - Self::connect( - JsonRpcConnection::from_websocket( - stream, - format!("exec-server websocket {websocket_url}"), - ), - args.into(), - ) - .await + let connection_label = format!("exec-server websocket {websocket_url}"); + let connection = if is_rendezvous_harness_url(&websocket_url) { + harness_connection_from_websocket(stream, connection_label) + } else { + JsonRpcConnection::from_websocket(stream, connection_label) + }; + Self::connect(connection, args.into()).await } pub(crate) async fn connect_stdio_command( @@ -120,6 +120,16 @@ impl ExecServerClient { } } +fn is_rendezvous_harness_url(websocket_url: &str) -> bool { + let Some((_path, query)) = websocket_url.split_once('?') else { + return false; + }; + query + .split('&') + .filter_map(|pair| pair.split_once('=')) + .any(|(key, value)| key == "role" && value == "harness") +} + fn stdio_command_process(stdio_command: &StdioExecServerCommand) -> Command { let mut command = Command::new(&stdio_command.program); command.args(&stdio_command.args); diff --git a/codex-rs/exec-server/src/lib.rs b/codex-rs/exec-server/src/lib.rs index d8c147127c..872f16ce32 100644 --- a/codex-rs/exec-server/src/lib.rs +++ b/codex-rs/exec-server/src/lib.rs @@ -13,6 +13,8 @@ mod local_process; mod process; mod process_id; mod protocol; +mod relay; +mod relay_proto; mod remote; mod remote_file_system; mod remote_process; diff --git a/codex-rs/exec-server/src/proto/codex.exec_server.relay.v1.proto b/codex-rs/exec-server/src/proto/codex.exec_server.relay.v1.proto new file mode 100644 index 0000000000..46527d80cc --- /dev/null +++ b/codex-rs/exec-server/src/proto/codex.exec_server.relay.v1.proto @@ -0,0 +1,37 @@ +syntax = "proto3"; + +package codex.exec_server.relay.v1; + +message RelayMessageFrame { + uint32 version = 1; + string stream_id = 2; + uint32 ack = 3; + uint32 ack_bits = 4; + + oneof body { + RelayData data = 5; + RelayAck ack_frame = 6; + RelayResume resume = 7; + RelayReset reset = 8; + RelayHeartbeat heartbeat = 9; + } +} + +message RelayData { + uint32 seq = 1; + uint32 segment_index = 2; + uint32 segment_count = 3; + bytes payload = 4; +} + +message RelayAck {} + +message RelayResume { + uint32 next_seq = 1; +} + +message RelayReset { + string reason = 1; +} + +message RelayHeartbeat {} diff --git a/codex-rs/exec-server/src/proto/codex.exec_server.relay.v1.rs b/codex-rs/exec-server/src/proto/codex.exec_server.relay.v1.rs new file mode 100644 index 0000000000..072a003dac --- /dev/null +++ b/codex-rs/exec-server/src/proto/codex.exec_server.relay.v1.rs @@ -0,0 +1,54 @@ +// This file is @generated by prost-build. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RelayMessageFrame { + #[prost(uint32, tag = "1")] + pub version: u32, + #[prost(string, tag = "2")] + pub stream_id: ::prost::alloc::string::String, + #[prost(uint32, tag = "3")] + pub ack: u32, + #[prost(uint32, tag = "4")] + pub ack_bits: u32, + #[prost(oneof = "relay_message_frame::Body", tags = "5, 6, 7, 8, 9")] + pub body: ::core::option::Option, +} +pub mod relay_message_frame { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Body { + #[prost(message, tag = "5")] + Data(super::RelayData), + #[prost(message, tag = "6")] + AckFrame(super::RelayAck), + #[prost(message, tag = "7")] + Resume(super::RelayResume), + #[prost(message, tag = "8")] + Reset(super::RelayReset), + #[prost(message, tag = "9")] + Heartbeat(super::RelayHeartbeat), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RelayData { + #[prost(uint32, tag = "1")] + pub seq: u32, + #[prost(uint32, tag = "2")] + pub segment_index: u32, + #[prost(uint32, tag = "3")] + pub segment_count: u32, + #[prost(bytes = "vec", tag = "4")] + pub payload: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct RelayAck {} +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct RelayResume { + #[prost(uint32, tag = "1")] + pub next_seq: u32, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct RelayReset { + #[prost(string, tag = "1")] + pub reason: ::prost::alloc::string::String, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct RelayHeartbeat {} diff --git a/codex-rs/exec-server/src/relay.rs b/codex-rs/exec-server/src/relay.rs new file mode 100644 index 0000000000..bce787cfc2 --- /dev/null +++ b/codex-rs/exec-server/src/relay.rs @@ -0,0 +1,455 @@ +use std::collections::HashMap; + +use codex_app_server_protocol::JSONRPCMessage; +use futures::SinkExt; +use futures::StreamExt; +use prost::Message as ProstMessage; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::sync::mpsc; +use tokio::sync::watch; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::tungstenite::Message; +use tracing::debug; +use tracing::warn; +use uuid::Uuid; + +use crate::ExecServerError; +use crate::connection::CHANNEL_CAPACITY; +use crate::connection::JsonRpcConnection; +use crate::connection::JsonRpcConnectionEvent; +use crate::connection::JsonRpcTransport; +use crate::relay_proto::RelayData; +use crate::relay_proto::RelayMessageFrame; +use crate::relay_proto::RelayResume; +use crate::relay_proto::relay_message_frame; +use crate::server::ConnectionProcessor; + +const RELAY_MESSAGE_FRAME_VERSION: u32 = 1; + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +enum RelayFrameBodyKind { + Data, + Ack, + Resume, + Reset, + Heartbeat, +} + +impl RelayMessageFrame { + fn data(stream_id: String, seq: u32, payload: Vec) -> Self { + Self { + version: RELAY_MESSAGE_FRAME_VERSION, + stream_id, + ack: 0, + ack_bits: 0, + body: Some(relay_message_frame::Body::Data(RelayData { + seq, + segment_index: 0, + segment_count: 1, + payload, + })), + } + } + + fn resume(stream_id: String) -> Self { + Self { + version: RELAY_MESSAGE_FRAME_VERSION, + stream_id, + ack: 0, + ack_bits: 0, + body: Some(relay_message_frame::Body::Resume(RelayResume { + next_seq: 0, + })), + } + } + + fn validate(&self) -> Result { + if self.version != RELAY_MESSAGE_FRAME_VERSION { + return Err(ExecServerError::Protocol(format!( + "unsupported relay message frame version {}", + self.version + ))); + } + if self.stream_id.trim().is_empty() { + return Err(ExecServerError::Protocol( + "relay message frame is missing stream_id".to_string(), + )); + } + match self.body.as_ref() { + Some(relay_message_frame::Body::Data(data)) => { + if data.segment_index != 0 || data.segment_count != 1 || data.payload.is_empty() { + return Err(ExecServerError::Protocol( + "relay data message frame is missing required fields".to_string(), + )); + } + Ok(RelayFrameBodyKind::Data) + } + Some(relay_message_frame::Body::AckFrame(_)) => Ok(RelayFrameBodyKind::Ack), + Some(relay_message_frame::Body::Resume(_)) => Ok(RelayFrameBodyKind::Resume), + Some(relay_message_frame::Body::Reset(reset)) => { + if reset.reason.is_empty() { + return Err(ExecServerError::Protocol( + "relay reset message frame is missing reason".to_string(), + )); + } + Ok(RelayFrameBodyKind::Reset) + } + Some(relay_message_frame::Body::Heartbeat(_)) => Ok(RelayFrameBodyKind::Heartbeat), + None => Err(ExecServerError::Protocol( + "relay message frame is missing body".to_string(), + )), + } + } + + fn into_jsonrpc_message(self) -> Result { + let kind = self.validate()?; + if kind != RelayFrameBodyKind::Data { + return Err(ExecServerError::Protocol( + "expected relay data message frame".to_string(), + )); + } + let payload = match self.body { + Some(relay_message_frame::Body::Data(data)) => data.payload, + _ => Vec::new(), + }; + serde_json::from_slice(&payload).map_err(ExecServerError::Json) + } + + fn into_reset_reason(self) -> Option { + match self.body { + Some(relay_message_frame::Body::Reset(reset)) if !reset.reason.is_empty() => { + Some(reset.reason) + } + _ => None, + } + } +} + +fn encode_relay_message_frame(frame: &RelayMessageFrame) -> Vec { + frame.encode_to_vec() +} + +fn decode_relay_message_frame(payload: &[u8]) -> Result { + RelayMessageFrame::decode(payload) + .map_err(|err| ExecServerError::Protocol(format!("invalid relay message frame: {err}"))) +} + +fn jsonrpc_payload(message: &JSONRPCMessage) -> Result, ExecServerError> { + serde_json::to_vec(message).map_err(ExecServerError::Json) +} + +pub(crate) fn harness_connection_from_websocket( + stream: WebSocketStream, + connection_label: String, +) -> JsonRpcConnection +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + let stream_id = Uuid::new_v4().to_string(); + let (mut websocket_writer, mut websocket_reader) = stream.split(); + let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (disconnected_tx, disconnected_rx) = watch::channel(false); + + let reader_label = connection_label; + let reader_stream_id = stream_id.clone(); + let incoming_tx_for_reader = incoming_tx; + let disconnected_tx_for_reader = disconnected_tx.clone(); + let reader_task = tokio::spawn(async move { + loop { + match websocket_reader.next().await { + Some(Ok(Message::Binary(payload))) => { + let frame = match decode_relay_message_frame(payload.as_ref()) { + Ok(frame) => frame, + Err(err) => { + let _ = incoming_tx_for_reader + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: format!( + "failed to parse relay message frame from {reader_label}: {err}" + ), + }) + .await; + continue; + } + }; + if frame.stream_id != reader_stream_id { + continue; + } + let kind = match frame.validate() { + Ok(kind) => kind, + Err(err) => { + let _ = incoming_tx_for_reader + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: err.to_string(), + }) + .await; + continue; + } + }; + match kind { + RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() { + Ok(message) => { + if incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + break; + } + } + Err(err) => { + let _ = incoming_tx_for_reader + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: err.to_string(), + }) + .await; + } + }, + RelayFrameBodyKind::Reset => { + let _ = disconnected_tx_for_reader.send(true); + let _ = incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Disconnected { + reason: frame.into_reset_reason(), + }) + .await; + break; + } + RelayFrameBodyKind::Ack + | RelayFrameBodyKind::Resume + | RelayFrameBodyKind::Heartbeat => {} + } + } + Some(Ok(Message::Close(_))) | None => { + let _ = disconnected_tx_for_reader.send(true); + let _ = incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Disconnected { reason: None }) + .await; + break; + } + Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {} + Some(Ok(Message::Text(_))) => { + let _ = incoming_tx_for_reader + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: "relay exec-server transport expects binary protobuf frames" + .to_string(), + }) + .await; + } + Some(Err(err)) => { + let _ = disconnected_tx_for_reader.send(true); + let _ = incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Disconnected { + reason: Some(format!( + "failed to read relay websocket frame from {reader_label}: {err}" + )), + }) + .await; + break; + } + } + } + }); + + let writer_task = tokio::spawn(async move { + let resume = RelayMessageFrame::resume(stream_id.clone()); + if websocket_writer + .send(Message::Binary(encode_relay_message_frame(&resume).into())) + .await + .is_err() + { + let _ = disconnected_tx.send(true); + return; + } + + let mut next_seq = 0u32; + while let Some(message) = outgoing_rx.recv().await { + let payload = match jsonrpc_payload(&message) { + Ok(payload) => payload, + Err(err) => { + warn!("failed to serialize JSON-RPC payload for relay transport: {err}"); + break; + } + }; + let frame = RelayMessageFrame::data(stream_id.clone(), next_seq, payload); + next_seq = next_seq.wrapping_add(1); + if websocket_writer + .send(Message::Binary(encode_relay_message_frame(&frame).into())) + .await + .is_err() + { + let _ = disconnected_tx.send(true); + break; + } + } + }); + + JsonRpcConnection { + outgoing_tx, + incoming_rx, + disconnected_rx, + task_handles: vec![reader_task, writer_task], + transport: JsonRpcTransport::Plain, + } +} + +pub(crate) async fn run_multiplexed_executor( + stream: WebSocketStream, + processor: ConnectionProcessor, +) where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + let (mut websocket_writer, mut websocket_reader) = stream.split(); + let (physical_outgoing_tx, mut physical_outgoing_rx) = + mpsc::channel::>(CHANNEL_CAPACITY); + let writer_task = tokio::spawn(async move { + while let Some(encoded) = physical_outgoing_rx.recv().await { + if websocket_writer + .send(Message::Binary(encoded.into())) + .await + .is_err() + { + break; + } + } + }); + + let mut streams: HashMap = HashMap::new(); + loop { + let frame = match websocket_reader.next().await { + Some(Ok(Message::Binary(payload))) => { + match decode_relay_message_frame(payload.as_ref()) { + Ok(frame) => frame, + Err(err) => { + warn!("dropping malformed relay message frame from harness: {err}"); + continue; + } + } + } + Some(Ok(Message::Close(_))) | None => break, + Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue, + Some(Ok(Message::Text(_))) => { + warn!("dropping non-binary relay message frame from harness"); + continue; + } + Some(Err(err)) => { + debug!("multiplexed executor websocket read failed: {err}"); + break; + } + }; + + let kind = match frame.validate() { + Ok(kind) => kind, + Err(err) => { + warn!("dropping invalid relay message frame: {err}"); + continue; + } + }; + + match kind { + RelayFrameBodyKind::Data => { + let stream_id = frame.stream_id.clone(); + let message = match frame.into_jsonrpc_message() { + Ok(message) => message, + Err(err) => { + warn!("dropping malformed relay data message frame: {err}"); + continue; + } + }; + let stream = streams.entry(stream_id.clone()).or_insert_with(|| { + spawn_virtual_stream( + stream_id.clone(), + processor.clone(), + physical_outgoing_tx.clone(), + ) + }); + if stream + .incoming_tx + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + streams.remove(&stream_id); + } + } + RelayFrameBodyKind::Reset => { + if let Some(stream) = streams.remove(&frame.stream_id) { + stream.disconnect(frame.into_reset_reason()).await; + } + } + RelayFrameBodyKind::Ack + | RelayFrameBodyKind::Resume + | RelayFrameBodyKind::Heartbeat => {} + } + } + + for (_stream_id, stream) in streams { + stream.disconnect(/*reason*/ None).await; + } + drop(physical_outgoing_tx); + let _ = writer_task.await; +} + +struct VirtualStream { + incoming_tx: mpsc::Sender, + disconnected_tx: watch::Sender, +} + +impl VirtualStream { + async fn disconnect(self, reason: Option) { + let _ = self.disconnected_tx.send(true); + let _ = self + .incoming_tx + .send(JsonRpcConnectionEvent::Disconnected { reason }) + .await; + } +} + +fn spawn_virtual_stream( + stream_id: String, + processor: ConnectionProcessor, + physical_outgoing_tx: mpsc::Sender>, +) -> VirtualStream { + let (json_outgoing_tx, mut json_outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (disconnected_tx, disconnected_rx) = watch::channel(false); + + let writer_stream_id = stream_id; + let writer_task = tokio::spawn(async move { + let mut next_seq = 0u32; + while let Some(message) = json_outgoing_rx.recv().await { + let payload = match jsonrpc_payload(&message) { + Ok(payload) => payload, + Err(err) => { + warn!("failed to serialize virtual stream JSON-RPC payload: {err}"); + break; + } + }; + let frame = RelayMessageFrame::data(writer_stream_id.clone(), next_seq, payload); + next_seq = next_seq.wrapping_add(1); + if physical_outgoing_tx + .send(encode_relay_message_frame(&frame)) + .await + .is_err() + { + break; + } + } + }); + + let connection = JsonRpcConnection { + outgoing_tx: json_outgoing_tx, + incoming_rx, + disconnected_rx, + task_handles: vec![writer_task], + transport: JsonRpcTransport::Plain, + }; + tokio::spawn(async move { + processor.run_connection(connection).await; + }); + + VirtualStream { + incoming_tx, + disconnected_tx, + } +} diff --git a/codex-rs/exec-server/src/relay_proto.rs b/codex-rs/exec-server/src/relay_proto.rs new file mode 100644 index 0000000000..b8a938b8c7 --- /dev/null +++ b/codex-rs/exec-server/src/relay_proto.rs @@ -0,0 +1,7 @@ +#[path = "proto/codex.exec_server.relay.v1.rs"] +mod generated; + +pub(crate) use generated::RelayData; +pub(crate) use generated::RelayMessageFrame; +pub(crate) use generated::RelayResume; +pub(crate) use generated::relay_message_frame; diff --git a/codex-rs/exec-server/src/remote.rs b/codex-rs/exec-server/src/remote.rs index bb22105c19..32c0d5bc8e 100644 --- a/codex-rs/exec-server/src/remote.rs +++ b/codex-rs/exec-server/src/remote.rs @@ -11,7 +11,7 @@ use codex_utils_rustls_provider::ensure_rustls_crypto_provider; use crate::ExecServerError; use crate::ExecServerRuntimePaths; -use crate::connection::JsonRpcConnection; +use crate::relay::run_multiplexed_executor; use crate::server::ConnectionProcessor; pub const CODEX_EXEC_SERVER_REMOTE_BEARER_TOKEN_ENV_VAR: &str = @@ -113,7 +113,7 @@ impl RemoteExecutorConfig { Self::with_bearer_token(base_url, executor_id, read_remote_bearer_token_from_env()?) } - fn with_bearer_token( + pub fn with_bearer_token( base_url: String, executor_id: String, bearer_token: String, @@ -150,12 +150,7 @@ pub async fn run_remote_executor( match connect_async(response.url.as_str()).await { Ok((websocket, _)) => { backoff = Duration::from_secs(1); - processor - .run_connection(JsonRpcConnection::from_websocket( - websocket, - "remote exec-server websocket".to_string(), - )) - .await; + run_multiplexed_executor(websocket, processor.clone()).await; } Err(err) => { warn!("failed to connect remote exec-server websocket: {err}"); diff --git a/codex-rs/exec-server/tests/relay.rs b/codex-rs/exec-server/tests/relay.rs new file mode 100644 index 0000000000..d228db9ee5 --- /dev/null +++ b/codex-rs/exec-server/tests/relay.rs @@ -0,0 +1,353 @@ +mod common; + +#[path = "../src/proto/codex.exec_server.relay.v1.rs"] +mod relay_proto; + +use std::collections::HashMap; +use std::time::Duration; + +use anyhow::Context; +use anyhow::Result; +use anyhow::anyhow; +use anyhow::bail; +use codex_app_server_protocol::JSONRPCError; +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCNotification; +use codex_app_server_protocol::JSONRPCRequest; +use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::RequestId; +use codex_exec_server::ExecServerRuntimePaths; +use codex_exec_server::InitializeParams; +use codex_exec_server::InitializeResponse; +use codex_exec_server::RemoteExecutorConfig; +use futures::SinkExt; +use futures::StreamExt; +use pretty_assertions::assert_eq; +use prost::Message as ProstMessage; +use relay_proto::RelayData; +use relay_proto::RelayMessageFrame; +use relay_proto::RelayReset; +use relay_proto::relay_message_frame; +use tokio::net::TcpListener; +use tokio::time::timeout; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::accept_async; +use tokio_tungstenite::tungstenite::Message; +use uuid::Uuid; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::ResponseTemplate; +use wiremock::matchers::header; +use wiremock::matchers::method; +use wiremock::matchers::path; + +const EXECUTOR_ID: &str = "exec-mux-test"; +const REGISTRY_TOKEN: &str = "registry-token"; +const RELAY_MESSAGE_FRAME_VERSION: u32 = 1; +const TEST_TIMEOUT: Duration = Duration::from_secs(5); + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn multiplexed_remote_executor_routes_independent_virtual_streams() -> Result<()> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let rendezvous_url = format!("ws://{}", listener.local_addr()?); + let registry = MockServer::start().await; + Mock::given(method("POST")) + .and(path(format!("/cloud/executor/{EXECUTOR_ID}/register"))) + .and(header("authorization", format!("Bearer {REGISTRY_TOKEN}"))) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "executor_id": EXECUTOR_ID, + "url": rendezvous_url, + }))) + .mount(®istry) + .await; + + let (codex_exe, codex_linux_sandbox_exe) = common::current_test_binary_helper_paths()?; + let runtime_paths = ExecServerRuntimePaths::new(codex_exe, codex_linux_sandbox_exe)?; + let config = RemoteExecutorConfig::with_bearer_token( + registry.uri(), + EXECUTOR_ID.to_string(), + REGISTRY_TOKEN.to_string(), + )?; + let remote_executor = tokio::spawn(codex_exec_server::run_remote_executor( + config, + runtime_paths, + )); + + let (socket, _peer_addr) = timeout(TEST_TIMEOUT, listener.accept()) + .await + .context("remote executor should connect to fake rendezvous")??; + let mut websocket = timeout(TEST_TIMEOUT, accept_async(socket)) + .await + .context("fake rendezvous should accept executor websocket")??; + + let stream_a = "stream-a"; + let stream_b = "stream-b"; + send_relay_message( + &mut websocket, + stream_a, + /*seq*/ 0, + initialize_request(/*id*/ 1, "relay-test-a")?, + ) + .await?; + send_relay_message( + &mut websocket, + stream_b, + /*seq*/ 0, + initialize_request(/*id*/ 1, "relay-test-b")?, + ) + .await?; + + let initialize_responses = read_relay_messages_by_stream(&mut websocket, /*count*/ 2).await?; + let session_a = + assert_initialize_response(initialize_responses.get(stream_a), stream_a, /*id*/ 1)?; + let session_b = + assert_initialize_response(initialize_responses.get(stream_b), stream_b, /*id*/ 1)?; + assert_ne!(session_a, session_b); + + send_relay_message( + &mut websocket, + stream_a, + /*seq*/ 1, + notification("initialized", serde_json::json!({})), + ) + .await?; + send_relay_message( + &mut websocket, + stream_b, + /*seq*/ 1, + notification("initialized", serde_json::json!({})), + ) + .await?; + + send_relay_message( + &mut websocket, + stream_a, + /*seq*/ 2, + request(/*id*/ 2, "test/unknown-a", serde_json::json!({})), + ) + .await?; + send_relay_message( + &mut websocket, + stream_b, + /*seq*/ 2, + request(/*id*/ 2, "test/unknown-b", serde_json::json!({})), + ) + .await?; + + let unknown_method_responses = + read_relay_messages_by_stream(&mut websocket, /*count*/ 2).await?; + assert_error_response( + unknown_method_responses.get(stream_a), + stream_a, + /*id*/ 2, + "test/unknown-a", + )?; + assert_error_response( + unknown_method_responses.get(stream_b), + stream_b, + /*id*/ 2, + "test/unknown-b", + )?; + + send_relay_reset(&mut websocket, stream_a, "test_reset").await?; + send_relay_message( + &mut websocket, + stream_b, + /*seq*/ 3, + request( + /*id*/ 3, + "test/unknown-b-after-reset", + serde_json::json!({}), + ), + ) + .await?; + + let (stream_id, message) = read_relay_message(&mut websocket).await?; + assert_eq!(stream_id, stream_b); + assert_error_response( + Some(&message), + stream_b, + /*id*/ 3, + "test/unknown-b-after-reset", + )?; + + websocket.close(None).await?; + remote_executor.abort(); + let _ = remote_executor.await; + Ok(()) +} + +async fn send_relay_message( + websocket: &mut WebSocketStream, + stream_id: &str, + seq: u32, + message: JSONRPCMessage, +) -> Result<()> +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + let payload = serde_json::to_vec(&message)?; + let frame = RelayMessageFrame { + version: RELAY_MESSAGE_FRAME_VERSION, + stream_id: stream_id.to_string(), + ack: 0, + ack_bits: 0, + body: Some(relay_message_frame::Body::Data(RelayData { + seq, + segment_index: 0, + segment_count: 1, + payload, + })), + }; + send_relay_frame(websocket, frame).await +} + +async fn send_relay_reset( + websocket: &mut WebSocketStream, + stream_id: &str, + reason: &str, +) -> Result<()> +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + send_relay_frame( + websocket, + RelayMessageFrame { + version: RELAY_MESSAGE_FRAME_VERSION, + stream_id: stream_id.to_string(), + ack: 0, + ack_bits: 0, + body: Some(relay_message_frame::Body::Reset(RelayReset { + reason: reason.to_string(), + })), + }, + ) + .await +} + +async fn send_relay_frame( + websocket: &mut WebSocketStream, + frame: RelayMessageFrame, +) -> Result<()> +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + websocket + .send(Message::Binary(frame.encode_to_vec().into())) + .await?; + Ok(()) +} + +async fn read_relay_messages_by_stream( + websocket: &mut WebSocketStream, + count: usize, +) -> Result> +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + let mut messages = HashMap::new(); + for _ in 0..count { + let (stream_id, message) = read_relay_message(websocket).await?; + if messages.insert(stream_id.clone(), message).is_some() { + bail!("received duplicate response for stream {stream_id}"); + } + } + Ok(messages) +} + +async fn read_relay_message( + websocket: &mut WebSocketStream, +) -> Result<(String, JSONRPCMessage)> +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + loop { + let frame = timeout(TEST_TIMEOUT, websocket.next()) + .await + .context("timed out waiting for relay frame")? + .ok_or_else(|| anyhow!("executor websocket closed"))??; + match frame { + Message::Binary(bytes) => { + let frame = RelayMessageFrame::decode(bytes.as_ref())?; + let stream_id = frame.stream_id; + let Some(relay_message_frame::Body::Data(data)) = frame.body else { + continue; + }; + let message = serde_json::from_slice(&data.payload)?; + return Ok((stream_id, message)); + } + Message::Ping(_) | Message::Pong(_) => {} + Message::Close(_) => bail!("executor websocket closed"), + Message::Text(_) => bail!("executor sent text frame on relay websocket"), + Message::Frame(_) => {} + } + } +} + +fn initialize_request(id: i64, client_name: &str) -> Result { + Ok(request( + id, + "initialize", + serde_json::to_value(InitializeParams { + client_name: client_name.to_string(), + resume_session_id: None, + })?, + )) +} + +fn request(id: i64, method: &str, params: serde_json::Value) -> JSONRPCMessage { + JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(id), + method: method.to_string(), + params: Some(params), + trace: None, + }) +} + +fn notification(method: &str, params: serde_json::Value) -> JSONRPCMessage { + JSONRPCMessage::Notification(JSONRPCNotification { + method: method.to_string(), + params: Some(params), + }) +} + +fn assert_initialize_response( + message: Option<&JSONRPCMessage>, + stream_id: &str, + id: i64, +) -> Result { + let message = message.ok_or_else(|| anyhow!("missing initialize response for {stream_id}"))?; + let JSONRPCMessage::Response(JSONRPCResponse { + id: response_id, + result, + }) = message + else { + bail!("expected initialize response for {stream_id}, got {message:?}"); + }; + assert_eq!(response_id, &RequestId::Integer(id)); + let response: InitializeResponse = serde_json::from_value(result.clone())?; + Ok(Uuid::parse_str(&response.session_id)?) +} + +fn assert_error_response( + message: Option<&JSONRPCMessage>, + stream_id: &str, + id: i64, + expected_method: &str, +) -> Result<()> { + let message = message.ok_or_else(|| anyhow!("missing error response for {stream_id}"))?; + let JSONRPCMessage::Error(JSONRPCError { + id: response_id, + error, + }) = message + else { + bail!("expected error response for {stream_id}, got {message:?}"); + }; + assert_eq!(response_id, &RequestId::Integer(id)); + assert!( + error.message.contains(expected_method), + "expected error for {stream_id} to mention {expected_method}, got {}", + error.message + ); + Ok(()) +}