Compare commits

...

1 Commits

Author SHA1 Message Date
Ruslan Nigmatullin
2350059789 app-server: Add transport for remote control 2026-03-26 21:26:32 -07:00
25 changed files with 3548 additions and 88 deletions

3
codex-rs/Cargo.lock generated
View File

@@ -1423,9 +1423,11 @@ dependencies = [
"codex-utils-cli",
"codex-utils-json-to-toml",
"codex-utils-pty",
"codex-utils-rustls-provider",
"constant_time_eq",
"core_test_support",
"futures",
"gethostname",
"hmac",
"jsonwebtoken",
"opentelemetry",
@@ -1448,6 +1450,7 @@ dependencies = [
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
"url",
"uuid",
"wiremock",
]

View File

@@ -52,10 +52,12 @@ codex-state = { workspace = true }
codex-tools = { workspace = true }
codex-utils-absolute-path = { workspace = true }
codex-utils-json-to-toml = { workspace = true }
codex-utils-rustls-provider = { workspace = true }
chrono = { workspace = true }
clap = { workspace = true, features = ["derive"] }
constant_time_eq = { workspace = true }
futures = { workspace = true }
gethostname = { workspace = true }
hmac = { workspace = true }
jsonwebtoken = { workspace = true }
owo-colors = { workspace = true, features = ["supports-colors"] }
@@ -76,6 +78,7 @@ tokio-util = { workspace = true }
tokio-tungstenite = { workspace = true }
tracing = { workspace = true, features = ["log"] }
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt", "json"] }
url = { workspace = true }
uuid = { workspace = true, features = ["serde", "v7"] }
[dev-dependencies]

View File

@@ -25,6 +25,7 @@ Supported transports:
- stdio (`--listen stdio://`, default): newline-delimited JSON (JSONL)
- websocket (`--listen ws://IP:PORT`): one JSON-RPC message per websocket text frame (**experimental / unsupported**)
- off (`--listen off`): do not expose a local transport
When running with `--listen ws://IP:PORT`, the same listener also serves basic HTTP health probes:

View File

@@ -86,6 +86,7 @@ fn transport_name(transport: AppServerTransport) -> &'static str {
match transport {
AppServerTransport::Stdio => "stdio",
AppServerTransport::WebSocket { .. } => "websocket",
AppServerTransport::Off => "off",
}
}

View File

@@ -74,6 +74,7 @@ use codex_app_server_protocol::Result;
use codex_app_server_protocol::ServerNotification;
use codex_app_server_protocol::ServerRequest;
use codex_arg0::Arg0DispatchPaths;
use codex_core::AuthManager;
use codex_core::config::Config;
use codex_core::config_loader::CloudRequirementsLoader;
use codex_core::config_loader::LoaderOverrides;
@@ -377,6 +378,14 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
}
});
let auth_manager = AuthManager::shared(
args.config.codex_home.clone(),
args.enable_codex_api_key_env,
args.config.cli_auth_credentials_store_mode,
);
auth_manager
.set_forced_chatgpt_workspace_id(args.config.forced_chatgpt_workspace_id.clone());
let processor_outgoing = Arc::clone(&outgoing_message_sender);
let (processor_tx, mut processor_rx) = mpsc::channel::<ProcessorCommand>(channel_capacity);
let mut processor_handle = tokio::spawn(async move {
@@ -392,7 +401,7 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
log_db: None,
config_warnings: args.config_warnings,
session_source: args.session_source,
enable_codex_api_key_env: args.enable_codex_api_key_env,
auth_manager,
});
let mut thread_created_rx = processor.thread_created_receiver();
let mut session = ConnectionSessionState::default();

View File

