app-server: Add transport for remote control (#15951)

This commit is contained in:
Ruslan Nigmatullin
2026-04-06 14:55:59 -07:00
committed by GitHub
parent 03c07956cf
commit 73dab2046f
23 changed files with 4557 additions and 81 deletions

View File

@@ -17,6 +17,7 @@ use std::str::FromStr;
use std::sync::Arc;
use std::sync::RwLock;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
@@ -28,9 +29,11 @@ use tracing::warn;
/// plenty for an interactive CLI.
pub(crate) const CHANNEL_CAPACITY: usize = 128;
mod remote_control;
mod stdio;
mod websocket;
pub(crate) use remote_control::start_remote_control;
pub(crate) use stdio::start_stdio_connection;
pub(crate) use websocket::start_websocket_acceptor;
@@ -38,6 +41,7 @@ pub(crate) use websocket::start_websocket_acceptor;
pub enum AppServerTransport {
Stdio,
WebSocket { bind_address: SocketAddr },
Off,
}
#[derive(Debug, Clone, Eq, PartialEq)]
@@ -51,7 +55,7 @@ impl std::fmt::Display for AppServerTransportParseError {
match self {
AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!(
f,
"unsupported --listen URL `{listen_url}`; expected `stdio://` or `ws://IP:PORT`"
"unsupported --listen URL `{listen_url}`; expected `stdio://`, `ws://IP:PORT`, or `off`"
),
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!(
f,
@@ -71,6 +75,10 @@ impl AppServerTransport {
return Ok(Self::Stdio);
}
if listen_url == "off" {
return Ok(Self::Off);
}
if let Some(socket_addr) = listen_url.strip_prefix("ws://") {
let bind_address = socket_addr.parse::<SocketAddr>().map_err(|_| {
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string())
@@ -166,6 +174,12 @@ impl OutboundConnectionState {
}
}
static CONNECTION_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
fn next_connection_id() -> ConnectionId {
ConnectionId(CONNECTION_ID_COUNTER.fetch_add(1, Ordering::Relaxed))
}
async fn forward_incoming_message(
transport_event_tx: &mpsc::Sender<TransportEvent>,
writer: &mpsc::Sender<QueuedOutgoingMessage>,
@@ -378,8 +392,11 @@ pub(crate) async fn route_outgoing_envelope(
#[cfg(test)]
mod tests {
use super::*;
use crate::error_code::OVERLOADED_ERROR_CODE;
use codex_app_server_protocol::ConfigWarningNotification;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RequestId;
use codex_app_server_protocol::ServerNotification;
use codex_utils_absolute_path::AbsolutePathBuf;
use pretty_assertions::assert_eq;
@@ -393,41 +410,10 @@ mod tests {
}
#[test]
fn app_server_transport_parses_stdio_listen_url() {
let transport = AppServerTransport::from_listen_url(AppServerTransport::DEFAULT_LISTEN_URL)
.expect("stdio listen URL should parse");
assert_eq!(transport, AppServerTransport::Stdio);
}
#[test]
fn app_server_transport_parses_websocket_listen_url() {
let transport = AppServerTransport::from_listen_url("ws://127.0.0.1:1234")
.expect("websocket listen URL should parse");
fn listen_off_parses_as_off_transport() {
assert_eq!(
transport,
AppServerTransport::WebSocket {
bind_address: "127.0.0.1:1234".parse().expect("valid socket address"),
}
);
}
#[test]
fn app_server_transport_rejects_invalid_websocket_listen_url() {
let err = AppServerTransport::from_listen_url("ws://localhost:1234")
.expect_err("hostname bind address should be rejected");
assert_eq!(
err.to_string(),
"invalid websocket --listen URL `ws://localhost:1234`; expected `ws://IP:PORT`"
);
}
#[test]
fn app_server_transport_rejects_unsupported_listen_url() {
let err = AppServerTransport::from_listen_url("http://127.0.0.1:1234")
.expect_err("unsupported scheme should fail");
assert_eq!(
err.to_string(),
"unsupported --listen URL `http://127.0.0.1:1234`; expected `stdio://` or `ws://IP:PORT`"
AppServerTransport::from_listen_url("off"),
Ok(AppServerTransport::Off)
);
}
@@ -437,11 +423,10 @@ mod tests {
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1);
let (writer_tx, mut writer_rx) = mpsc::channel(1);
let first_message =
JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
let first_message = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
transport_event_tx
.send(TransportEvent::IncomingMessage {
connection_id,
@@ -450,8 +435,8 @@ mod tests {
.await
.expect("queue should accept first message");
let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest {
id: codex_app_server_protocol::RequestId::Integer(7),
let request = JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(7),
method: "config/read".to_string(),
params: Some(json!({ "includeLayers": false })),
trace: None,
@@ -499,11 +484,10 @@ mod tests {
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1);
let (writer_tx, _writer_rx) = mpsc::channel(1);
let first_message =
JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
let first_message = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
transport_event_tx
.send(TransportEvent::IncomingMessage {
connection_id,
@@ -512,8 +496,8 @@ mod tests {
.await
.expect("queue should accept first message");
let response = JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse {
id: codex_app_server_protocol::RequestId::Integer(7),
let response = JSONRPCMessage::Response(JSONRPCResponse {
id: RequestId::Integer(7),
result: json!({"ok": true}),
});
let transport_event_tx_for_enqueue = transport_event_tx.clone();
@@ -553,11 +537,10 @@ mod tests {
match forwarded_event {
TransportEvent::IncomingMessage {
connection_id: queued_connection_id,
message:
JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse { id, result }),
message: JSONRPCMessage::Response(JSONRPCResponse { id, result }),
} => {
assert_eq!(queued_connection_id, connection_id);
assert_eq!(id, codex_app_server_protocol::RequestId::Integer(7));
assert_eq!(id, RequestId::Integer(7));
assert_eq!(result, json!({"ok": true}));
}
_ => panic!("expected forwarded response message"),
@@ -573,12 +556,10 @@ mod tests {
transport_event_tx
.send(TransportEvent::IncomingMessage {
connection_id,
message: JSONRPCMessage::Notification(
codex_app_server_protocol::JSONRPCNotification {
method: "initialized".to_string(),
params: None,
},
),
message: JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
}),
})
.await
.expect("transport queue should accept first message");
@@ -597,15 +578,15 @@ mod tests {
.await
.expect("writer queue should accept first message");
let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest {
id: codex_app_server_protocol::RequestId::Integer(7),
let request = JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(7),
method: "config/read".to_string(),
params: Some(json!({ "includeLayers": false })),
trace: None,
});
let enqueue_result = tokio::time::timeout(
std::time::Duration::from_millis(100),
let enqueue_result = timeout(
Duration::from_millis(100),
enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request),
)
.await
@@ -781,7 +762,7 @@ mod tests {
OutgoingEnvelope::ToConnection {
connection_id,
message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
request_id: codex_app_server_protocol::RequestId::Integer(1),
request_id: RequestId::Integer(1),
params: codex_app_server_protocol::CommandExecutionRequestApprovalParams {
thread_id: "thr_123".to_string(),
turn_id: "turn_123".to_string(),
@@ -843,7 +824,7 @@ mod tests {
OutgoingEnvelope::ToConnection {
connection_id,
message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
request_id: codex_app_server_protocol::RequestId::Integer(1),
request_id: RequestId::Integer(1),
params: codex_app_server_protocol::CommandExecutionRequestApprovalParams {
thread_id: "thr_123".to_string(),
turn_id: "turn_123".to_string(),