diff --git a/codex-rs/exec-server/src/client_transport.rs b/codex-rs/exec-server/src/client_transport.rs index 8ca1eb0280..4bdc09a80e 100644 --- a/codex-rs/exec-server/src/client_transport.rs +++ b/codex-rs/exec-server/src/client_transport.rs @@ -15,6 +15,7 @@ use crate::client_api::RemoteExecServerConnectArgs; use crate::client_api::StdioExecServerCommand; use crate::client_api::StdioExecServerConnectArgs; use crate::connection::JsonRpcConnection; +use crate::relay::harness_connection_from_websocket; const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment"; @@ -69,14 +70,13 @@ impl ExecServerClient { source, })?; - Self::connect( - JsonRpcConnection::from_websocket( - stream, - format!("exec-server websocket {websocket_url}"), - ), - args.into(), - ) - .await + let connection_label = format!("exec-server websocket {websocket_url}"); + let connection = if is_rendezvous_harness_url(&websocket_url) { + harness_connection_from_websocket(stream, connection_label) + } else { + JsonRpcConnection::from_websocket(stream, connection_label) + }; + Self::connect(connection, args.into()).await } pub(crate) async fn connect_stdio_command( @@ -120,6 +120,16 @@ impl ExecServerClient { } } +fn is_rendezvous_harness_url(websocket_url: &str) -> bool { + let Some((_path, query)) = websocket_url.split_once('?') else { + return false; + }; + query + .split('&') + .filter_map(|pair| pair.split_once('=')) + .any(|(key, value)| key == "role" && value == "harness") +} + fn stdio_command_process(stdio_command: &StdioExecServerCommand) -> Command { let mut command = Command::new(&stdio_command.program); command.args(&stdio_command.args); diff --git a/codex-rs/exec-server/src/lib.rs b/codex-rs/exec-server/src/lib.rs index d8c147127c..9fbca92d7c 100644 --- a/codex-rs/exec-server/src/lib.rs +++ b/codex-rs/exec-server/src/lib.rs @@ -13,6 +13,7 @@ mod local_process; mod process; mod process_id; mod protocol; +mod relay; mod remote; mod remote_file_system; mod remote_process; diff --git a/codex-rs/exec-server/src/relay.rs b/codex-rs/exec-server/src/relay.rs new file mode 100644 index 0000000000..ac1dc1bc67 --- /dev/null +++ b/codex-rs/exec-server/src/relay.rs @@ -0,0 +1,462 @@ +use std::collections::HashMap; + +use base64::Engine; +use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; +use codex_app_server_protocol::JSONRPCMessage; +use futures::SinkExt; +use futures::StreamExt; +use serde::Deserialize; +use serde::Serialize; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::sync::mpsc; +use tokio::sync::watch; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::tungstenite::Message; +use tracing::debug; +use tracing::warn; +use uuid::Uuid; + +use crate::ExecServerError; +use crate::connection::CHANNEL_CAPACITY; +use crate::connection::JsonRpcConnection; +use crate::connection::JsonRpcConnectionEvent; +use crate::connection::JsonRpcTransport; +use crate::server::ConnectionProcessor; + +const RELAY_MESSAGE_FRAME_VERSION: u32 = 1; +#[derive(Debug, Clone, Copy, Eq, PartialEq, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +enum RelayMessageFrameKind { + Data, + Ack, + Resume, + Reset, + Heartbeat, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +struct RelayMessageFrame { + version: u32, + stream_id: String, + kind: RelayMessageFrameKind, + ack: u32, + ack_bits: u32, + #[serde(skip_serializing_if = "Option::is_none")] + seq: Option, + #[serde(skip_serializing_if = "Option::is_none")] + segment_index: Option, + #[serde(skip_serializing_if = "Option::is_none")] + segment_count: Option, + #[serde(skip_serializing_if = "Option::is_none")] + payload_base64: Option, + #[serde(skip_serializing_if = "Option::is_none")] + reason: Option, +} + +impl RelayMessageFrame { + fn data(stream_id: String, seq: u32, payload: Vec) -> Self { + Self { + version: RELAY_MESSAGE_FRAME_VERSION, + stream_id, + kind: RelayMessageFrameKind::Data, + ack: 0, + ack_bits: 0, + seq: Some(seq), + segment_index: Some(0), + segment_count: Some(1), + payload_base64: Some(BASE64_STANDARD.encode(payload)), + reason: None, + } + } + + fn resume(stream_id: String) -> Self { + Self { + version: RELAY_MESSAGE_FRAME_VERSION, + stream_id, + kind: RelayMessageFrameKind::Resume, + ack: 0, + ack_bits: 0, + seq: None, + segment_index: None, + segment_count: None, + payload_base64: None, + reason: None, + } + } + + fn validate(&self) -> Result<(), ExecServerError> { + if self.version != RELAY_MESSAGE_FRAME_VERSION { + return Err(ExecServerError::Protocol(format!( + "unsupported relay message frame version {}", + self.version + ))); + } + if self.stream_id.trim().is_empty() { + return Err(ExecServerError::Protocol( + "relay message frame is missing streamId".to_string(), + )); + } + if self.kind == RelayMessageFrameKind::Data + && (self.seq.is_none() + || self.segment_index != Some(0) + || self.segment_count != Some(1) + || self.payload_base64.is_none()) + { + return Err(ExecServerError::Protocol( + "relay data message frame is missing required fields".to_string(), + )); + } + if self.kind == RelayMessageFrameKind::Reset && self.reason.is_none() { + return Err(ExecServerError::Protocol( + "relay reset message frame is missing reason".to_string(), + )); + } + Ok(()) + } + + fn into_jsonrpc_message(self) -> Result { + self.validate()?; + if self.kind != RelayMessageFrameKind::Data { + return Err(ExecServerError::Protocol( + "expected relay data message frame".to_string(), + )); + } + let payload = BASE64_STANDARD + .decode(self.payload_base64.unwrap_or_default()) + .map_err(|err| ExecServerError::Protocol(format!("invalid payloadBase64: {err}")))?; + serde_json::from_slice(&payload).map_err(ExecServerError::Json) + } +} + +fn serialize_relay_message_frame(frame: &RelayMessageFrame) -> Result { + serde_json::to_string(frame).map_err(ExecServerError::Json) +} + +fn jsonrpc_payload(message: &JSONRPCMessage) -> Result, ExecServerError> { + serde_json::to_vec(message).map_err(ExecServerError::Json) +} + +pub(crate) fn harness_connection_from_websocket( + stream: WebSocketStream, + connection_label: String, +) -> JsonRpcConnection +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + let stream_id = Uuid::new_v4().to_string(); + let (mut websocket_writer, mut websocket_reader) = stream.split(); + let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (disconnected_tx, disconnected_rx) = watch::channel(false); + + let reader_label = connection_label; + let reader_stream_id = stream_id.clone(); + let incoming_tx_for_reader = incoming_tx; + let disconnected_tx_for_reader = disconnected_tx.clone(); + let reader_task = tokio::spawn(async move { + loop { + match websocket_reader.next().await { + Some(Ok(Message::Text(text))) => { + let frame = match serde_json::from_str::(text.as_ref()) { + Ok(frame) => frame, + Err(err) => { + let _ = incoming_tx_for_reader + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: format!( + "failed to parse relay message frame from {reader_label}: {err}" + ), + }) + .await; + continue; + } + }; + if frame.stream_id != reader_stream_id { + continue; + } + match frame.kind { + RelayMessageFrameKind::Data => match frame.into_jsonrpc_message() { + Ok(message) => { + if incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + break; + } + } + Err(err) => { + let _ = incoming_tx_for_reader + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: err.to_string(), + }) + .await; + } + }, + RelayMessageFrameKind::Reset => { + let _ = disconnected_tx_for_reader.send(true); + let _ = incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Disconnected { + reason: frame.reason, + }) + .await; + break; + } + RelayMessageFrameKind::Ack + | RelayMessageFrameKind::Resume + | RelayMessageFrameKind::Heartbeat => {} + } + } + Some(Ok(Message::Close(_))) | None => { + let _ = disconnected_tx_for_reader.send(true); + let _ = incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Disconnected { reason: None }) + .await; + break; + } + Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {} + Some(Ok(Message::Binary(_))) => { + let _ = incoming_tx_for_reader + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: "relay exec-server transport expects JSON text frames" + .to_string(), + }) + .await; + } + Some(Err(err)) => { + let _ = disconnected_tx_for_reader.send(true); + let _ = incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Disconnected { + reason: Some(format!( + "failed to read relay websocket frame from {reader_label}: {err}" + )), + }) + .await; + break; + } + } + } + }); + + let writer_task = tokio::spawn(async move { + let resume = RelayMessageFrame::resume(stream_id.clone()); + match serialize_relay_message_frame(&resume) { + Ok(encoded) => { + if websocket_writer + .send(Message::Text(encoded.into())) + .await + .is_err() + { + let _ = disconnected_tx.send(true); + return; + } + } + Err(err) => { + warn!("failed to serialize relay resume frame: {err}"); + let _ = disconnected_tx.send(true); + return; + } + } + + let mut next_seq = 0u32; + while let Some(message) = outgoing_rx.recv().await { + let payload = match jsonrpc_payload(&message) { + Ok(payload) => payload, + Err(err) => { + warn!("failed to serialize JSON-RPC payload for relay transport: {err}"); + break; + } + }; + let frame = RelayMessageFrame::data(stream_id.clone(), next_seq, payload); + next_seq = next_seq.wrapping_add(1); + match serialize_relay_message_frame(&frame) { + Ok(encoded) => { + if websocket_writer + .send(Message::Text(encoded.into())) + .await + .is_err() + { + let _ = disconnected_tx.send(true); + break; + } + } + Err(err) => { + warn!("failed to serialize relay data message frame: {err}"); + let _ = disconnected_tx.send(true); + break; + } + } + } + }); + + JsonRpcConnection { + outgoing_tx, + incoming_rx, + disconnected_rx, + task_handles: vec![reader_task, writer_task], + transport: JsonRpcTransport::Plain, + } +} + +pub(crate) async fn run_multiplexed_executor( + stream: WebSocketStream, + processor: ConnectionProcessor, +) where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + let (mut websocket_writer, mut websocket_reader) = stream.split(); + let (physical_outgoing_tx, mut physical_outgoing_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let writer_task = tokio::spawn(async move { + while let Some(encoded) = physical_outgoing_rx.recv().await { + if websocket_writer + .send(Message::Text(encoded.into())) + .await + .is_err() + { + break; + } + } + }); + + let mut streams: HashMap = HashMap::new(); + loop { + let frame = match websocket_reader.next().await { + Some(Ok(Message::Text(text))) => { + match serde_json::from_str::(text.as_ref()) { + Ok(frame) => frame, + Err(err) => { + warn!("dropping malformed relay message frame from harness: {err}"); + continue; + } + } + } + Some(Ok(Message::Close(_))) | None => break, + Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue, + Some(Ok(Message::Binary(_))) => { + warn!("dropping non-text relay message frame from harness"); + continue; + } + Some(Err(err)) => { + debug!("multiplexed executor websocket read failed: {err}"); + break; + } + }; + + if let Err(err) = frame.validate() { + warn!("dropping invalid relay message frame: {err}"); + continue; + } + + match frame.kind { + RelayMessageFrameKind::Data => { + let stream_id = frame.stream_id.clone(); + let message = match frame.into_jsonrpc_message() { + Ok(message) => message, + Err(err) => { + warn!("dropping malformed relay data message frame: {err}"); + continue; + } + }; + let stream = streams.entry(stream_id.clone()).or_insert_with(|| { + spawn_virtual_stream( + stream_id.clone(), + processor.clone(), + physical_outgoing_tx.clone(), + ) + }); + if stream + .incoming_tx + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + streams.remove(&stream_id); + } + } + RelayMessageFrameKind::Reset => { + if let Some(stream) = streams.remove(&frame.stream_id) { + stream.disconnect(frame.reason).await; + } + } + RelayMessageFrameKind::Ack + | RelayMessageFrameKind::Resume + | RelayMessageFrameKind::Heartbeat => {} + } + } + + for (_stream_id, stream) in streams { + stream.disconnect(/*reason*/ None).await; + } + drop(physical_outgoing_tx); + let _ = writer_task.await; +} + +struct VirtualStream { + incoming_tx: mpsc::Sender, + disconnected_tx: watch::Sender, +} + +impl VirtualStream { + async fn disconnect(self, reason: Option) { + let _ = self.disconnected_tx.send(true); + let _ = self + .incoming_tx + .send(JsonRpcConnectionEvent::Disconnected { reason }) + .await; + } +} + +fn spawn_virtual_stream( + stream_id: String, + processor: ConnectionProcessor, + physical_outgoing_tx: mpsc::Sender, +) -> VirtualStream { + let (json_outgoing_tx, mut json_outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (disconnected_tx, disconnected_rx) = watch::channel(false); + + let writer_stream_id = stream_id; + let writer_task = tokio::spawn(async move { + let mut next_seq = 0u32; + while let Some(message) = json_outgoing_rx.recv().await { + let payload = match jsonrpc_payload(&message) { + Ok(payload) => payload, + Err(err) => { + warn!("failed to serialize virtual stream JSON-RPC payload: {err}"); + break; + } + }; + let frame = RelayMessageFrame::data(writer_stream_id.clone(), next_seq, payload); + next_seq = next_seq.wrapping_add(1); + match serialize_relay_message_frame(&frame) { + Ok(encoded) => { + if physical_outgoing_tx.send(encoded).await.is_err() { + break; + } + } + Err(err) => { + warn!("failed to serialize virtual stream relay message frame: {err}"); + break; + } + } + } + }); + + let connection = JsonRpcConnection { + outgoing_tx: json_outgoing_tx, + incoming_rx, + disconnected_rx, + task_handles: vec![writer_task], + transport: JsonRpcTransport::Plain, + }; + tokio::spawn(async move { + processor.run_connection(connection).await; + }); + + VirtualStream { + incoming_tx, + disconnected_tx, + } +} diff --git a/codex-rs/exec-server/src/remote.rs b/codex-rs/exec-server/src/remote.rs index bb22105c19..4fbdef77b0 100644 --- a/codex-rs/exec-server/src/remote.rs +++ b/codex-rs/exec-server/src/remote.rs @@ -11,7 +11,7 @@ use codex_utils_rustls_provider::ensure_rustls_crypto_provider; use crate::ExecServerError; use crate::ExecServerRuntimePaths; -use crate::connection::JsonRpcConnection; +use crate::relay::run_multiplexed_executor; use crate::server::ConnectionProcessor; pub const CODEX_EXEC_SERVER_REMOTE_BEARER_TOKEN_ENV_VAR: &str = @@ -150,12 +150,7 @@ pub async fn run_remote_executor( match connect_async(response.url.as_str()).await { Ok((websocket, _)) => { backoff = Duration::from_secs(1); - processor - .run_connection(JsonRpcConnection::from_websocket( - websocket, - "remote exec-server websocket".to_string(), - )) - .await; + run_multiplexed_executor(websocket, processor.clone()).await; } Err(err) => { warn!("failed to connect remote exec-server websocket: {err}");