@@ -8,6 +8,7 @@ use codex_core::config::ConfigBuilder;
use codex_core::config_loader::CloudRequirementsLoader;
use codex_core::config_loader::ConfigLayerStackOrdering;
use codex_core::config_loader::LoaderOverrides;
use codex_features::Feature;
use codex_utils_cli::CliConfigOverrides;
use std::collections::HashMap;
use std::collections::HashSet;
@@ -29,6 +30,7 @@ use crate::transport::OutboundConnectionState;
use crate::transport::TransportEvent;
use crate::transport::auth::policy_from_settings;
use crate::transport::route_outgoing_envelope;
use crate::transport::start_remote_control;
use crate::transport::start_stdio_connection;
use crate::transport::start_websocket_acceptor;
use codex_app_server_protocol::ConfigLayerSource;
@@ -499,13 +501,13 @@ pub async fn run_main_with_transport(
let feedback_layer = feedback.logger_layer();
let feedback_metadata_layer = feedback.metadata_layer();
let log_db = codex_state::StateRuntime::init(
let state_db = codex_state::StateRuntime::init(
config.sqlite_home.clone(),
config.model_provider_id.clone(),
)
.await
.ok()
.map(log_db::start);
.ok();
let log_db = state_db.clone().map(log_db::start);
let log_db_layer = log_db
.clone()
.map(|layer| layer.with_filter(Targets::new().with_default(Level::TRACE)));
@@ -548,6 +550,32 @@ pub async fn run_main_with_transport(
.await?;
transport_accept_handles.push(accept_handle);
}
AppServerTransport::Off => {}
}
let auth_manager = AuthManager::shared(
config.codex_home.clone(),
/*enable_codex_api_key_env*/ false,
config.cli_auth_credentials_store_mode,
);
auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone());
if config.features.enabled(Feature::RemoteControl) {
let accept_handle = start_remote_control(
config.chatgpt_base_url.clone(),
state_db.clone(),
auth_manager.clone(),
transport_event_tx.clone(),
transport_shutdown_token.clone(),
)
.await?;
transport_accept_handles.push(accept_handle);
}
if transport_accept_handles.is_empty() {
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
"no transport configured; use --listen or enable remote control",
));
}
let outbound_handle = tokio::spawn(async move {
@@ -622,7 +650,7 @@ pub async fn run_main_with_transport(
log_db,
config_warnings,
session_source,
enable_codex_api_key_env: false,
auth_manager,
});
let mut thread_created_rx = processor.thread_created_receiver();
let mut running_turn_count_rx = processor.subscribe_running_assistant_turn_count();

View File

@@ -16,7 +16,7 @@ const MANAGED_CONFIG_PATH_ENV_VAR: &str = "CODEX_APP_SERVER_MANAGED_CONFIG_PATH"
#[derive(Debug, Parser)]
struct AppServerArgs {
/// Transport endpoint URL. Supported values: `stdio://` (default),
/// `ws://IP:PORT`.
/// `ws://IP:PORT`, `off`.
#[arg(
long = "listen",
value_name = "URL",

View File

@@ -187,7 +187,7 @@ pub(crate) struct MessageProcessorArgs {
pub(crate) log_db: Option<LogDbLayer>,
pub(crate) config_warnings: Vec<ConfigWarningNotification>,
pub(crate) session_source: SessionSource,
pub(crate) enable_codex_api_key_env: bool,
pub(crate) auth_manager: Arc<AuthManager>,
}
impl MessageProcessor {
@@ -206,13 +206,8 @@ impl MessageProcessor {
log_db,
config_warnings,
session_source,
enable_codex_api_key_env,
auth_manager,
} = args;
let auth_manager = AuthManager::shared(
config.codex_home.clone(),
enable_codex_api_key_env,
config.cli_auth_credentials_store_mode,
);
let thread_manager = Arc::new(ThreadManager::new(
config.as_ref(),
auth_manager.clone(),
@@ -224,7 +219,6 @@ impl MessageProcessor {
},
environment_manager,
));
auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone());
auth_manager.set_external_auth_refresher(Arc::new(ExternalAuthRefreshBridge {
outgoing: outgoing.clone(),
}));

View File

@@ -20,6 +20,7 @@ use codex_app_server_protocol::TurnStartParams;
use codex_app_server_protocol::TurnStartResponse;
use codex_app_server_protocol::UserInput;
use codex_arg0::Arg0DispatchPaths;
use codex_core::AuthManager;
use codex_core::config::Config;
use codex_core::config::ConfigBuilder;
use codex_core::config_loader::CloudRequirementsLoader;
@@ -231,6 +232,13 @@ fn build_test_processor(
MessageProcessor,
mpsc::Receiver<crate::outgoing_message::OutgoingEnvelope>,
) {
let auth_manager = AuthManager::shared(
config.codex_home.clone(),
/*enable_codex_api_key_env*/ false,
config.cli_auth_credentials_store_mode,
);
auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone());
let (outgoing_tx, outgoing_rx) = mpsc::channel(16);
let outgoing = Arc::new(OutgoingMessageSender::new(outgoing_tx));
let processor = MessageProcessor::new(MessageProcessorArgs {
@@ -245,7 +253,7 @@ fn build_test_processor(
log_db: None,
config_warnings: Vec::new(),
session_source: SessionSource::VSCode,
enable_codex_api_key_env: false,
auth_manager,
});
(processor, outgoing_rx)
}

View File

@@ -17,6 +17,7 @@ use std::str::FromStr;
use std::sync::Arc;
use std::sync::RwLock;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
@@ -28,9 +29,11 @@ use tracing::warn;
/// plenty for an interactive CLI.
pub(crate) const CHANNEL_CAPACITY: usize = 128;
mod remote_control;
mod stdio;
mod websocket;
pub(crate) use remote_control::start_remote_control;
pub(crate) use stdio::start_stdio_connection;
pub(crate) use websocket::start_websocket_acceptor;
@@ -38,6 +41,7 @@ pub(crate) use websocket::start_websocket_acceptor;
pub enum AppServerTransport {
Stdio,
WebSocket { bind_address: SocketAddr },
Off,
}
#[derive(Debug, Clone, Eq, PartialEq)]
@@ -51,7 +55,7 @@ impl std::fmt::Display for AppServerTransportParseError {
match self {
AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!(
f,
"unsupported --listen URL `{listen_url}`; expected `stdio://` or `ws://IP:PORT`"
"unsupported --listen URL `{listen_url}`; expected `stdio://`, `ws://IP:PORT`, or `off`"
),
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!(
f,
@@ -71,6 +75,10 @@ impl AppServerTransport {
return Ok(Self::Stdio);
}
if listen_url == "off" {
return Ok(Self::Off);
}
if let Some(socket_addr) = listen_url.strip_prefix("ws://") {
let bind_address = socket_addr.parse::<SocketAddr>().map_err(|_| {
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string())
@@ -166,6 +174,12 @@ impl OutboundConnectionState {
}
}
static CONNECTION_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
fn next_connection_id() -> ConnectionId {
ConnectionId(CONNECTION_ID_COUNTER.fetch_add(1, Ordering::Relaxed))
}
async fn forward_incoming_message(
transport_event_tx: &mpsc::Sender<TransportEvent>,
writer: &mpsc::Sender<QueuedOutgoingMessage>,
@@ -378,8 +392,11 @@ pub(crate) async fn route_outgoing_envelope(
#[cfg(test)]
mod tests {
use super::*;
use crate::error_code::OVERLOADED_ERROR_CODE;
use codex_app_server_protocol::ConfigWarningNotification;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RequestId;
use codex_app_server_protocol::ServerNotification;
use codex_utils_absolute_path::AbsolutePathBuf;
use pretty_assertions::assert_eq;
@@ -393,41 +410,10 @@ mod tests {
}
#[test]
fn app_server_transport_parses_stdio_listen_url() {
let transport = AppServerTransport::from_listen_url(AppServerTransport::DEFAULT_LISTEN_URL)
.expect("stdio listen URL should parse");
assert_eq!(transport, AppServerTransport::Stdio);
}
#[test]
fn app_server_transport_parses_websocket_listen_url() {
let transport = AppServerTransport::from_listen_url("ws://127.0.0.1:1234")
.expect("websocket listen URL should parse");
fn listen_off_parses_as_off_transport() {
assert_eq!(
transport,
AppServerTransport::WebSocket {
bind_address: "127.0.0.1:1234".parse().expect("valid socket address"),
}
);
}
#[test]
fn app_server_transport_rejects_invalid_websocket_listen_url() {
let err = AppServerTransport::from_listen_url("ws://localhost:1234")
.expect_err("hostname bind address should be rejected");
assert_eq!(
err.to_string(),
"invalid websocket --listen URL `ws://localhost:1234`; expected `ws://IP:PORT`"
);
}
#[test]
fn app_server_transport_rejects_unsupported_listen_url() {
let err = AppServerTransport::from_listen_url("http://127.0.0.1:1234")
.expect_err("unsupported scheme should fail");
assert_eq!(
err.to_string(),
"unsupported --listen URL `http://127.0.0.1:1234`; expected `stdio://` or `ws://IP:PORT`"
AppServerTransport::from_listen_url("off"),
Ok(AppServerTransport::Off)
);
}
@@ -437,11 +423,10 @@ mod tests {
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1);
let (writer_tx, mut writer_rx) = mpsc::channel(1);
let first_message =
JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
let first_message = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
transport_event_tx
.send(TransportEvent::IncomingMessage {
connection_id,
@@ -450,8 +435,8 @@ mod tests {
.await
.expect("queue should accept first message");
let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest {
id: codex_app_server_protocol::RequestId::Integer(7),
let request = JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(7),
method: "config/read".to_string(),
params: Some(json!({ "includeLayers": false })),
trace: None,
@@ -499,11 +484,10 @@ mod tests {
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1);
let (writer_tx, _writer_rx) = mpsc::channel(1);
let first_message =
JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
let first_message = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
transport_event_tx
.send(TransportEvent::IncomingMessage {
connection_id,
@@ -512,8 +496,8 @@ mod tests {
.await
.expect("queue should accept first message");
let response = JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse {
id: codex_app_server_protocol::RequestId::Integer(7),
let response = JSONRPCMessage::Response(JSONRPCResponse {
id: RequestId::Integer(7),
result: json!({"ok": true}),
});
let transport_event_tx_for_enqueue = transport_event_tx.clone();
@@ -553,11 +537,10 @@ mod tests {
match forwarded_event {
TransportEvent::IncomingMessage {
connection_id: queued_connection_id,
message:
JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse { id, result }),
message: JSONRPCMessage::Response(JSONRPCResponse { id, result }),
} => {
assert_eq!(queued_connection_id, connection_id);
assert_eq!(id, codex_app_server_protocol::RequestId::Integer(7));
assert_eq!(id, RequestId::Integer(7));
assert_eq!(result, json!({"ok": true}));
}
_ => panic!("expected forwarded response message"),
@@ -573,12 +556,10 @@ mod tests {
transport_event_tx
.send(TransportEvent::IncomingMessage {
connection_id,
message: JSONRPCMessage::Notification(
codex_app_server_protocol::JSONRPCNotification {
method: "initialized".to_string(),
params: None,
},
),
message: JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
}),
})
.await
.expect("transport queue should accept first message");
@@ -597,15 +578,15 @@ mod tests {
.await
.expect("writer queue should accept first message");
let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest {
id: codex_app_server_protocol::RequestId::Integer(7),
let request = JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(7),
method: "config/read".to_string(),
params: Some(json!({ "includeLayers": false })),
trace: None,
});
let enqueue_result = tokio::time::timeout(
std::time::Duration::from_millis(100),
let enqueue_result = timeout(
Duration::from_millis(100),
enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request),
)
.await
@@ -781,7 +762,7 @@ mod tests {
OutgoingEnvelope::ToConnection {
connection_id,
message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
request_id: codex_app_server_protocol::RequestId::Integer(1),
request_id: RequestId::Integer(1),
params: codex_app_server_protocol::CommandExecutionRequestApprovalParams {
thread_id: "thr_123".to_string(),
turn_id: "turn_123".to_string(),
@@ -843,7 +824,7 @@ mod tests {
OutgoingEnvelope::ToConnection {
connection_id,
message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
request_id: codex_app_server_protocol::RequestId::Integer(1),
request_id: RequestId::Integer(1),
params: codex_app_server_protocol::CommandExecutionRequestApprovalParams {
thread_id: "thr_123".to_string(),
turn_id: "turn_123".to_string(),

View File

@@ -0,0 +1,422 @@
use super::CHANNEL_CAPACITY;
use super::TransportEvent;
use super::next_connection_id;
use super::protocol::ClientEnvelope;
pub use super::protocol::ClientEvent;
pub use super::protocol::ClientId;
use super::protocol::PongStatus;
use super::protocol::ServerEvent;
use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::QueuedOutgoingMessage;
use crate::transport::remote_control::QueuedServerEnvelope;
use codex_app_server_protocol::JSONRPCMessage;
use std::collections::HashMap;
use tokio::sync::mpsc;
use tokio::sync::watch;
use tokio::task::JoinSet;
use tokio::time::Duration;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
const REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT: Duration = Duration::from_secs(10 * 60);
pub(crate) const REMOTE_CONTROL_IDLE_SWEEP_INTERVAL: Duration = Duration::from_secs(30);
#[derive(Debug)]
pub(crate) struct Stopped;
struct ClientState {
connection_id: ConnectionId,
disconnect_token: CancellationToken,
last_activity_at: Instant,
last_inbound_seq_id: Option<u64>,
status_tx: watch::Sender<PongStatus>,
}
pub(crate) struct ClientTracker {
clients: HashMap<ClientId, ClientState>,
join_set: JoinSet<()>,
server_event_tx: mpsc::Sender<QueuedServerEnvelope>,
transport_event_tx: mpsc::Sender<TransportEvent>,
shutdown_token: CancellationToken,
}
impl ClientTracker {
pub(crate) fn new(
server_event_tx: mpsc::Sender<QueuedServerEnvelope>,
transport_event_tx: mpsc::Sender<TransportEvent>,
shutdown_token: &CancellationToken,
) -> Self {
Self {
clients: HashMap::new(),
join_set: JoinSet::new(),
server_event_tx,
transport_event_tx,
shutdown_token: shutdown_token.child_token(),
}
}
pub(crate) async fn bookkeep_join_set(&mut self) {
while self.join_set.join_next().await.is_some() {}
futures::future::pending().await
}
pub(crate) async fn shutdown(&mut self) {
self.shutdown_token.cancel();
while let Some(client_id) = self.clients.keys().next().cloned() {
let _ = self.close_client(&client_id).await;
}
self.drain_join_set().await;
}
async fn drain_join_set(&mut self) {
while self.join_set.join_next().await.is_some() {}
}
pub(crate) async fn handle_message(
&mut self,
client_envelope: ClientEnvelope,
) -> Result<(), Stopped> {
let ClientEnvelope {
client_id,
event,
seq_id,
cursor: _,
} = client_envelope;
match event {
ClientEvent::ClientMessage { message } => {
let is_initialize = remote_control_message_starts_connection(&message);
if let Some(seq_id) = seq_id
&& let Some(client) = self.clients.get(&client_id)
&& client
.last_inbound_seq_id
.is_some_and(|last_seq_id| last_seq_id >= seq_id)
&& !is_initialize
{
return Ok(());
}
if is_initialize && self.clients.contains_key(&client_id) {
self.close_client(&client_id).await?;
}
if let Some(connection_id) = self.clients.get_mut(&client_id).map(|client| {
client.last_activity_at = Instant::now();
if let Some(seq_id) = seq_id {
client.last_inbound_seq_id = Some(seq_id);
}
client.connection_id
}) {
self.transport_event_tx
.send(TransportEvent::IncomingMessage {
connection_id,
message,
})
.await
.map_err(|_| Stopped)?;
return Ok(());
}
if !is_initialize {
return Ok(());
}
let connection_id = next_connection_id();
let (writer_tx, writer_rx) =
mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
let disconnect_token = self.shutdown_token.child_token();
self.transport_event_tx
.send(TransportEvent::ConnectionOpened {
connection_id,
writer: writer_tx,
disconnect_sender: Some(disconnect_token.clone()),
})
.await
.map_err(|_| Stopped)?;
let (status_tx, status_rx) = watch::channel(PongStatus::Active);
self.join_set.spawn(Self::run_client_outbound(
client_id.clone(),
self.server_event_tx.clone(),
writer_rx,
status_rx,
disconnect_token.clone(),
));
self.clients.insert(
client_id,
ClientState {
connection_id,
disconnect_token,
last_activity_at: Instant::now(),
last_inbound_seq_id: seq_id,
status_tx,
},
);
self.send_transport_event(TransportEvent::IncomingMessage {
connection_id,
message,
})
.await
}
ClientEvent::Ack => Ok(()),
ClientEvent::Ping => {
if let Some(client) = self.clients.get_mut(&client_id) {
client.last_activity_at = Instant::now();
let _ = client.status_tx.send(PongStatus::Active);
return Ok(());
}
let server_event_tx = self.server_event_tx.clone();
self.join_set.spawn(async move {
let server_envelope = QueuedServerEnvelope {
event: ServerEvent::Pong {
status: PongStatus::Unknown,
},
client_id,
write_complete_tx: None,
};
let _ = server_event_tx.send(server_envelope).await;
});
Ok(())
}
ClientEvent::ClientClosed => self.close_client(&client_id).await,
}
}
async fn run_client_outbound(
client_id: ClientId,
server_event_tx: mpsc::Sender<QueuedServerEnvelope>,
mut writer_rx: mpsc::Receiver<QueuedOutgoingMessage>,
mut status_rx: watch::Receiver<PongStatus>,
disconnect_token: CancellationToken,
) {
loop {
let (event, write_complete_tx) = tokio::select! {
_ = disconnect_token.cancelled() => {
break;
}
queued_message = writer_rx.recv() => {
let Some(queued_message) = queued_message else {
break;
};
let event = ServerEvent::ServerMessage {
message: Box::new(queued_message.message),
};
(event, queued_message.write_complete_tx)
}
changed = status_rx.changed() => {
if changed.is_err() {
break;
}
let event = ServerEvent::Pong { status: status_rx.borrow().clone() };
(event, None)
}
};
let send_result = tokio::select! {
_ = disconnect_token.cancelled() => {
break;
}
send_result = server_event_tx.send(QueuedServerEnvelope {
event,
client_id: client_id.clone(),
write_complete_tx,
}) => send_result,
};
if send_result.is_err() {
break;
}
}
}
pub(crate) async fn close_expired_clients(&mut self) -> Result<Vec<ClientId>, Stopped> {
let now = Instant::now();
let expired_client_ids: Vec<ClientId> = self
.clients
.iter()
.filter_map(|(client_id, client)| {
(!remote_control_client_is_alive(client, now)).then_some(client_id.clone())
})
.collect();
for client_id in &expired_client_ids {
self.close_client(client_id).await?;
}
Ok(expired_client_ids)
}
async fn close_client(&mut self, client_id: &ClientId) -> Result<(), Stopped> {
let Some(client) = self.clients.remove(client_id) else {
return Ok(());
};
client.disconnect_token.cancel();
self.send_transport_event(TransportEvent::ConnectionClosed {
connection_id: client.connection_id,
})
.await
}
async fn send_transport_event(&self, event: TransportEvent) -> Result<(), Stopped> {
self.transport_event_tx
.send(event)
.await
.map_err(|_| Stopped)
}
}
fn remote_control_message_starts_connection(message: &JSONRPCMessage) -> bool {
matches!(
message,
JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { method, .. })
if method == "initialize"
)
}
fn remote_control_client_is_alive(client: &ClientState, now: Instant) -> bool {
now.duration_since(client.last_activity_at) < REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT
}
#[cfg(test)]
mod tests {
use super::*;
use crate::outgoing_message::OutgoingMessage;
use crate::transport::remote_control::protocol::ClientEnvelope;
use crate::transport::remote_control::protocol::ClientEvent;
use codex_app_server_protocol::ConfigWarningNotification;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::RequestId;
use codex_app_server_protocol::ServerNotification;
use pretty_assertions::assert_eq;
use serde_json::json;
use tokio::time::timeout;
fn initialize_envelope(client_id: &str) -> ClientEnvelope {
ClientEnvelope {
event: ClientEvent::ClientMessage {
message: JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(1),
method: "initialize".to_string(),
params: Some(json!({
"clientInfo": {
"name": "remote-test-client",
"version": "0.1.0"
}
})),
trace: None,
}),
},
client_id: ClientId(client_id.to_string()),
seq_id: Some(0),
cursor: None,
}
}
#[tokio::test]
async fn cancelled_outbound_task_emits_connection_closed() {
let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
let shutdown_token = CancellationToken::new();
let mut client_tracker =
ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token);
client_tracker
.handle_message(initialize_envelope("client-1"))
.await
.expect("initialize should open client");
let (connection_id, disconnect_sender) = match transport_event_rx
.recv()
.await
.expect("connection opened should be sent")
{
TransportEvent::ConnectionOpened {
connection_id,
disconnect_sender: Some(disconnect_sender),
..
} => (connection_id, disconnect_sender),
other => panic!("expected connection opened, got {other:?}"),
};
match transport_event_rx
.recv()
.await
.expect("initialize should be forwarded")
{
TransportEvent::IncomingMessage {
connection_id: incoming_connection_id,
..
} => assert_eq!(incoming_connection_id, connection_id),
other => panic!("expected incoming initialize, got {other:?}"),
}
disconnect_sender.cancel();
timeout(Duration::from_secs(1), client_tracker.bookkeep_join_set())
.await
.expect_err("bookkeeping should process the closed task and stay pending");
match transport_event_rx
.recv()
.await
.expect("connection closed should be sent")
{
TransportEvent::ConnectionClosed {
connection_id: closed_connection_id,
} => assert_eq!(closed_connection_id, connection_id),
other => panic!("expected connection closed, got {other:?}"),
}
}
#[tokio::test]
async fn shutdown_cancels_blocked_outbound_forwarding() {
let (server_event_tx, _server_event_rx) = mpsc::channel(1);
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
let shutdown_token = CancellationToken::new();
let mut client_tracker =
ClientTracker::new(server_event_tx.clone(), transport_event_tx, &shutdown_token);
server_event_tx
.send(QueuedServerEnvelope {
event: ServerEvent::Pong {
status: PongStatus::Unknown,
},
client_id: ClientId("queued-client".to_string()),
write_complete_tx: None,
})
.await
.expect("server event queue should accept prefill");
client_tracker
.handle_message(initialize_envelope("client-1"))
.await
.expect("initialize should open client");
let writer = match transport_event_rx
.recv()
.await
.expect("connection opened should be sent")
{
TransportEvent::ConnectionOpened { writer, .. } => writer,
other => panic!("expected connection opened, got {other:?}"),
};
let _ = transport_event_rx
.recv()
.await
.expect("initialize should be forwarded");
writer
.send(QueuedOutgoingMessage::new(
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
ConfigWarningNotification {
summary: "test".to_string(),
details: None,
path: None,
range: None,
},
)),
))
.await
.expect("writer should accept queued message");
timeout(Duration::from_secs(1), client_tracker.shutdown())
.await
.expect("shutdown should not hang on blocked server forwarding");
}
}

View File

@@ -0,0 +1,399 @@
use super::protocol::EnrollRemoteServerRequest;
use super::protocol::EnrollRemoteServerResponse;
use super::protocol::RemoteControlTarget;
use axum::http::HeaderMap;
use codex_core::default_client::build_reqwest_client;
use codex_state::StateRuntime;
use gethostname::gethostname;
use std::io;
use std::io::ErrorKind;
use tracing::warn;
const REMOTE_CONTROL_ENROLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
const REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES: usize = 4096;
const REQUEST_ID_HEADER: &str = "x-request-id";
const OAI_REQUEST_ID_HEADER: &str = "x-oai-request-id";
const CF_RAY_HEADER: &str = "cf-ray";
pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id";
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) struct RemoteControlEnrollment {
pub(super) account_id: Option<String>,
pub(super) server_id: String,
pub(super) server_name: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) struct RemoteControlConnectionAuth {
pub(super) bearer_token: String,
pub(super) account_id: Option<String>,
}
pub(super) async fn load_persisted_remote_control_enrollment(
state_db: Option<&StateRuntime>,
remote_control_target: &RemoteControlTarget,
account_id: Option<&str>,
) -> Option<RemoteControlEnrollment> {
let state_db = state_db?;
let enrollment = match state_db
.get_remote_control_enrollment(&remote_control_target.websocket_url, account_id)
.await
{
Ok(enrollment) => enrollment,
Err(err) => {
warn!("{err}");
return None;
}
};
enrollment.map(|(server_id, server_name)| RemoteControlEnrollment {
account_id: account_id.map(&str::to_string),
server_id,
server_name,
})
}
pub(super) async fn update_persisted_remote_control_enrollment(
state_db: Option<&StateRuntime>,
remote_control_target: &RemoteControlTarget,
account_id: Option<&str>,
enrollment: Option<&RemoteControlEnrollment>,
) -> io::Result<()> {
let Some(state_db) = state_db else {
return Ok(());
};
if let &Some(enrollment) = &enrollment
&& enrollment.account_id.as_deref() != account_id
{
return Err(io::Error::other(format!(
"enrollment account_id does not match expected account_id `{account_id:?}`"
)));
}
if let Some(enrollment) = enrollment {
state_db
.upsert_remote_control_enrollment(
&remote_control_target.websocket_url,
account_id,
&enrollment.server_id,
&enrollment.server_name,
)
.await
.map_err(io::Error::other)
} else {
state_db
.delete_remote_control_enrollment(&remote_control_target.websocket_url, account_id)
.await
.map(|_| ())
.map_err(io::Error::other)
}
}
pub(crate) fn preview_remote_control_response_body(body: &[u8]) -> String {
let body = String::from_utf8_lossy(body);
let trimmed = body.trim();
if trimmed.is_empty() {
return "<empty>".to_string();
}
if trimmed.len() <= REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES {
return trimmed.to_string();
}
let mut cut = REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES;
while !trimmed.is_char_boundary(cut) {
cut = cut.saturating_sub(1);
}
let mut truncated = trimmed[..cut].to_string();
truncated.push_str("...");
truncated
}
pub(crate) fn format_headers(headers: &HeaderMap) -> String {
let request_id_str = headers
.get(REQUEST_ID_HEADER)
.or_else(|| headers.get(OAI_REQUEST_ID_HEADER))
.map(|value| value.to_str().unwrap_or("<invalid utf-8>").to_owned())
.unwrap_or_else(|| "<none>".to_owned());
let cf_ray_str = headers
.get(CF_RAY_HEADER)
.map(|value| value.to_str().unwrap_or("<invalid utf-8>").to_owned())
.unwrap_or_else(|| "<none>".to_owned());
format!("request-id: {request_id_str}, cf-ray: {cf_ray_str}")
}
pub(super) async fn enroll_remote_control_server(
remote_control_target: &RemoteControlTarget,
auth: &RemoteControlConnectionAuth,
) -> io::Result<RemoteControlEnrollment> {
let enroll_url = &remote_control_target.enroll_url;
let server_name = gethostname().to_string_lossy().trim().to_string();
let request = EnrollRemoteServerRequest {
name: server_name.clone(),
os: std::env::consts::OS,
arch: std::env::consts::ARCH,
app_server_version: env!("CARGO_PKG_VERSION"),
};
let client = build_reqwest_client();
let mut http_request = client
.post(enroll_url)
.timeout(REMOTE_CONTROL_ENROLL_TIMEOUT)
.bearer_auth(&auth.bearer_token)
.json(&request);
let account_id = auth.account_id.as_deref();
if let Some(account_id) = account_id {
http_request = http_request.header(REMOTE_CONTROL_ACCOUNT_ID_HEADER, account_id);
}
let response = http_request.send().await.map_err(|err| {
io::Error::other(format!(
"failed to enroll remote control server at `{enroll_url}`: {err}"
))
})?;
let headers = response.headers().clone();
let status = response.status();
let body = response.bytes().await.map_err(|err| {
io::Error::other(format!(
"failed to read remote control enrollment response from `{enroll_url}`: {err}"
))
})?;
let body_preview = preview_remote_control_response_body(&body);
if !status.is_success() {
let headers_str = format_headers(&headers);
let error_kind = if matches!(status.as_u16(), 401 | 403) {
ErrorKind::PermissionDenied
} else {
ErrorKind::Other
};
return Err(io::Error::new(
error_kind,
format!(
"remote control server enrollment failed at `{enroll_url}`: HTTP {status}, {headers_str}, body: {body_preview}"
),
));
}
let enrollment = serde_json::from_slice::<EnrollRemoteServerResponse>(&body).map_err(|err| {
let headers_str = format_headers(&headers);
io::Error::other(format!(
"failed to parse remote control enrollment response from `{enroll_url}`: HTTP {status}, {headers_str}, body: {body_preview}, decode error: {err}"
))
})?;
Ok(RemoteControlEnrollment {
account_id: account_id.map(&str::to_string),
server_id: enrollment.server_id,
server_name,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::remote_control::protocol::normalize_remote_control_url;
use codex_state::StateRuntime;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::sync::Arc;
use tempfile::TempDir;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio::time::Duration;
use tokio::time::timeout;
async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc<StateRuntime> {
StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string())
.await
.expect("state runtime should initialize")
}
#[tokio::test]
async fn persisted_remote_control_enrollment_round_trips_by_target_and_account() {
let codex_home = TempDir::new().expect("temp dir should create");
let state_db = remote_control_state_runtime(&codex_home).await;
let first_target = normalize_remote_control_url("http://example.com/remote/control")
.expect("first target should parse");
let second_target = normalize_remote_control_url("http://example.com/other/control")
.expect("second target should parse");
let first_enrollment = RemoteControlEnrollment {
account_id: Some("account-a".to_string()),
server_id: "srv_e_first".to_string(),
server_name: "first-server".to_string(),
};
let second_enrollment = RemoteControlEnrollment {
account_id: Some("account-a".to_string()),
server_id: "srv_e_second".to_string(),
server_name: "second-server".to_string(),
};
update_persisted_remote_control_enrollment(
Some(state_db.as_ref()),
&first_target,
Some("account-a"),
Some(&first_enrollment),
)
.await
.expect("first enrollment should persist");
update_persisted_remote_control_enrollment(
Some(state_db.as_ref()),
&second_target,
Some("account-a"),
Some(&second_enrollment),
)
.await
.expect("second enrollment should persist");
assert_eq!(
load_persisted_remote_control_enrollment(
Some(state_db.as_ref()),
&first_target,
Some("account-a"),
)
.await,
Some(first_enrollment.clone())
);
assert_eq!(
load_persisted_remote_control_enrollment(
Some(state_db.as_ref()),
&first_target,
Some("account-b"),
)
.await,
None
);
assert_eq!(
load_persisted_remote_control_enrollment(
Some(state_db.as_ref()),
&second_target,
Some("account-a"),
)
.await,
Some(second_enrollment)
);
}
#[tokio::test]
async fn clearing_persisted_remote_control_enrollment_removes_only_matching_entry() {
let codex_home = TempDir::new().expect("temp dir should create");
let state_db = remote_control_state_runtime(&codex_home).await;
let first_target = normalize_remote_control_url("http://example.com/remote/control")
.expect("first target should parse");
let second_target = normalize_remote_control_url("http://example.com/other/control")
.expect("second target should parse");
let first_enrollment = RemoteControlEnrollment {
account_id: Some("account-a".to_string()),
server_id: "srv_e_first".to_string(),
server_name: "first-server".to_string(),
};
let second_enrollment = RemoteControlEnrollment {
account_id: Some("account-a".to_string()),
server_id: "srv_e_second".to_string(),
server_name: "second-server".to_string(),
};
update_persisted_remote_control_enrollment(
Some(state_db.as_ref()),
&first_target,
Some("account-a"),
Some(&first_enrollment),
)
.await
.expect("first enrollment should persist");
update_persisted_remote_control_enrollment(
Some(state_db.as_ref()),
&second_target,
Some("account-a"),
Some(&second_enrollment),
)
.await
.expect("second enrollment should persist");
update_persisted_remote_control_enrollment(
Some(state_db.as_ref()),
&first_target,
Some("account-a"),
None,
)
.await
.expect("matching enrollment should clear");
assert_eq!(
load_persisted_remote_control_enrollment(
Some(state_db.as_ref()),
&first_target,
Some("account-a"),
)
.await,
None
);
assert_eq!(
load_persisted_remote_control_enrollment(
Some(state_db.as_ref()),
&second_target,
Some("account-a"),
)
.await,
Some(second_enrollment)
);
}
#[tokio::test]
async fn enroll_remote_control_server_parse_failure_includes_response_body() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = format!(
"http://{}/backend-api/",
listener
.local_addr()
.expect("listener should have a local addr")
);
let remote_control_target =
normalize_remote_control_url(&remote_control_url).expect("target should parse");
let enroll_url = remote_control_target.enroll_url.clone();
let response_body = json!({
"error": "not enrolled",
});
let expected_body = response_body.to_string();
let server_task = tokio::spawn(async move {
let (stream, _) = timeout(Duration::from_secs(5), listener.accept())
.await
.expect("HTTP request should arrive in time")
.expect("listener accept should succeed");
respond_with_json(stream, response_body).await;
});
let err = enroll_remote_control_server(
&remote_control_target,
&RemoteControlConnectionAuth {
bearer_token: "Access Token".to_string(),
account_id: Some("account_id".to_string()),
},
)
.await
.expect_err("invalid response should fail to parse");
server_task.await.expect("server task should succeed");
assert_eq!(
err.to_string(),
format!(
"failed to parse remote control enrollment response from `{enroll_url}`: HTTP 200 OK, request-id: <none>, cf-ray: <none>, body: {expected_body}, decode error: missing field `server_id` at line 1 column {}",
expected_body.len()
)
);
}
async fn respond_with_json(mut stream: TcpStream, body: serde_json::Value) {
let body = body.to_string();
let response = format!(
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
body.len()
);
stream
.write_all(response.as_bytes())
.await
.expect("response should write");
stream.flush().await.expect("response should flush");
}
}

View File

@@ -0,0 +1,59 @@
mod client_tracker;
mod enroll;
mod protocol;
mod websocket;
use crate::transport::remote_control::websocket::load_remote_control_auth;
pub use self::protocol::ClientId;
use self::protocol::ServerEvent;
use self::protocol::normalize_remote_control_url;
use self::websocket::run_remote_control_websocket_loop;
use super::CHANNEL_CAPACITY;
use super::TransportEvent;
use super::next_connection_id;
use codex_core::AuthManager;
use codex_state::StateRuntime;
use std::io;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
pub(super) struct QueuedServerEnvelope {
pub(super) event: ServerEvent,
pub(super) client_id: ClientId,
pub(super) write_complete_tx: Option<oneshot::Sender<()>>,
}
pub(crate) async fn start_remote_control(
remote_control_url: String,
state_db: Option<Arc<StateRuntime>>,
auth_manager: Arc<AuthManager>,
transport_event_tx: mpsc::Sender<TransportEvent>,
shutdown_token: CancellationToken,
) -> io::Result<JoinHandle<()>> {
let remote_control_target = normalize_remote_control_url(&remote_control_url)?;
validate_remote_control_auth(&auth_manager).await?;
Ok(tokio::spawn(async move {
run_remote_control_websocket_loop(
remote_control_target,
state_db,
auth_manager,
transport_event_tx,
shutdown_token.child_token(),
)
.await;
}))
}
pub(crate) async fn validate_remote_control_auth(
auth_manager: &Arc<AuthManager>,
) -> io::Result<()> {
load_remote_control_auth(auth_manager).await.map(|_| ())
}
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,139 @@
use crate::outgoing_message::OutgoingMessage;
use codex_app_server_protocol::JSONRPCMessage;
use serde::Deserialize;
use serde::Serialize;
use std::io;
use std::io::ErrorKind;
use url::Url;
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) struct RemoteControlTarget {
pub(super) websocket_url: String,
pub(super) enroll_url: String,
}
#[derive(Debug, Serialize)]
pub(super) struct EnrollRemoteServerRequest {
pub(super) name: String,
pub(super) os: &'static str,
pub(super) arch: &'static str,
pub(super) app_server_version: &'static str,
}
#[derive(Debug, Deserialize)]
pub(super) struct EnrollRemoteServerResponse {
pub(super) server_id: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ClientId(pub String);
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ClientEvent {
ClientMessage { message: JSONRPCMessage },
Ack,
Ping,
ClientClosed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub(crate) struct ClientEnvelope {
#[serde(flatten)]
pub(crate) event: ClientEvent,
#[serde(rename = "client_id", alias = "clientId")]
pub(crate) client_id: ClientId,
#[serde(
rename = "seq_id",
alias = "seqId",
skip_serializing_if = "Option::is_none"
)]
pub(crate) seq_id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) cursor: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PongStatus {
Active,
Unknown,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ServerEvent {
ServerMessage {
message: Box<OutgoingMessage>,
},
#[allow(dead_code)]
Ack,
Pong {
status: PongStatus,
},
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case")]
pub(crate) struct ServerEnvelope {
#[serde(flatten)]
pub(crate) event: ServerEvent,
#[serde(rename = "client_id", alias = "clientId")]
pub(crate) client_id: ClientId,
#[serde(rename = "seq_id", alias = "seqId")]
pub(crate) seq_id: u64,
}
pub(super) fn normalize_remote_control_url(
remote_control_url: &str,
) -> io::Result<RemoteControlTarget> {
let map_url_parse_error = |err: url::ParseError| -> io::Error {
io::Error::new(
ErrorKind::InvalidInput,
format!("invalid remote control URL `{remote_control_url}`: {err}"),
)
};
let map_scheme_error = |_: ()| -> io::Error {
io::Error::new(
ErrorKind::InvalidInput,
format!(
"invalid remote control URL `{remote_control_url}`; expected absolute URL with http:// or https:// scheme"
),
)
};
let mut remote_control_url = Url::parse(remote_control_url).map_err(map_url_parse_error)?;
match remote_control_url.scheme() {
"https" | "http" => {}
_ => return Err(map_scheme_error(())),
}
if !remote_control_url.path().ends_with('/') {
let normalized_path = format!("{}/", remote_control_url.path());
remote_control_url.set_path(&normalized_path);
}
let mut enroll_url = remote_control_url
.join("wham/remote/control/server/enroll")
.map_err(map_url_parse_error)?;
let mut websocket_url = remote_control_url
.join("wham/remote/control/server")
.map_err(map_url_parse_error)?;
match remote_control_url.scheme() {
"https" => {
enroll_url.set_scheme("https").map_err(map_scheme_error)?;
websocket_url.set_scheme("wss").map_err(map_scheme_error)?;
}
"http" => {
enroll_url.set_scheme("http").map_err(map_scheme_error)?;
websocket_url.set_scheme("ws").map_err(map_scheme_error)?;
}
_ => return Err(map_scheme_error(())),
}
Ok(RemoteControlTarget {
websocket_url: websocket_url.to_string(),
enroll_url: enroll_url.to_string(),
})
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,8 @@
use super::CHANNEL_CAPACITY;
use super::TransportEvent;
use super::forward_incoming_message;
use super::next_connection_id;
use super::serialize_outgoing_message;
use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::QueuedOutgoingMessage;
use std::io::ErrorKind;
use std::io::Result as IoResult;
@@ -20,7 +20,7 @@ pub(crate) async fn start_stdio_connection(
transport_event_tx: mpsc::Sender<TransportEvent>,
stdio_handles: &mut Vec<JoinHandle<()>>,
) -> IoResult<()> {
let connection_id = ConnectionId(0);
let connection_id = next_connection_id();
let (writer_tx, mut writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
let writer_tx_for_reader = writer_tx.clone();
transport_event_tx

View File

@@ -4,6 +4,7 @@ use super::auth::WebsocketAuthPolicy;
use super::auth::authorize_upgrade;
use super::auth::should_warn_about_unauthenticated_non_loopback_listener;
use super::forward_incoming_message;
use super::next_connection_id;
use super::serialize_outgoing_message;
use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::QueuedOutgoingMessage;
@@ -32,8 +33,6 @@ use owo_colors::Style;
use std::io::Result as IoResult;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
@@ -75,7 +74,6 @@ fn print_websocket_startup_banner(addr: SocketAddr) {
#[derive(Clone)]
struct WebSocketListenerState {
transport_event_tx: mpsc::Sender<TransportEvent>,
connection_counter: Arc<AtomicU64>,
auth_policy: Arc<WebsocketAuthPolicy>,
}
@@ -113,7 +111,7 @@ async fn websocket_upgrade_handler(
);
return (err.status_code(), err.message()).into_response();
}
let connection_id = ConnectionId(state.connection_counter.fetch_add(1, Ordering::Relaxed));
let connection_id = next_connection_id();
info!(%peer_addr, "websocket client connected");
websocket
.on_upgrade(move |stream| async move {
@@ -146,7 +144,6 @@ pub(crate) async fn start_websocket_acceptor(
.layer(middleware::from_fn(reject_requests_with_origin_header))
.with_state(WebSocketListenerState {
transport_event_tx,
connection_counter: Arc::new(AtomicU64::new(1)),
auth_policy: Arc::new(auth_policy),
});
let server = axum::serve(

View File

@@ -328,7 +328,7 @@ struct AppServerCommand {
subcommand: Option<AppServerSubcommand>,
/// Transport endpoint URL. Supported values: `stdio://` (default),
/// `ws://IP:PORT`.
/// `ws://IP:PORT`, `off`.
#[arg(
long = "listen",
value_name = "URL",
@@ -1930,6 +1930,12 @@ mod tests {
);
}
#[test]
fn app_server_listen_off_parses() {
let app_server = app_server_from_args(["codex", "app-server", "--listen", "off"].as_ref());
assert_eq!(app_server.listen, codex_app_server::AppServerTransport::Off);
}
#[test]
fn app_server_listen_invalid_url_fails_to_parse() {
let parse_result =

View File

@@ -434,6 +434,9 @@
"realtime_conversation": {
"type": "boolean"
},
"remote_control": {
"type": "boolean"
},
"remote_models": {
"type": "boolean"
},
@@ -2071,6 +2074,9 @@
"realtime_conversation": {
"type": "boolean"
},
"remote_control": {
"type": "boolean"
},
"remote_models": {
"type": "boolean"
},

View File

@@ -176,6 +176,8 @@ pub enum Feature {
VoiceTranscription,
/// Enable experimental realtime voice conversation mode in the TUI.
RealtimeConversation,
/// Connect app-server to the ChatGPT remote control service.
RemoteControl,
/// Route interactive startup to the app-server-backed TUI implementation.
TuiAppServer,
/// Prevent idle system sleep while a turn is actively running.
@@ -819,6 +821,12 @@ pub const FEATURES: &[FeatureSpec] = &[
stage: Stage::UnderDevelopment,
default_enabled: false,
},
FeatureSpec {
id: Feature::RemoteControl,
key: "remote_control",
stage: Stage::UnderDevelopment,
default_enabled: false,
},
FeatureSpec {
id: Feature::TuiAppServer,
key: "tui_app_server",

View File

@@ -159,6 +159,12 @@ fn image_detail_original_feature_is_under_development() {
assert_eq!(Feature::ImageDetailOriginal.default_enabled(), false);
}
#[test]
fn remote_control_is_under_development() {
assert_eq!(Feature::RemoteControl.stage(), Stage::UnderDevelopment);
assert_eq!(Feature::RemoteControl.default_enabled(), false);
}
#[test]
fn collab_is_legacy_alias_for_multi_agent() {
assert_eq!(feature_for_key("multi_agent"), Some(Feature::Collab));

View File

@@ -0,0 +1,8 @@
CREATE TABLE remote_control_enrollments (
websocket_url TEXT NOT NULL,
account_id TEXT NOT NULL,
server_id TEXT NOT NULL,
server_name TEXT NOT NULL,
updated_at INTEGER NOT NULL,
PRIMARY KEY (websocket_url, account_id)
);

View File

@@ -53,6 +53,7 @@ mod agent_jobs;
mod backfill;
mod logs;
mod memories;
mod remote_control;
#[cfg(test)]
mod test_support;
mod threads;

View File

@@ -0,0 +1,197 @@
use super::*;
const REMOTE_CONTROL_ACCOUNT_ID_NONE: &str = "";
fn remote_control_account_id_key(account_id: Option<&str>) -> &str {
account_id.unwrap_or(REMOTE_CONTROL_ACCOUNT_ID_NONE)
}
impl StateRuntime {
pub async fn get_remote_control_enrollment(
&self,
websocket_url: &str,
account_id: Option<&str>,
) -> anyhow::Result<Option<(String, String)>> {
let row = sqlx::query(
r#"
SELECT server_id, server_name
FROM remote_control_enrollments
WHERE websocket_url = ? AND account_id = ?
"#,
)
.bind(websocket_url)
.bind(remote_control_account_id_key(account_id))
.fetch_optional(self.pool.as_ref())
.await?;
row.map(|row| Ok((row.try_get("server_id")?, row.try_get("server_name")?)))
.transpose()
}
pub async fn upsert_remote_control_enrollment(
&self,
websocket_url: &str,
account_id: Option<&str>,
server_id: &str,
server_name: &str,
) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO remote_control_enrollments (
websocket_url,
account_id,
server_id,
server_name,
updated_at
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(websocket_url, account_id) DO UPDATE SET
server_id = excluded.server_id,
server_name = excluded.server_name,
updated_at = excluded.updated_at
"#,
)
.bind(websocket_url)
.bind(remote_control_account_id_key(account_id))
.bind(server_id)
.bind(server_name)
.bind(Utc::now().timestamp())
.execute(self.pool.as_ref())
.await?;
Ok(())
}
pub async fn delete_remote_control_enrollment(
&self,
websocket_url: &str,
account_id: Option<&str>,
) -> anyhow::Result<u64> {
let result = sqlx::query(
r#"
DELETE FROM remote_control_enrollments
WHERE websocket_url = ? AND account_id = ?
"#,
)
.bind(websocket_url)
.bind(remote_control_account_id_key(account_id))
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected())
}
}
#[cfg(test)]
mod tests {
use super::StateRuntime;
use super::test_support::unique_temp_dir;
use pretty_assertions::assert_eq;
#[tokio::test]
async fn remote_control_enrollment_round_trips_by_target_and_account() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string())
.await
.expect("initialize runtime");
runtime
.upsert_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
Some("account-a"),
"srv_e_first",
"first-server",
)
.await
.expect("insert first enrollment");
runtime
.upsert_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
Some("account-b"),
"srv_e_second",
"second-server",
)
.await
.expect("insert second enrollment");
assert_eq!(
runtime
.get_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
Some("account-a"),
)
.await
.expect("load first enrollment"),
Some(("srv_e_first".to_string(), "first-server".to_string()))
);
assert_eq!(
runtime
.get_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
None,
)
.await
.expect("load missing enrollment"),
None
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn delete_remote_control_enrollment_removes_only_matching_entry() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string())
.await
.expect("initialize runtime");
runtime
.upsert_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
None,
"srv_e_first",
"first-server",
)
.await
.expect("insert first enrollment");
runtime
.upsert_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
Some("account-a"),
"srv_e_second",
"second-server",
)
.await
.expect("insert second enrollment");
assert_eq!(
runtime
.delete_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
None,
)
.await
.expect("delete first enrollment"),
1
);
assert_eq!(
runtime
.get_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
None,
)
.await
.expect("load deleted enrollment"),
None
);
assert_eq!(
runtime
.get_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
Some("account-a"),
)
.await
.expect("load retained enrollment"),
Some(("srv_e_second".to_string(), "second-server".to_string()))
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
}