Compare commits

...

4 Commits

Author SHA1 Message Date
Ruslan Nigmatullin
82be4e88b0 base url 2026-03-25 17:54:19 -07:00
Ruslan Nigmatullin
685d77d2c1 fixes 2026-03-25 17:07:48 -07:00
Ruslan Nigmatullin
8d62dd3257 sqlite 2026-03-25 15:35:58 -07:00
Ruslan Nigmatullin
8c644a154b app-server: Add transport for remote control 2026-03-25 15:11:52 -07:00
22 changed files with 4307 additions and 1492 deletions

3
codex-rs/Cargo.lock generated
View File

@@ -1473,9 +1473,11 @@ dependencies = [
"codex-utils-cli",
"codex-utils-json-to-toml",
"codex-utils-pty",
"codex-utils-rustls-provider",
"constant_time_eq",
"core_test_support",
"futures",
"gethostname",
"hmac",
"jsonwebtoken",
"opentelemetry",
@@ -1498,6 +1500,7 @@ dependencies = [
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
"url",
"uuid",
"wiremock",
]

View File

@@ -51,10 +51,12 @@ codex-sandboxing = { workspace = true }
codex-state = { workspace = true }
codex-utils-absolute-path = { workspace = true }
codex-utils-json-to-toml = { workspace = true }
codex-utils-rustls-provider = { workspace = true }
chrono = { workspace = true }
clap = { workspace = true, features = ["derive"] }
constant_time_eq = { workspace = true }
futures = { workspace = true }
gethostname = { workspace = true }
hmac = { workspace = true }
jsonwebtoken = { workspace = true }
owo-colors = { workspace = true, features = ["supports-colors"] }
@@ -75,6 +77,7 @@ tokio-util = { workspace = true }
tokio-tungstenite = { workspace = true }
tracing = { workspace = true, features = ["log"] }
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt", "json"] }
url = { workspace = true }
uuid = { workspace = true, features = ["serde", "v7"] }
[dev-dependencies]

View File

@@ -2,6 +2,17 @@
`codex app-server` is the interface Codex uses to power rich interfaces such as the [Codex VS Code extension](https://marketplace.visualstudio.com/items?itemName=openai.chatgpt).
For remote-control-only deployments, use `codexd`. It runs the same app-server runtime in a headless daemon mode, connects outbound to the ChatGPT remote control server using ChatGPT auth, and does not expose a local stdio or websocket transport.
Remote control is configured in `~/.codex/config.toml`:
```toml
chatgpt_base_url = "https://chatgpt.com/backend-api/"
[features]
remote_control = true
```
## Table of Contents
- [Protocol](#protocol)

View File

@@ -86,6 +86,7 @@ fn transport_name(transport: AppServerTransport) -> &'static str {
match transport {
AppServerTransport::Stdio => "stdio",
AppServerTransport::WebSocket { .. } => "websocket",
AppServerTransport::Headless => "headless",
}
}

View File

@@ -8,6 +8,7 @@ use codex_core::config::ConfigBuilder;
use codex_core::config_loader::CloudRequirementsLoader;
use codex_core::config_loader::ConfigLayerStackOrdering;
use codex_core::config_loader::LoaderOverrides;
use codex_features::Feature;
use codex_utils_cli::CliConfigOverrides;
use std::collections::HashMap;
use std::collections::HashSet;
@@ -29,8 +30,10 @@ use crate::transport::OutboundConnectionState;
use crate::transport::TransportEvent;
use crate::transport::auth::policy_from_settings;
use crate::transport::route_outgoing_envelope;
use crate::transport::start_remote_control;
use crate::transport::start_stdio_connection;
use crate::transport::start_websocket_acceptor;
use crate::transport::validate_remote_control_auth;
use codex_app_server_protocol::ConfigLayerSource;
use codex_app_server_protocol::ConfigWarningNotification;
use codex_app_server_protocol::JSONRPCMessage;
@@ -94,6 +97,37 @@ enum LogFormat {
Json,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
struct TransportRuntimeMode {
single_client_mode: bool,
shutdown_when_no_connections: bool,
graceful_ctrl_c_restart_enabled: bool,
ctrl_c_shutdown_enabled: bool,
}
fn transport_runtime_mode(transport: AppServerTransport) -> TransportRuntimeMode {
match transport {
AppServerTransport::Stdio => TransportRuntimeMode {
single_client_mode: true,
shutdown_when_no_connections: true,
graceful_ctrl_c_restart_enabled: false,
ctrl_c_shutdown_enabled: false,
},
AppServerTransport::WebSocket { .. } => TransportRuntimeMode {
single_client_mode: false,
shutdown_when_no_connections: false,
graceful_ctrl_c_restart_enabled: true,
ctrl_c_shutdown_enabled: false,
},
AppServerTransport::Headless => TransportRuntimeMode {
single_client_mode: false,
shutdown_when_no_connections: false,
graceful_ctrl_c_restart_enabled: false,
ctrl_c_shutdown_enabled: true,
},
}
}
type StderrLogLayer = Box<dyn Layer<Registry> + Send + Sync + 'static>;
/// Control-plane messages from the processor/transport side to the outbound router task.
@@ -361,38 +395,6 @@ pub async fn run_main_with_transport(
let (outbound_control_tx, mut outbound_control_rx) =
mpsc::channel::<OutboundControlEvent>(CHANNEL_CAPACITY);
enum TransportRuntime {
Stdio,
WebSocket {
accept_handle: JoinHandle<()>,
shutdown_token: CancellationToken,
},
}
let mut stdio_handles = Vec::<JoinHandle<()>>::new();
let transport_runtime = match transport {
AppServerTransport::Stdio => {
start_stdio_connection(transport_event_tx.clone(), &mut stdio_handles).await?;
TransportRuntime::Stdio
}
AppServerTransport::WebSocket { bind_address } => {
let shutdown_token = CancellationToken::new();
let accept_handle = start_websocket_acceptor(
bind_address,
transport_event_tx.clone(),
shutdown_token.clone(),
policy_from_settings(&auth)?,
)
.await?;
TransportRuntime::WebSocket {
accept_handle,
shutdown_token,
}
}
};
let single_client_mode = matches!(&transport_runtime, TransportRuntime::Stdio);
let shutdown_when_no_connections = single_client_mode;
let graceful_signal_restart_enabled = !single_client_mode;
// Parse CLI overrides once and derive the base Config eagerly so later
// components do not need to work with raw TOML values.
let cli_kv_overrides = cli_config_overrides.parse_overrides().map_err(|e| {
@@ -529,13 +531,13 @@ pub async fn run_main_with_transport(
let feedback_layer = feedback.logger_layer();
let feedback_metadata_layer = feedback.metadata_layer();
let log_db = codex_state::StateRuntime::init(
let state_db = codex_state::StateRuntime::init(
config.sqlite_home.clone(),
config.model_provider_id.clone(),
)
.await
.ok()
.map(log_db::start);
.ok();
let log_db = state_db.clone().map(log_db::start);
let log_db_layer = log_db
.clone()
.map(|layer| layer.with_filter(Targets::new().with_default(Level::TRACE)));
@@ -556,6 +558,57 @@ pub async fn run_main_with_transport(
}
}
let transport_shutdown_token = CancellationToken::new();
let mut transport_accept_handles = Vec::<JoinHandle<()>>::new();
let runtime_mode = transport_runtime_mode(transport);
match transport {
AppServerTransport::Stdio => {
start_stdio_connection(transport_event_tx.clone(), &mut transport_accept_handles)
.await?;
}
AppServerTransport::WebSocket { bind_address } => {
let accept_handle = start_websocket_acceptor(
bind_address,
transport_event_tx.clone(),
transport_shutdown_token.clone(),
policy_from_settings(&auth)?,
)
.await?;
transport_accept_handles.push(accept_handle);
}
AppServerTransport::Headless => {}
}
let shutdown_when_no_connections = runtime_mode.shutdown_when_no_connections;
let graceful_ctrl_c_restart_enabled = runtime_mode.graceful_ctrl_c_restart_enabled;
let graceful_signal_restart_enabled = runtime_mode.graceful_ctrl_c_restart_enabled;
let auth_manager = AuthManager::shared(
config.codex_home.clone(),
/*enable_codex_api_key_env*/ false,
config.cli_auth_credentials_store_mode,
);
auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone());
if config.features.enabled(Feature::RemoteControl) {
validate_remote_control_auth(auth_manager.as_ref()).await?;
let accept_handle = start_remote_control(
config.chatgpt_base_url.clone(),
state_db.clone(),
auth_manager.clone(),
transport_event_tx.clone(),
transport_shutdown_token.clone(),
)
.await?;
transport_accept_handles.push(accept_handle);
}
if transport_accept_handles.is_empty() {
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
"no transport configured; use --listen or enable remote control",
));
}
let outbound_handle = tokio::spawn(async move {
let mut outbound_connections = HashMap::<ConnectionId, OutboundConnectionState>::new();
loop {
@@ -632,10 +685,7 @@ pub async fn run_main_with_transport(
let mut thread_created_rx = processor.thread_created_receiver();
let mut running_turn_count_rx = processor.subscribe_running_assistant_turn_count();
let mut connections = HashMap::<ConnectionId, ConnectionState>::new();
let websocket_accept_shutdown = match &transport_runtime {
TransportRuntime::WebSocket { shutdown_token, .. } => Some(shutdown_token.clone()),
TransportRuntime::Stdio => None,
};
let transport_shutdown_token = transport_shutdown_token.clone();
async move {
let mut listen_for_threads = true;
let mut shutdown_state = ShutdownState::default();
@@ -648,9 +698,7 @@ pub async fn run_main_with_transport(
shutdown_state.update(running_turn_count, connections.len()),
ShutdownAction::Finish
) {
if let Some(shutdown_token) = &websocket_accept_shutdown {
shutdown_token.cancel();
}
transport_shutdown_token.cancel();
let _ = outbound_control_tx
.send(OutboundControlEvent::DisconnectAll)
.await;
@@ -665,6 +713,24 @@ pub async fn run_main_with_transport(
let running_turn_count = *running_turn_count_rx.borrow();
shutdown_state.on_signal(connections.len(), running_turn_count);
}
ctrl_c_result = tokio::signal::ctrl_c(), if runtime_mode.ctrl_c_shutdown_enabled => {
if let Err(err) = ctrl_c_result {
warn!("failed to listen for Ctrl-C during daemon shutdown: {err}");
}
info!("received Ctrl-C; shutting down codexd remote-control daemon");
transport_shutdown_token.cancel();
let _ = outbound_control_tx
.send(OutboundControlEvent::DisconnectAll)
.await;
break;
}
ctrl_c_result = tokio::signal::ctrl_c(), if graceful_ctrl_c_restart_enabled && !shutdown_state.forced() => {
if let Err(err) = ctrl_c_result {
warn!("failed to listen for Ctrl-C during graceful restart drain: {err}");
}
let running_turn_count = *running_turn_count_rx.borrow();
shutdown_state.on_signal(connections.len(), running_turn_count);
}
changed = running_turn_count_rx.changed(), if graceful_signal_restart_enabled && shutdown_state.requested() => {
if changed.is_err() {
warn!("running-turn watcher closed during graceful restart drain");
@@ -844,16 +910,8 @@ pub async fn run_main_with_transport(
let _ = processor_handle.await;
let _ = outbound_handle.await;
if let TransportRuntime::WebSocket {
accept_handle,
shutdown_token,
} = transport_runtime
{
shutdown_token.cancel();
let _ = accept_handle.await;
}
for handle in stdio_handles {
transport_shutdown_token.cancel();
for handle in transport_accept_handles {
let _ = handle.await;
}
@@ -867,6 +925,9 @@ pub async fn run_main_with_transport(
#[cfg(test)]
mod tests {
use super::LogFormat;
use super::TransportRuntimeMode;
use super::transport_runtime_mode;
use crate::AppServerTransport;
use pretty_assertions::assert_eq;
#[test]
@@ -883,4 +944,17 @@ mod tests {
assert_eq!(LogFormat::from_env_value(Some("text")), LogFormat::Default);
assert_eq!(LogFormat::from_env_value(Some("jsonl")), LogFormat::Default);
}
#[test]
fn headless_transport_runtime_mode_uses_daemon_shutdown_behavior() {
assert_eq!(
transport_runtime_mode(AppServerTransport::Headless),
TransportRuntimeMode {
single_client_mode: false,
shutdown_when_no_connections: false,
graceful_ctrl_c_restart_enabled: false,
ctrl_c_shutdown_enabled: true,
}
);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,385 @@
use crate::error_code::OVERLOADED_ERROR_CODE;
use crate::message_processor::ConnectionSessionState;
use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::OutgoingEnvelope;
use crate::outgoing_message::OutgoingError;
use crate::outgoing_message::OutgoingMessage;
use crate::outgoing_message::QueuedOutgoingMessage;
use codex_app_server_protocol::JSONRPCErrorError;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::ServerRequest;
use std::collections::HashMap;
use std::collections::HashSet;
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::RwLock;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::error;
use tracing::warn;
/// Size of the bounded channels used to communicate between tasks. The value
/// is a balance between throughput and memory usage - 128 messages should be
/// plenty for an interactive CLI.
pub(crate) const CHANNEL_CAPACITY: usize = 128;
pub(crate) mod auth;
mod remote_control;
mod stdio;
mod websocket;
pub(crate) use remote_control::start_remote_control;
pub(crate) use remote_control::validate_remote_control_auth;
pub(crate) use stdio::start_stdio_connection;
pub(crate) use websocket::start_websocket_acceptor;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum AppServerTransport {
Stdio,
WebSocket { bind_address: SocketAddr },
Headless,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum AppServerTransportParseError {
UnsupportedListenUrl(String),
InvalidWebSocketListenUrl(String),
}
impl std::fmt::Display for AppServerTransportParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!(
f,
"unsupported --listen URL `{listen_url}`; expected `stdio://` or `ws://IP:PORT`"
),
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!(
f,
"invalid websocket --listen URL `{listen_url}`; expected `ws://IP:PORT`"
),
}
}
}
impl std::error::Error for AppServerTransportParseError {}
impl AppServerTransport {
pub const DEFAULT_LISTEN_URL: &'static str = "stdio://";
pub fn from_listen_url(listen_url: &str) -> Result<Self, AppServerTransportParseError> {
if listen_url == Self::DEFAULT_LISTEN_URL {
return Ok(Self::Stdio);
}
if let Some(socket_addr) = listen_url.strip_prefix("ws://") {
let bind_address = socket_addr.parse::<SocketAddr>().map_err(|_| {
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string())
})?;
return Ok(Self::WebSocket { bind_address });
}
Err(AppServerTransportParseError::UnsupportedListenUrl(
listen_url.to_string(),
))
}
}
impl FromStr for AppServerTransport {
type Err = AppServerTransportParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::from_listen_url(s)
}
}
#[derive(Debug)]
pub(crate) enum TransportEvent {
ConnectionOpened {
connection_id: ConnectionId,
writer: mpsc::Sender<QueuedOutgoingMessage>,
disconnect_sender: Option<CancellationToken>,
},
ConnectionClosed {
connection_id: ConnectionId,
},
IncomingMessage {
connection_id: ConnectionId,
message: JSONRPCMessage,
},
}
pub(crate) struct ConnectionState {
pub(crate) outbound_initialized: Arc<AtomicBool>,
pub(crate) outbound_experimental_api_enabled: Arc<AtomicBool>,
pub(crate) outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
pub(crate) session: ConnectionSessionState,
}
impl ConnectionState {
pub(crate) fn new(
outbound_initialized: Arc<AtomicBool>,
outbound_experimental_api_enabled: Arc<AtomicBool>,
outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
) -> Self {
Self {
outbound_initialized,
outbound_experimental_api_enabled,
outbound_opted_out_notification_methods,
session: ConnectionSessionState::default(),
}
}
}
pub(crate) struct OutboundConnectionState {
pub(crate) initialized: Arc<AtomicBool>,
pub(crate) experimental_api_enabled: Arc<AtomicBool>,
pub(crate) opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
pub(crate) writer: mpsc::Sender<QueuedOutgoingMessage>,
disconnect_sender: Option<CancellationToken>,
}
impl OutboundConnectionState {
pub(crate) fn new(
writer: mpsc::Sender<QueuedOutgoingMessage>,
initialized: Arc<AtomicBool>,
experimental_api_enabled: Arc<AtomicBool>,
opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
disconnect_sender: Option<CancellationToken>,
) -> Self {
Self {
initialized,
experimental_api_enabled,
opted_out_notification_methods,
writer,
disconnect_sender,
}
}
fn can_disconnect(&self) -> bool {
self.disconnect_sender.is_some()
}
pub(crate) fn request_disconnect(&self) {
if let Some(disconnect_sender) = &self.disconnect_sender {
disconnect_sender.cancel();
}
}
}
static CONNECTION_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
fn next_connection_id() -> ConnectionId {
ConnectionId(CONNECTION_ID_COUNTER.fetch_add(1, Ordering::Relaxed))
}
async fn forward_incoming_message(
transport_event_tx: &mpsc::Sender<TransportEvent>,
writer: &mpsc::Sender<QueuedOutgoingMessage>,
connection_id: ConnectionId,
payload: &str,
) -> bool {
match serde_json::from_str::<JSONRPCMessage>(payload) {
Ok(message) => {
enqueue_incoming_message(transport_event_tx, writer, connection_id, message).await
}
Err(err) => {
error!("Failed to deserialize JSONRPCMessage: {err}");
true
}
}
}
async fn enqueue_incoming_message(
transport_event_tx: &mpsc::Sender<TransportEvent>,
writer: &mpsc::Sender<QueuedOutgoingMessage>,
connection_id: ConnectionId,
message: JSONRPCMessage,
) -> bool {
let event = TransportEvent::IncomingMessage {
connection_id,
message,
};
match transport_event_tx.try_send(event) {
Ok(()) => true,
Err(mpsc::error::TrySendError::Closed(_)) => false,
Err(mpsc::error::TrySendError::Full(TransportEvent::IncomingMessage {
connection_id,
message: JSONRPCMessage::Request(request),
})) => {
let overload_error = OutgoingMessage::Error(OutgoingError {
id: request.id,
error: JSONRPCErrorError {
code: OVERLOADED_ERROR_CODE,
message: "Server overloaded; retry later.".to_string(),
data: None,
},
});
match writer.try_send(QueuedOutgoingMessage::new(overload_error)) {
Ok(()) => true,
Err(mpsc::error::TrySendError::Closed(_)) => false,
Err(mpsc::error::TrySendError::Full(_overload_error)) => {
warn!(
"dropping overload response for connection {:?}: outbound queue is full",
connection_id
);
true
}
}
}
Err(mpsc::error::TrySendError::Full(event)) => transport_event_tx.send(event).await.is_ok(),
}
}
fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option<String> {
let value = match serde_json::to_value(outgoing_message) {
Ok(value) => value,
Err(err) => {
error!("Failed to convert OutgoingMessage to JSON value: {err}");
return None;
}
};
match serde_json::to_string(&value) {
Ok(json) => Some(json),
Err(err) => {
error!("Failed to serialize JSONRPCMessage: {err}");
None
}
}
}
fn should_skip_notification_for_connection(
connection_state: &OutboundConnectionState,
message: &OutgoingMessage,
) -> bool {
let Ok(opted_out_notification_methods) = connection_state.opted_out_notification_methods.read()
else {
warn!("failed to read outbound opted-out notifications");
return false;
};
match message {
OutgoingMessage::AppServerNotification(notification) => {
let method = notification.to_string();
opted_out_notification_methods.contains(method.as_str())
}
_ => false,
}
}
fn disconnect_connection(
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
connection_id: ConnectionId,
) -> bool {
if let Some(connection_state) = connections.remove(&connection_id) {
connection_state.request_disconnect();
return true;
}
false
}
async fn send_message_to_connection(
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
connection_id: ConnectionId,
message: OutgoingMessage,
write_complete_tx: Option<tokio::sync::oneshot::Sender<()>>,
) -> bool {
let Some(connection_state) = connections.get(&connection_id) else {
warn!("dropping message for disconnected connection: {connection_id:?}");
return false;
};
let message = filter_outgoing_message_for_connection(connection_state, message);
if should_skip_notification_for_connection(connection_state, &message) {
return false;
}
let writer = connection_state.writer.clone();
let queued_message = QueuedOutgoingMessage {
message,
write_complete_tx,
};
if connection_state.can_disconnect() {
match writer.try_send(queued_message) {
Ok(()) => false,
Err(mpsc::error::TrySendError::Full(_)) => {
warn!(
"disconnecting slow connection after outbound queue filled: {connection_id:?}"
);
disconnect_connection(connections, connection_id)
}
Err(mpsc::error::TrySendError::Closed(_)) => {
disconnect_connection(connections, connection_id)
}
}
} else if writer.send(queued_message).await.is_err() {
disconnect_connection(connections, connection_id)
} else {
false
}
}
fn filter_outgoing_message_for_connection(
connection_state: &OutboundConnectionState,
message: OutgoingMessage,
) -> OutgoingMessage {
let experimental_api_enabled = connection_state
.experimental_api_enabled
.load(Ordering::Acquire);
match message {
OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
request_id,
mut params,
}) => {
if !experimental_api_enabled {
params.strip_experimental_fields();
}
OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
request_id,
params,
})
}
_ => message,
}
}
pub(crate) async fn route_outgoing_envelope(
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
envelope: OutgoingEnvelope,
) {
match envelope {
OutgoingEnvelope::ToConnection {
connection_id,
message,
write_complete_tx,
} => {
let _ =
send_message_to_connection(connections, connection_id, message, write_complete_tx)
.await;
}
OutgoingEnvelope::Broadcast { message } => {
let target_connections: Vec<ConnectionId> = connections
.iter()
.filter_map(|(connection_id, connection_state)| {
if connection_state.initialized.load(Ordering::Acquire)
&& !should_skip_notification_for_connection(connection_state, &message)
{
Some(*connection_id)
} else {
None
}
})
.collect();
for connection_id in target_connections {
let _ = send_message_to_connection(
connections,
connection_id,
message.clone(),
/*write_complete_tx*/ None,
)
.await;
}
}
}
}

View File

@@ -0,0 +1,370 @@
use super::protocol::EnrollRemoteServerRequest;
use super::protocol::EnrollRemoteServerResponse;
use super::protocol::RemoteControlTarget;
use axum::http::HeaderMap;
use base64::Engine;
use codex_core::AuthManager;
use codex_core::default_client::build_reqwest_client;
use codex_state::StateRuntime;
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
use gethostname::gethostname;
use io::ErrorKind;
use std::io;
use tokio::net::TcpStream;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::HeaderValue;
use tracing::warn;
const REMOTE_CONTROL_ENROLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
const REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES: usize = 4096;
pub(super) const REMOTE_CONTROL_PROTOCOL_VERSION: &str = "2";
pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id";
const REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER: &str = "x-codex-subscribe-cursor";
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) struct RemoteControlEnrollment {
pub(super) server_id: String,
pub(super) server_name: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) struct RemoteControlConnectionAuth {
pub(super) bearer_token: String,
pub(super) account_id: Option<String>,
}
pub(super) struct RemoteControlWebsocketConnection {
pub(super) websocket_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
}
pub(super) async fn load_persisted_remote_control_enrollment(
state_db: Option<&StateRuntime>,
remote_control_target: &RemoteControlTarget,
account_id: Option<&str>,
) -> Option<RemoteControlEnrollment> {
let state_db = state_db?;
let enrollment = match state_db
.get_remote_control_enrollment(&remote_control_target.websocket_url, account_id)
.await
{
Ok(enrollment) => enrollment,
Err(err) => {
warn!("{err}");
return None;
}
};
enrollment.map(|(server_id, server_name)| RemoteControlEnrollment {
server_id,
server_name,
})
}
pub(super) async fn update_persisted_remote_control_enrollment(
state_db: Option<&StateRuntime>,
remote_control_target: &RemoteControlTarget,
account_id: Option<&str>,
enrollment: Option<&RemoteControlEnrollment>,
) -> io::Result<()> {
let Some(state_db) = state_db else {
return Ok(());
};
if let Some(enrollment) = enrollment {
state_db
.upsert_remote_control_enrollment(
&remote_control_target.websocket_url,
account_id,
&enrollment.server_id,
&enrollment.server_name,
)
.await
.map_err(io::Error::other)
} else {
state_db
.delete_remote_control_enrollment(&remote_control_target.websocket_url, account_id)
.await
.map(|_| ())
.map_err(io::Error::other)
}
}
pub(super) async fn load_remote_control_auth(
auth_manager: &AuthManager,
) -> io::Result<RemoteControlConnectionAuth> {
let auth = match auth_manager.auth().await {
Some(auth) => auth,
None => {
auth_manager.reload();
auth_manager.auth().await.ok_or_else(|| {
io::Error::new(
ErrorKind::PermissionDenied,
"remote control requires ChatGPT authentication",
)
})?
}
};
if !auth.is_chatgpt_auth() {
return Err(io::Error::new(
ErrorKind::PermissionDenied,
"remote control requires ChatGPT authentication; API key auth is not supported",
));
}
Ok(RemoteControlConnectionAuth {
bearer_token: auth.get_token().map_err(io::Error::other)?,
account_id: auth.get_account_id(),
})
}
fn preview_remote_control_response_body(body: &[u8]) -> String {
let body = String::from_utf8_lossy(body);
let trimmed = body.trim();
if trimmed.is_empty() {
return "<empty>".to_string();
}
if trimmed.len() <= REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES {
return trimmed.to_string();
}
let mut cut = REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES;
while !trimmed.is_char_boundary(cut) {
cut = cut.saturating_sub(1);
}
let mut truncated = trimmed[..cut].to_string();
truncated.push_str("...");
truncated
}
fn format_headers(headers: &HeaderMap) -> String {
let mut headers = headers
.iter()
.map(|(name, value)| {
format!(
"{}: {}",
name.as_str(),
value.to_str().unwrap_or("<invalid utf-8>")
)
})
.collect::<Vec<_>>();
headers.sort();
format!("{{{}}}", headers.join(", "))
}
fn format_remote_control_websocket_connect_error(
websocket_url: &str,
err: &tungstenite::Error,
) -> String {
let mut message =
format!("failed to connect app-server remote control websocket `{websocket_url}`: {err}");
let tungstenite::Error::Http(response) = err else {
return message;
};
message.push_str(&format!(
", headers: {}",
format_headers(response.headers())
));
if let Some(body) = response.body().as_ref()
&& !body.is_empty()
{
let body_preview = preview_remote_control_response_body(body);
message.push_str(&format!(", body: {body_preview}"));
}
message
}
pub(super) async fn enroll_remote_control_server(
remote_control_target: &RemoteControlTarget,
auth: &RemoteControlConnectionAuth,
) -> io::Result<RemoteControlEnrollment> {
let enroll_url = &remote_control_target.enroll_url;
let server_name = gethostname().to_string_lossy().trim().to_string();
let request = EnrollRemoteServerRequest {
name: server_name.clone(),
os: std::env::consts::OS,
arch: std::env::consts::ARCH,
app_server_version: env!("CARGO_PKG_VERSION"),
};
let client = build_reqwest_client();
let mut http_request = client
.post(enroll_url)
.timeout(REMOTE_CONTROL_ENROLL_TIMEOUT)
.bearer_auth(&auth.bearer_token)
.json(&request);
if let Some(account_id) = auth.account_id.as_deref() {
http_request = http_request.header(REMOTE_CONTROL_ACCOUNT_ID_HEADER, account_id);
}
let response = http_request.send().await.map_err(|err| {
io::Error::other(format!(
"failed to enroll remote control server at `{enroll_url}`: {err}"
))
})?;
let headers = response.headers().clone();
let status = response.status();
let body = response.bytes().await.map_err(|err| {
io::Error::other(format!(
"failed to read remote control enrollment response from `{enroll_url}`: {err}"
))
})?;
let body_preview = preview_remote_control_response_body(&body);
if !status.is_success() {
let headers_str = format_headers(&headers);
return Err(io::Error::other(format!(
"remote control server enrollment failed at `{enroll_url}`: HTTP {status}, headers: {headers_str}, body: {body_preview}"
)));
}
let enrollment = serde_json::from_slice::<EnrollRemoteServerResponse>(&body).map_err(|err| {
let headers_str = format_headers(&headers);
io::Error::other(format!(
"failed to parse remote control enrollment response from `{enroll_url}`: HTTP {status}, headers: {headers_str}, body: {body_preview}, decode error: {err}"
))
})?;
Ok(RemoteControlEnrollment {
server_id: enrollment.server_id,
server_name,
})
}
fn set_remote_control_header(
headers: &mut tungstenite::http::HeaderMap,
name: &'static str,
value: &str,
) -> io::Result<()> {
let header_value = HeaderValue::from_str(value).map_err(|err| {
io::Error::new(
ErrorKind::InvalidInput,
format!("invalid remote control header `{name}`: {err}"),
)
})?;
headers.insert(name, header_value);
Ok(())
}
fn build_remote_control_websocket_request(
websocket_url: &str,
enrollment: &RemoteControlEnrollment,
auth: &RemoteControlConnectionAuth,
subscribe_cursor: Option<&str>,
) -> io::Result<tungstenite::http::Request<()>> {
let mut request = websocket_url.into_client_request().map_err(|err| {
io::Error::new(
ErrorKind::InvalidInput,
format!("invalid remote control websocket URL `{websocket_url}`: {err}"),
)
})?;
let headers = request.headers_mut();
set_remote_control_header(headers, "x-codex-server-id", &enrollment.server_id)?;
set_remote_control_header(
headers,
"x-codex-name",
&base64::engine::general_purpose::STANDARD.encode(&enrollment.server_name),
)?;
set_remote_control_header(
headers,
"x-codex-protocol-version",
REMOTE_CONTROL_PROTOCOL_VERSION,
)?;
set_remote_control_header(
headers,
"authorization",
&format!("Bearer {}", auth.bearer_token),
)?;
if let Some(account_id) = auth.account_id.as_deref() {
set_remote_control_header(headers, REMOTE_CONTROL_ACCOUNT_ID_HEADER, account_id)?;
}
if let Some(subscribe_cursor) = subscribe_cursor {
set_remote_control_header(
headers,
REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER,
subscribe_cursor,
)?;
}
Ok(request)
}
pub(super) async fn connect_remote_control_websocket(
remote_control_target: &RemoteControlTarget,
state_db: Option<&StateRuntime>,
auth_manager: &AuthManager,
enrollment: &mut Option<RemoteControlEnrollment>,
subscribe_cursor: Option<&str>,
) -> io::Result<RemoteControlWebsocketConnection> {
ensure_rustls_crypto_provider();
let auth = load_remote_control_auth(auth_manager).await?;
if enrollment.is_none() {
*enrollment = load_persisted_remote_control_enrollment(
state_db,
remote_control_target,
auth.account_id.as_deref(),
)
.await;
}
if enrollment.is_none() {
let new_enrollment = enroll_remote_control_server(remote_control_target, &auth).await?;
if let Err(err) = update_persisted_remote_control_enrollment(
state_db,
remote_control_target,
auth.account_id.as_deref(),
Some(&new_enrollment),
)
.await
{
warn!("failed to persist remote control enrollment in sqlite state db: {err}");
}
*enrollment = Some(new_enrollment);
}
let enrollment_ref = enrollment.as_ref().ok_or_else(|| {
io::Error::other("missing remote control enrollment after enrollment step")
})?;
let request = build_remote_control_websocket_request(
&remote_control_target.websocket_url,
enrollment_ref,
&auth,
subscribe_cursor,
)?;
let (websocket_stream, _response) = match connect_async(request).await {
Ok((websocket_stream, response)) => (websocket_stream, response),
Err(err) => {
if matches!(
&err,
tungstenite::Error::Http(response) if response.status().as_u16() == 404
) {
if let Err(clear_err) = update_persisted_remote_control_enrollment(
state_db,
remote_control_target,
auth.account_id.as_deref(),
/*enrollment*/ None,
)
.await
{
warn!(
"failed to clear stale remote control enrollment in sqlite state db: {clear_err}"
);
}
*enrollment = None;
}
return Err(io::Error::other(
format_remote_control_websocket_connect_error(
&remote_control_target.websocket_url,
&err,
),
));
}
};
Ok(RemoteControlWebsocketConnection { websocket_stream })
}

View File

@@ -0,0 +1,346 @@
mod enroll;
mod protocol;
mod websocket;
use self::enroll::load_remote_control_auth;
use self::protocol::ClientEnvelope;
pub use self::protocol::ClientEvent;
pub use self::protocol::ClientId;
use self::protocol::PongStatus;
use self::protocol::ServerEnvelope;
use self::protocol::ServerEvent;
use self::protocol::normalize_remote_control_url;
use self::websocket::run_remote_control_websocket_loop;
use super::CHANNEL_CAPACITY;
use super::TransportEvent;
use super::next_connection_id;
use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::QueuedOutgoingMessage;
use codex_app_server_protocol::JSONRPCMessage;
use codex_core::AuthManager;
use codex_state::StateRuntime;
use std::collections::HashMap;
use std::io;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tokio::task::JoinSet;
use tokio::time::Duration;
use tokio::time::Instant;
use tokio::time::MissedTickBehavior;
use tokio_util::sync::CancellationToken;
const REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT: Duration = Duration::from_secs(10 * 60);
const REMOTE_CONTROL_IDLE_SWEEP_INTERVAL: Duration = Duration::from_secs(30);
struct RemoteControlClientState {
connection_id: ConnectionId,
disconnect_token: CancellationToken,
last_activity_at: Instant,
last_inbound_seq_id: Option<u64>,
}
pub(super) struct RemoteControlQueuedServerEnvelope {
pub(super) envelope: ServerEnvelope,
pub(super) write_complete_tx: Option<oneshot::Sender<()>>,
}
pub(crate) async fn start_remote_control(
remote_control_url: String,
state_db: Option<Arc<StateRuntime>>,
auth_manager: Arc<AuthManager>,
transport_event_tx: mpsc::Sender<TransportEvent>,
shutdown_token: CancellationToken,
) -> io::Result<JoinHandle<()>> {
let remote_control_url = normalize_remote_control_url(&remote_control_url)?;
Ok(tokio::spawn(async move {
let local_shutdown_token = shutdown_token.child_token();
let (client_event_tx, client_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (server_event_tx, server_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (writer_exited_tx, writer_exited_rx) = mpsc::channel(CHANNEL_CAPACITY);
let mut join_set = JoinSet::new();
join_set.spawn(run_remote_control_websocket_loop(
remote_control_url,
state_db,
auth_manager,
client_event_tx,
server_event_rx,
local_shutdown_token.clone(),
));
join_set.spawn(run_remote_control_manager(
transport_event_tx,
client_event_rx,
server_event_tx,
writer_exited_tx,
writer_exited_rx,
local_shutdown_token.clone(),
));
tokio::select! {
_ = local_shutdown_token.cancelled() => {}
_ = join_set.join_next() => local_shutdown_token.cancel(),
}
join_set.shutdown().await;
}))
}
async fn run_remote_control_manager(
transport_event_tx: mpsc::Sender<TransportEvent>,
mut client_event_rx: mpsc::Receiver<ClientEnvelope>,
server_event_tx: mpsc::Sender<RemoteControlQueuedServerEnvelope>,
writer_exited_tx: mpsc::Sender<ClientId>,
mut writer_exited_rx: mpsc::Receiver<ClientId>,
shutdown_token: CancellationToken,
) {
let mut clients = HashMap::<ClientId, RemoteControlClientState>::new();
let mut idle_sweep = tokio::time::interval(REMOTE_CONTROL_IDLE_SWEEP_INTERVAL);
idle_sweep.set_missed_tick_behavior(MissedTickBehavior::Skip);
loop {
tokio::select! {
_ = shutdown_token.cancelled() => {
break;
}
_ = idle_sweep.tick() => {
if !close_expired_remote_control_clients(&transport_event_tx, &mut clients).await {
break;
}
}
writer_exited = writer_exited_rx.recv() => {
let Some(client_id) = writer_exited else {
break;
};
if !close_remote_control_client(&transport_event_tx, &mut clients, &client_id).await {
break;
}
}
client_event = client_event_rx.recv() => {
let Some(client_event) = client_event else {
break;
};
match client_event.event {
ClientEvent::ClientMessage { message } => {
let client_id = client_event.client_id;
let is_initialize = remote_control_message_starts_connection(&message);
if let Some(seq_id) = client_event.seq_id
&& let Some(client) = clients.get(&client_id)
&& client.last_inbound_seq_id.is_some_and(|last_seq_id| last_seq_id >= seq_id)
&& !is_initialize
{
continue;
}
if is_initialize && clients.contains_key(&client_id)
&& !close_remote_control_client(&transport_event_tx, &mut clients, &client_id).await {
break;
}
if let Some(connection_id) = clients.get_mut(&client_id).map(|client| {
client.last_activity_at = Instant::now();
if let Some(seq_id) = client_event.seq_id {
client.last_inbound_seq_id = Some(seq_id);
}
client.connection_id
}) {
if transport_event_tx
.send(TransportEvent::IncomingMessage {
connection_id,
message,
})
.await
.is_err()
{
break;
}
continue;
}
if !is_initialize {
continue;
}
let connection_id = next_connection_id();
let (writer_tx, writer_rx) =
mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
let disconnect_token = CancellationToken::new();
if transport_event_tx
.send(TransportEvent::ConnectionOpened {
connection_id,
writer: writer_tx,
disconnect_sender: Some(disconnect_token.clone()),
})
.await
.is_err()
{
break;
}
tokio::spawn(run_remote_control_client_outbound(
client_id.clone(),
writer_rx,
server_event_tx.clone(),
writer_exited_tx.clone(),
disconnect_token.clone(),
));
clients.insert(
client_id,
RemoteControlClientState {
connection_id,
disconnect_token,
last_activity_at: Instant::now(),
last_inbound_seq_id: client_event.seq_id,
},
);
if transport_event_tx
.send(TransportEvent::IncomingMessage {
connection_id,
message,
})
.await
.is_err()
{
break;
}
}
ClientEvent::Ack { .. } => continue,
ClientEvent::Ping => {
let client_id = client_event.client_id;
let status = match clients.get_mut(&client_id) {
Some(client) => {
client.last_activity_at = Instant::now();
PongStatus::Active
}
None => PongStatus::Unknown,
};
if server_event_tx
.send(RemoteControlQueuedServerEnvelope {
envelope: ServerEnvelope {
event: ServerEvent::Pong { status },
client_id,
seq_id: None,
},
write_complete_tx: None,
})
.await
.is_err()
{
break;
}
}
ClientEvent::ClientClosed => {
let client_id = client_event.client_id;
if !close_remote_control_client(&transport_event_tx, &mut clients, &client_id).await {
break;
}
}
}
}
}
}
while let Some(client_id) = clients.keys().next().cloned() {
if !close_remote_control_client(&transport_event_tx, &mut clients, &client_id).await {
break;
}
}
}
fn remote_control_message_starts_connection(message: &JSONRPCMessage) -> bool {
matches!(
message,
JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { method, .. })
if method == "initialize"
)
}
fn remote_control_client_is_alive(client: &RemoteControlClientState, now: Instant) -> bool {
now.duration_since(client.last_activity_at) < REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT
}
async fn close_expired_remote_control_clients(
transport_event_tx: &mpsc::Sender<TransportEvent>,
clients: &mut HashMap<ClientId, RemoteControlClientState>,
) -> bool {
let now = Instant::now();
let expired_client_ids: Vec<ClientId> = clients
.iter()
.filter_map(|(client_id, client)| {
(!remote_control_client_is_alive(client, now)).then_some(client_id.clone())
})
.collect();
for client_id in expired_client_ids {
if !close_remote_control_client(transport_event_tx, clients, &client_id).await {
return false;
}
}
true
}
async fn close_remote_control_client(
transport_event_tx: &mpsc::Sender<TransportEvent>,
clients: &mut HashMap<ClientId, RemoteControlClientState>,
client_id: &ClientId,
) -> bool {
let Some(client) = clients.remove(client_id) else {
return true;
};
client.disconnect_token.cancel();
transport_event_tx
.send(TransportEvent::ConnectionClosed {
connection_id: client.connection_id,
})
.await
.is_ok()
}
async fn run_remote_control_client_outbound(
client_id: ClientId,
mut writer_rx: mpsc::Receiver<QueuedOutgoingMessage>,
server_event_tx: mpsc::Sender<RemoteControlQueuedServerEnvelope>,
writer_exited_tx: mpsc::Sender<ClientId>,
disconnect_token: CancellationToken,
) {
let mut seq_id = 0_u64;
loop {
tokio::select! {
_ = disconnect_token.cancelled() => {
break;
}
queued_message = writer_rx.recv() => {
let Some(queued_message) = queued_message else {
break;
};
if server_event_tx
.send(RemoteControlQueuedServerEnvelope {
envelope: ServerEnvelope {
event: ServerEvent::ServerMessage {
message: Box::new(queued_message.message),
},
client_id: client_id.clone(),
seq_id: Some(seq_id),
},
write_complete_tx: queued_message.write_complete_tx,
})
.await
.is_err()
{
break;
}
seq_id = seq_id.wrapping_add(1);
}
}
}
let _ = writer_exited_tx.send(client_id).await;
}
pub(crate) async fn validate_remote_control_auth(auth_manager: &AuthManager) -> io::Result<()> {
load_remote_control_auth(auth_manager).await.map(|_| ())
}
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,188 @@
use crate::outgoing_message::OutgoingMessage;
use codex_app_server_protocol::JSONRPCMessage;
use serde::Deserialize;
use serde::Serialize;
use std::io;
use std::io::ErrorKind;
use url::Url;
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) struct RemoteControlTarget {
pub(super) websocket_url: String,
pub(super) enroll_url: String,
}
#[derive(Debug, Serialize)]
pub(super) struct EnrollRemoteServerRequest {
pub(super) name: String,
pub(super) os: &'static str,
pub(super) arch: &'static str,
pub(super) app_server_version: &'static str,
}
#[derive(Debug, Deserialize)]
pub(super) struct EnrollRemoteServerResponse {
pub(super) server_id: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ClientId(pub String);
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ClientEvent {
ClientMessage {
message: JSONRPCMessage,
},
Ack {
#[serde(rename = "acked_seq_id", alias = "ackedSeqId")]
acked_seq_id: u64,
},
Ping,
ClientClosed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub(crate) struct ClientEnvelope {
#[serde(flatten)]
pub(crate) event: ClientEvent,
#[serde(rename = "client_id", alias = "clientId")]
pub(crate) client_id: ClientId,
#[serde(
rename = "seq_id",
alias = "seqId",
skip_serializing_if = "Option::is_none"
)]
pub(crate) seq_id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) cursor: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PongStatus {
Active,
Unknown,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ServerEvent {
ServerMessage {
message: Box<OutgoingMessage>,
},
#[allow(dead_code)]
Ack,
Pong {
status: PongStatus,
},
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case")]
pub(crate) struct ServerEnvelope {
#[serde(flatten)]
pub(crate) event: ServerEvent,
#[serde(rename = "client_id", alias = "clientId")]
pub(crate) client_id: ClientId,
#[serde(
rename = "seq_id",
alias = "seqId",
skip_serializing_if = "Option::is_none"
)]
pub(crate) seq_id: Option<u64>,
}
pub(super) fn normalize_remote_control_url(
remote_control_url: &str,
) -> io::Result<RemoteControlTarget> {
let map_url_parse_error = |err: url::ParseError| -> io::Error {
io::Error::new(
ErrorKind::InvalidInput,
format!("invalid remote control URL `{remote_control_url}`: {err}"),
)
};
let map_scheme_error = |_: ()| -> io::Error {
io::Error::new(
ErrorKind::InvalidInput,
format!(
"invalid remote control URL `{remote_control_url}`; expected absolute URL with http:// or https:// scheme"
),
)
};
let mut remote_control_url = Url::parse(remote_control_url).map_err(map_url_parse_error)?;
match remote_control_url.scheme() {
"https" | "http" => {}
_ => return Err(map_scheme_error(())),
}
if !remote_control_url.path().ends_with('/') {
let normalized_path = format!("{}/", remote_control_url.path());
remote_control_url.set_path(&normalized_path);
}
let mut enroll_url = remote_control_url
.join("wham/remote/control/server/enroll")
.map_err(map_url_parse_error)?;
let mut websocket_url = remote_control_url
.join("wham/remote/control/server")
.map_err(map_url_parse_error)?;
match remote_control_url.scheme() {
"https" => {
enroll_url.set_scheme("https").map_err(map_scheme_error)?;
websocket_url.set_scheme("wss").map_err(map_scheme_error)?;
}
"http" => {
enroll_url.set_scheme("http").map_err(map_scheme_error)?;
websocket_url.set_scheme("ws").map_err(map_scheme_error)?;
}
_ => return Err(map_scheme_error(())),
}
Ok(RemoteControlTarget {
websocket_url: websocket_url.to_string(),
enroll_url: enroll_url.to_string(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn normalize_remote_control_url_rewrites_http_schemes() {
assert_eq!(
normalize_remote_control_url("http://example.com/backend-api")
.expect("valid http prefix"),
RemoteControlTarget {
websocket_url: "ws://example.com/backend-api/wham/remote/control/server"
.to_string(),
enroll_url: "http://example.com/backend-api/wham/remote/control/server/enroll"
.to_string(),
}
);
assert_eq!(
normalize_remote_control_url("https://example.com/backend-api/")
.expect("valid https prefix"),
RemoteControlTarget {
websocket_url: "wss://example.com/backend-api/wham/remote/control/server"
.to_string(),
enroll_url: "https://example.com/backend-api/wham/remote/control/server/enroll"
.to_string(),
}
);
}
#[test]
fn normalize_remote_control_url_rejects_unsupported_schemes() {
let err = normalize_remote_control_url("ftp://example.com/control")
.expect_err("unsupported scheme should fail");
assert_eq!(
err.to_string(),
"invalid remote control URL `ftp://example.com/control`; expected absolute URL with http:// or https:// scheme"
);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,297 @@
use super::CHANNEL_CAPACITY;
use super::RemoteControlQueuedServerEnvelope;
use super::enroll::connect_remote_control_websocket;
use super::protocol::ClientEnvelope;
use super::protocol::ClientEvent;
use super::protocol::ClientId;
use super::protocol::RemoteControlTarget;
use super::protocol::ServerEnvelope;
use super::protocol::ServerEvent;
use codex_core::AuthManager;
use codex_state::StateRuntime;
use futures::SinkExt;
use futures::StreamExt;
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::time::Duration;
use tokio_tungstenite::tungstenite;
use tokio_util::sync::CancellationToken;
use tracing::error;
use tracing::info;
use tracing::warn;
const REMOTE_CONTROL_RECONNECT_INITIAL_BACKOFF: Duration = Duration::from_secs(1);
const REMOTE_CONTROL_RECONNECT_MAX_BACKOFF: Duration = Duration::from_secs(30);
enum RemoteControlWriteCommand {
ServerEnvelope(RemoteControlQueuedServerEnvelope),
Pong(tungstenite::Bytes),
}
struct BufferedServerEvent {
event: ServerEvent,
write_complete_tx: Option<oneshot::Sender<()>>,
}
#[allow(clippy::print_stderr)]
pub(super) async fn run_remote_control_websocket_loop(
remote_control_target: RemoteControlTarget,
state_db: Option<Arc<StateRuntime>>,
auth_manager: Arc<AuthManager>,
client_event_tx: mpsc::Sender<ClientEnvelope>,
mut server_event_rx: mpsc::Receiver<RemoteControlQueuedServerEnvelope>,
shutdown_token: CancellationToken,
) {
let mut reconnect_backoff = REMOTE_CONTROL_RECONNECT_INITIAL_BACKOFF;
let mut reconnect_attempt = 0_u64;
let mut wait_before_connect = false;
let mut enrollment = None;
let mut outbound_buffer = HashMap::<ClientId, BTreeMap<u64, BufferedServerEvent>>::new();
let mut subscribe_cursor: Option<String> = None;
loop {
if wait_before_connect {
tokio::select! {
_ = shutdown_token.cancelled() => break,
_ = tokio::time::sleep(reconnect_backoff) => {}
}
reconnect_attempt = reconnect_attempt.saturating_add(1);
warn!(
"app-server remote control websocket reconnect attempt {reconnect_attempt} after {reconnect_backoff:?}"
);
reconnect_backoff = reconnect_backoff
.saturating_mul(2)
.min(REMOTE_CONTROL_RECONNECT_MAX_BACKOFF);
} else {
wait_before_connect = true;
}
let websocket_connection = tokio::select! {
_ = shutdown_token.cancelled() => break,
connect_result = connect_remote_control_websocket(
&remote_control_target,
state_db.as_deref(),
auth_manager.as_ref(),
&mut enrollment,
subscribe_cursor.as_deref(),
) => {
match connect_result {
Ok(websocket_connection) => {
reconnect_backoff = REMOTE_CONTROL_RECONNECT_INITIAL_BACKOFF;
reconnect_attempt = 0;
info!(
"connected to app-server remote control websocket: {}",
remote_control_target.websocket_url
);
websocket_connection
}
Err(err) => {
warn!("{err}");
continue;
}
}
}
};
let (mut websocket_writer, mut websocket_reader) =
websocket_connection.websocket_stream.split();
let (write_command_tx, mut write_command_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (reader_event_tx, mut reader_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
let mut buffered_events_to_resend = Vec::new();
for (client_id, buffered_events) in outbound_buffer.iter_mut() {
for (seq_id, buffered_event) in buffered_events.iter_mut() {
buffered_events_to_resend.push(RemoteControlQueuedServerEnvelope {
envelope: ServerEnvelope {
event: buffered_event.event.clone(),
client_id: client_id.clone(),
seq_id: Some(*seq_id),
},
write_complete_tx: buffered_event.write_complete_tx.take(),
});
}
}
let mut write_task = tokio::spawn(async move {
for server_envelope in buffered_events_to_resend {
let payload = match serde_json::to_string(&server_envelope.envelope) {
Ok(payload) => payload,
Err(err) => {
error!("failed to serialize remote-control server event: {err}");
continue;
}
};
info!("YOLO sending to codex backend: {payload}");
if websocket_writer
.send(tungstenite::Message::Text(payload.into()))
.await
.is_err()
{
return;
}
if let Some(write_complete_tx) = server_envelope.write_complete_tx {
let _ = write_complete_tx.send(());
}
}
while let Some(write_command) = write_command_rx.recv().await {
match write_command {
RemoteControlWriteCommand::ServerEnvelope(server_envelope) => {
let payload = match serde_json::to_string(&server_envelope.envelope) {
Ok(payload) => payload,
Err(err) => {
error!("failed to serialize remote-control server event: {err}");
continue;
}
};
info!("YOLO sending to codex backend: {payload}");
if websocket_writer
.send(tungstenite::Message::Text(payload.into()))
.await
.is_err()
{
return;
}
if let Some(write_complete_tx) = server_envelope.write_complete_tx {
let _ = write_complete_tx.send(());
}
}
RemoteControlWriteCommand::Pong(payload) => {
if websocket_writer
.send(tungstenite::Message::Pong(payload))
.await
.is_err()
{
return;
}
}
}
}
});
let write_command_tx_for_reader = write_command_tx.clone();
let mut read_task = tokio::spawn(async move {
while let Some(incoming_message) = websocket_reader.next().await {
match incoming_message {
Ok(tungstenite::Message::Text(text)) => {
if let Ok(client_envelope) = serde_json::from_str::<ClientEnvelope>(&text) {
if reader_event_tx.send(client_envelope).await.is_err() {
return;
}
} else {
warn!("failed to deserialize remote-control client event");
}
}
Ok(tungstenite::Message::Ping(payload)) => {
if write_command_tx_for_reader
.send(RemoteControlWriteCommand::Pong(payload))
.await
.is_err()
{
return;
}
}
Ok(tungstenite::Message::Pong(_)) => {}
Ok(tungstenite::Message::Binary(_)) => {
warn!("dropping unsupported binary remote-control websocket message");
}
Ok(tungstenite::Message::Frame(_)) => {}
Ok(tungstenite::Message::Close(_)) => {
warn!("remote control websocket disconnected");
return;
}
Err(err) => {
warn!("remote control websocket receive failed: {err}");
return;
}
}
}
warn!("remote control websocket disconnected");
});
loop {
tokio::select! {
_ = shutdown_token.cancelled() => {
write_task.abort();
read_task.abort();
return;
}
_ = &mut write_task => {
read_task.abort();
break;
}
_ = &mut read_task => {
write_task.abort();
break;
}
client_envelope = reader_event_rx.recv() => {
let Some(client_envelope) = client_envelope else {
write_task.abort();
read_task.abort();
break;
};
if let Some(cursor) = client_envelope.cursor.as_deref() {
subscribe_cursor = Some(cursor.to_string());
}
if let ClientEvent::Ack { acked_seq_id } = &client_envelope.event
&& let Some(buffered_events) = outbound_buffer.get_mut(&client_envelope.client_id)
{
let acknowledged_seq_ids: Vec<u64> = buffered_events
.range(..=*acked_seq_id)
.map(|(seq_id, _)| *seq_id)
.collect();
for acknowledged_seq_id in acknowledged_seq_ids {
buffered_events.remove(&acknowledged_seq_id);
}
if buffered_events.is_empty() {
outbound_buffer.remove(&client_envelope.client_id);
}
}
if client_event_tx.send(client_envelope).await.is_err() {
write_task.abort();
read_task.abort();
return;
}
}
server_envelope = server_event_rx.recv() => {
let Some(server_envelope) = server_envelope else {
write_task.abort();
read_task.abort();
return;
};
if let ServerEvent::ServerMessage { .. } = &server_envelope.envelope.event
&& let Some(seq_id) = server_envelope.envelope.seq_id
{
outbound_buffer
.entry(server_envelope.envelope.client_id.clone())
.or_default()
.insert(seq_id, BufferedServerEvent {
event: server_envelope.envelope.event.clone(),
write_complete_tx: None,
});
}
if let Err(err) = write_command_tx
.send(RemoteControlWriteCommand::ServerEnvelope(server_envelope))
.await
{
let RemoteControlWriteCommand::ServerEnvelope(server_envelope) = err.0 else {
unreachable!();
};
if let ServerEvent::ServerMessage { .. } = &server_envelope.envelope.event
&& let Some(seq_id) = server_envelope.envelope.seq_id
&& let Some(buffered_events) = outbound_buffer.get_mut(&server_envelope.envelope.client_id)
&& let Some(buffered_event) = buffered_events.get_mut(&seq_id)
{
buffered_event.write_complete_tx = server_envelope.write_complete_tx;
}
write_task.abort();
read_task.abort();
break;
}
}
}
}
}
}

View File

@@ -0,0 +1,88 @@
use super::CHANNEL_CAPACITY;
use super::TransportEvent;
use super::forward_incoming_message;
use super::serialize_outgoing_message;
use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::QueuedOutgoingMessage;
use std::io::ErrorKind;
use std::io::Result as IoResult;
use tokio::io;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tracing::debug;
use tracing::error;
use tracing::info;
pub(crate) async fn start_stdio_connection(
transport_event_tx: mpsc::Sender<TransportEvent>,
stdio_handles: &mut Vec<JoinHandle<()>>,
) -> IoResult<()> {
let connection_id = ConnectionId(0);
let (writer_tx, mut writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
let writer_tx_for_reader = writer_tx.clone();
transport_event_tx
.send(TransportEvent::ConnectionOpened {
connection_id,
writer: writer_tx,
disconnect_sender: None,
})
.await
.map_err(|_| std::io::Error::new(ErrorKind::BrokenPipe, "processor unavailable"))?;
let transport_event_tx_for_reader = transport_event_tx.clone();
stdio_handles.push(tokio::spawn(async move {
let stdin = io::stdin();
let reader = BufReader::new(stdin);
let mut lines = reader.lines();
loop {
match lines.next_line().await {
Ok(Some(line)) => {
if !forward_incoming_message(
&transport_event_tx_for_reader,
&writer_tx_for_reader,
connection_id,
&line,
)
.await
{
break;
}
}
Ok(None) => break,
Err(err) => {
error!("Failed reading stdin: {err}");
break;
}
}
}
let _ = transport_event_tx_for_reader
.send(TransportEvent::ConnectionClosed { connection_id })
.await;
debug!("stdin reader finished (EOF)");
}));
stdio_handles.push(tokio::spawn(async move {
let mut stdout = io::stdout();
while let Some(queued_message) = writer_rx.recv().await {
let Some(mut json) = serialize_outgoing_message(queued_message.message) else {
continue;
};
json.push('\n');
if let Err(err) = stdout.write_all(json.as_bytes()).await {
error!("Failed to write to stdout: {err}");
break;
}
if let Some(write_complete_tx) = queued_message.write_complete_tx {
let _ = write_complete_tx.send(());
}
}
info!("stdout writer exited (channel closed)");
}));
Ok(())
}

View File

@@ -0,0 +1,308 @@
use super::CHANNEL_CAPACITY;
use super::TransportEvent;
use super::auth::WebsocketAuthPolicy;
use super::auth::authorize_upgrade;
use super::auth::should_warn_about_unauthenticated_non_loopback_listener;
use super::forward_incoming_message;
use super::serialize_outgoing_message;
use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::QueuedOutgoingMessage;
use axum::Router;
use axum::body::Body;
use axum::extract::ConnectInfo;
use axum::extract::State;
use axum::extract::ws::Message as WebSocketMessage;
use axum::extract::ws::WebSocket;
use axum::extract::ws::WebSocketUpgrade;
use axum::http::HeaderMap;
use axum::http::Request;
use axum::http::StatusCode;
use axum::http::header::ORIGIN;
use axum::middleware;
use axum::middleware::Next;
use axum::response::IntoResponse;
use axum::response::Response;
use axum::routing::any;
use axum::routing::get;
use futures::SinkExt;
use futures::StreamExt;
use owo_colors::OwoColorize;
use owo_colors::Stream;
use owo_colors::Style;
use std::io::Result as IoResult;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::error;
use tracing::info;
use tracing::warn;
fn colorize(text: &str, style: Style) -> String {
text.if_supports_color(Stream::Stderr, |value| value.style(style))
.to_string()
}
#[allow(clippy::print_stderr)]
fn print_websocket_startup_banner(addr: SocketAddr) {
let title = colorize("codex app-server (WebSockets)", Style::new().bold().cyan());
let listening_label = colorize("listening on:", Style::new().dimmed());
let listen_url = colorize(&format!("ws://{addr}"), Style::new().green());
let ready_label = colorize("readyz:", Style::new().dimmed());
let ready_url = colorize(&format!("http://{addr}/readyz"), Style::new().green());
let health_label = colorize("healthz:", Style::new().dimmed());
let health_url = colorize(&format!("http://{addr}/healthz"), Style::new().green());
let note_label = colorize("note:", Style::new().dimmed());
eprintln!("{title}");
eprintln!(" {listening_label} {listen_url}");
eprintln!(" {ready_label} {ready_url}");
eprintln!(" {health_label} {health_url}");
if addr.ip().is_loopback() {
eprintln!(
" {note_label} binds localhost only (use SSH port-forwarding for remote access)"
);
} else {
eprintln!(
" {note_label} this is a raw WS server; consider running behind TLS/auth for real remote use"
);
}
}
#[derive(Clone)]
struct WebSocketListenerState {
transport_event_tx: mpsc::Sender<TransportEvent>,
connection_counter: Arc<AtomicU64>,
auth_policy: Arc<WebsocketAuthPolicy>,
}
async fn health_check_handler() -> StatusCode {
StatusCode::OK
}
async fn reject_requests_with_origin_header(
request: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
if request.headers().contains_key(ORIGIN) {
warn!(
method = %request.method(),
uri = %request.uri(),
"rejecting websocket listener request with Origin header"
);
Err(StatusCode::FORBIDDEN)
} else {
Ok(next.run(request).await)
}
}
async fn websocket_upgrade_handler(
websocket: WebSocketUpgrade,
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
State(state): State<WebSocketListenerState>,
headers: HeaderMap,
) -> impl IntoResponse {
if let Err(err) = authorize_upgrade(&headers, state.auth_policy.as_ref()) {
warn!(
%peer_addr,
message = err.message(),
"rejecting websocket client during upgrade"
);
return (err.status_code(), err.message()).into_response();
}
let connection_id = ConnectionId(state.connection_counter.fetch_add(1, Ordering::Relaxed));
info!(%peer_addr, "websocket client connected");
websocket
.on_upgrade(move |stream| async move {
run_websocket_connection(connection_id, stream, state.transport_event_tx).await;
})
.into_response()
}
pub(crate) async fn start_websocket_acceptor(
bind_address: SocketAddr,
transport_event_tx: mpsc::Sender<TransportEvent>,
shutdown_token: CancellationToken,
auth_policy: WebsocketAuthPolicy,
) -> IoResult<JoinHandle<()>> {
if should_warn_about_unauthenticated_non_loopback_listener(bind_address, &auth_policy) {
warn!(
%bind_address,
"starting non-loopback websocket listener without auth; websocket auth is opt-in for now and will become the default in a future release"
);
}
let listener = TcpListener::bind(bind_address).await?;
let local_addr = listener.local_addr()?;
print_websocket_startup_banner(local_addr);
info!("app-server websocket listening on ws://{local_addr}");
let router = Router::new()
.route("/readyz", get(health_check_handler))
.route("/healthz", get(health_check_handler))
.fallback(any(websocket_upgrade_handler))
.layer(middleware::from_fn(reject_requests_with_origin_header))
.with_state(WebSocketListenerState {
transport_event_tx,
connection_counter: Arc::new(AtomicU64::new(1)),
auth_policy: Arc::new(auth_policy),
});
let server = axum::serve(
listener,
router.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(async move {
shutdown_token.cancelled().await;
});
Ok(tokio::spawn(async move {
if let Err(err) = server.await {
error!("websocket acceptor failed: {err}");
}
info!("websocket acceptor shutting down");
}))
}
async fn run_websocket_connection(
connection_id: ConnectionId,
websocket_stream: WebSocket,
transport_event_tx: mpsc::Sender<TransportEvent>,
) {
let (writer_tx, writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
let writer_tx_for_reader = writer_tx.clone();
let disconnect_token = CancellationToken::new();
if transport_event_tx
.send(TransportEvent::ConnectionOpened {
connection_id,
writer: writer_tx,
disconnect_sender: Some(disconnect_token.clone()),
})
.await
.is_err()
{
return;
}
let (websocket_writer, websocket_reader) = websocket_stream.split();
let (writer_control_tx, writer_control_rx) =
mpsc::channel::<WebSocketMessage>(CHANNEL_CAPACITY);
let mut outbound_task = tokio::spawn(run_websocket_outbound_loop(
websocket_writer,
writer_rx,
writer_control_rx,
disconnect_token.clone(),
));
let mut inbound_task = tokio::spawn(run_websocket_inbound_loop(
websocket_reader,
transport_event_tx.clone(),
writer_tx_for_reader,
writer_control_tx,
connection_id,
disconnect_token.clone(),
));
tokio::select! {
_ = &mut outbound_task => {
disconnect_token.cancel();
inbound_task.abort();
}
_ = &mut inbound_task => {
disconnect_token.cancel();
outbound_task.abort();
}
}
let _ = transport_event_tx
.send(TransportEvent::ConnectionClosed { connection_id })
.await;
}
async fn run_websocket_outbound_loop(
mut websocket_writer: futures::stream::SplitSink<WebSocket, WebSocketMessage>,
mut writer_rx: mpsc::Receiver<QueuedOutgoingMessage>,
mut writer_control_rx: mpsc::Receiver<WebSocketMessage>,
disconnect_token: CancellationToken,
) {
loop {
tokio::select! {
_ = disconnect_token.cancelled() => {
break;
}
message = writer_control_rx.recv() => {
let Some(message) = message else {
break;
};
if websocket_writer.send(message).await.is_err() {
break;
}
}
queued_message = writer_rx.recv() => {
let Some(queued_message) = queued_message else {
break;
};
let Some(json) = serialize_outgoing_message(queued_message.message) else {
continue;
};
if websocket_writer.send(WebSocketMessage::Text(json.into())).await.is_err() {
break;
}
if let Some(write_complete_tx) = queued_message.write_complete_tx {
let _ = write_complete_tx.send(());
}
}
}
}
}
async fn run_websocket_inbound_loop(
mut websocket_reader: futures::stream::SplitStream<WebSocket>,
transport_event_tx: mpsc::Sender<TransportEvent>,
writer_tx_for_reader: mpsc::Sender<QueuedOutgoingMessage>,
writer_control_tx: mpsc::Sender<WebSocketMessage>,
connection_id: ConnectionId,
disconnect_token: CancellationToken,
) {
loop {
tokio::select! {
_ = disconnect_token.cancelled() => {
break;
}
incoming_message = websocket_reader.next() => {
match incoming_message {
Some(Ok(WebSocketMessage::Text(text))) => {
if !forward_incoming_message(
&transport_event_tx,
&writer_tx_for_reader,
connection_id,
text.as_ref(),
)
.await
{
break;
}
}
Some(Ok(WebSocketMessage::Ping(payload))) => {
match writer_control_tx.try_send(WebSocketMessage::Pong(payload)) {
Ok(()) => {}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => break,
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
warn!("websocket control queue full while replying to ping; closing connection");
break;
}
}
}
Some(Ok(WebSocketMessage::Pong(_))) => {}
Some(Ok(WebSocketMessage::Close(_))) | None => break,
Some(Ok(WebSocketMessage::Binary(_))) => {
warn!("dropping unsupported binary websocket message");
}
Some(Err(err)) => {
warn!("websocket receive error: {err}");
break;
}
}
}
}
}
}

View File

@@ -1,8 +1,10 @@
use anyhow::Result;
use app_test_support::ChatGptAuthFixture;
use app_test_support::McpProcess;
use app_test_support::test_path_buf_with_windows;
use app_test_support::test_tmp_path_buf;
use app_test_support::to_response;
use app_test_support::write_chatgpt_auth;
use codex_app_server_protocol::AppConfig;
use codex_app_server_protocol::AppToolApproval;
use codex_app_server_protocol::AppsConfig;
@@ -21,6 +23,7 @@ use codex_app_server_protocol::RequestId;
use codex_app_server_protocol::SandboxMode;
use codex_app_server_protocol::ToolsV2;
use codex_app_server_protocol::WriteStatus;
use codex_core::auth::AuthCredentialsStoreMode;
use codex_core::config::set_project_trust_level;
use codex_protocol::config_types::TrustLevel;
use codex_protocol::config_types::WebSearchContextSize;
@@ -88,6 +91,69 @@ sandbox_mode = "workspace-write"
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn config_read_includes_chatgpt_base_url() -> Result<()> {
let codex_home = TempDir::new()?;
write_config(
&codex_home,
r#"
chatgpt_base_url = "https://example.com/backend-api/"
[features]
remote_control = true
"#,
)?;
write_chatgpt_auth(
codex_home.path(),
ChatGptAuthFixture::new("chatgpt-token"),
AuthCredentialsStoreMode::File,
)?;
let codex_home_path = codex_home.path().canonicalize()?;
let user_file = AbsolutePathBuf::try_from(codex_home_path.join("config.toml"))?;
let mut mcp = McpProcess::new(codex_home.path()).await?;
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
let request_id = mcp
.send_config_read_request(ConfigReadParams {
include_layers: true,
cwd: None,
})
.await?;
let resp: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
)
.await??;
let ConfigReadResponse {
config,
origins,
layers,
} = to_response(resp)?;
assert_eq!(
config.additional.get("chatgpt_base_url"),
Some(&json!("https://example.com/backend-api/"))
);
assert_eq!(
origins.get("chatgpt_base_url").expect("origin").name,
ConfigLayerSource::User {
file: user_file.clone(),
}
);
assert_eq!(
origins.get("features.remote_control").expect("origin").name,
ConfigLayerSource::User {
file: user_file.clone(),
}
);
let layers = layers.expect("layers present");
assert_layers_user_then_optional_system(&layers, user_file)?;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn config_read_includes_tools() -> Result<()> {
let codex_home = TempDir::new()?;

View File

@@ -328,6 +328,7 @@ impl CloudRequirementsService {
return Ok(None);
};
if !auth.is_chatgpt_auth()
|| auth.is_external_chatgpt_tokens()
|| !matches!(
auth.account_plan_type(),
Some(PlanType::Business | PlanType::Enterprise)

View File

@@ -437,6 +437,9 @@
"realtime_conversation": {
"type": "boolean"
},
"remote_control": {
"type": "boolean"
},
"remote_models": {
"type": "boolean"
},
@@ -2054,6 +2057,9 @@
"realtime_conversation": {
"type": "boolean"
},
"remote_control": {
"type": "boolean"
},
"remote_models": {
"type": "boolean"
},

View File

@@ -176,6 +176,8 @@ pub enum Feature {
VoiceTranscription,
/// Enable experimental realtime voice conversation mode in the TUI.
RealtimeConversation,
/// Connect app-server to the ChatGPT remote control service.
RemoteControl,
/// Route interactive startup to the app-server-backed TUI implementation.
TuiAppServer,
/// Prevent idle system sleep while a turn is actively running.
@@ -819,6 +821,12 @@ pub const FEATURES: &[FeatureSpec] = &[
stage: Stage::UnderDevelopment,
default_enabled: false,
},
FeatureSpec {
id: Feature::RemoteControl,
key: "remote_control",
stage: Stage::UnderDevelopment,
default_enabled: false,
},
FeatureSpec {
id: Feature::TuiAppServer,
key: "tui_app_server",

View File

@@ -159,6 +159,12 @@ fn image_detail_original_feature_is_under_development() {
assert_eq!(Feature::ImageDetailOriginal.default_enabled(), false);
}
#[test]
fn remote_control_is_under_development() {
assert_eq!(Feature::RemoteControl.stage(), Stage::UnderDevelopment);
assert_eq!(Feature::RemoteControl.default_enabled(), false);
}
#[test]
fn collab_is_legacy_alias_for_multi_agent() {
assert_eq!(feature_for_key("multi_agent"), Some(Feature::Collab));

View File

@@ -0,0 +1,8 @@
CREATE TABLE remote_control_enrollments (
websocket_url TEXT NOT NULL,
account_id TEXT NOT NULL,
server_id TEXT NOT NULL,
server_name TEXT NOT NULL,
updated_at INTEGER NOT NULL,
PRIMARY KEY (websocket_url, account_id)
);

View File

@@ -53,6 +53,7 @@ mod agent_jobs;
mod backfill;
mod logs;
mod memories;
mod remote_control;
#[cfg(test)]
mod test_support;
mod threads;

View File

@@ -0,0 +1,197 @@
use super::*;
const REMOTE_CONTROL_ACCOUNT_ID_NONE: &str = "";
fn remote_control_account_id_key(account_id: Option<&str>) -> &str {
account_id.unwrap_or(REMOTE_CONTROL_ACCOUNT_ID_NONE)
}
impl StateRuntime {
pub async fn get_remote_control_enrollment(
&self,
websocket_url: &str,
account_id: Option<&str>,
) -> anyhow::Result<Option<(String, String)>> {
let row = sqlx::query(
r#"
SELECT server_id, server_name
FROM remote_control_enrollments
WHERE websocket_url = ? AND account_id = ?
"#,
)
.bind(websocket_url)
.bind(remote_control_account_id_key(account_id))
.fetch_optional(self.pool.as_ref())
.await?;
row.map(|row| Ok((row.try_get("server_id")?, row.try_get("server_name")?)))
.transpose()
}
pub async fn upsert_remote_control_enrollment(
&self,
websocket_url: &str,
account_id: Option<&str>,
server_id: &str,
server_name: &str,
) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO remote_control_enrollments (
websocket_url,
account_id,
server_id,
server_name,
updated_at
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(websocket_url, account_id) DO UPDATE SET
server_id = excluded.server_id,
server_name = excluded.server_name,
updated_at = excluded.updated_at
"#,
)
.bind(websocket_url)
.bind(remote_control_account_id_key(account_id))
.bind(server_id)
.bind(server_name)
.bind(Utc::now().timestamp())
.execute(self.pool.as_ref())
.await?;
Ok(())
}
pub async fn delete_remote_control_enrollment(
&self,
websocket_url: &str,
account_id: Option<&str>,
) -> anyhow::Result<u64> {
let result = sqlx::query(
r#"
DELETE FROM remote_control_enrollments
WHERE websocket_url = ? AND account_id = ?
"#,
)
.bind(websocket_url)
.bind(remote_control_account_id_key(account_id))
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected())
}
}
#[cfg(test)]
mod tests {
use super::StateRuntime;
use super::test_support::unique_temp_dir;
use pretty_assertions::assert_eq;
#[tokio::test]
async fn remote_control_enrollment_round_trips_by_target_and_account() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string())
.await
.expect("initialize runtime");
runtime
.upsert_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
Some("account-a"),
"srv_e_first",
"first-server",
)
.await
.expect("insert first enrollment");
runtime
.upsert_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
Some("account-b"),
"srv_e_second",
"second-server",
)
.await
.expect("insert second enrollment");
assert_eq!(
runtime
.get_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
Some("account-a"),
)
.await
.expect("load first enrollment"),
Some(("srv_e_first".to_string(), "first-server".to_string()))
);
assert_eq!(
runtime
.get_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
None,
)
.await
.expect("load missing enrollment"),
None
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn delete_remote_control_enrollment_removes_only_matching_entry() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string())
.await
.expect("initialize runtime");
runtime
.upsert_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
None,
"srv_e_first",
"first-server",
)
.await
.expect("insert first enrollment");
runtime
.upsert_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
Some("account-a"),
"srv_e_second",
"second-server",
)
.await
.expect("insert second enrollment");
assert_eq!(
runtime
.delete_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
None,
)
.await
.expect("delete first enrollment"),
1
);
assert_eq!(
runtime
.get_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
None,
)
.await
.expect("load deleted enrollment"),
None
);
assert_eq!(
runtime
.get_remote_control_enrollment(
"wss://example.com/backend-api/wham/remote/control/server",
Some("account-a"),
)
.await
.expect("load retained enrollment"),
Some(("srv_e_second".to_string(), "second-server".to_string()))
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
}