Compare commits

...

2 Commits

Author SHA1 Message Date
Anton Panasenko
c991b670ee update 2026-03-25 16:04:07 -07:00
Anton Panasenko
83056d0474 feat(core) slingshot 2026-03-20 16:13:59 -07:00
22 changed files with 3172 additions and 115 deletions

2
codex-rs/Cargo.lock generated
View File

@@ -1440,8 +1440,10 @@ dependencies = [
"codex-utils-cli",
"codex-utils-json-to-toml",
"codex-utils-pty",
"codex-utils-rustls-provider",
"core_test_support",
"futures",
"gethostname",
"opentelemetry",
"opentelemetry_sdk",
"owo-colors",

View File

@@ -47,9 +47,11 @@ codex-rmcp-client = { workspace = true }
codex-state = { workspace = true }
codex-utils-absolute-path = { workspace = true }
codex-utils-json-to-toml = { workspace = true }
codex-utils-rustls-provider = { workspace = true }
chrono = { workspace = true }
clap = { workspace = true, features = ["derive"] }
futures = { workspace = true }
gethostname = { workspace = true }
owo-colors = { workspace = true, features = ["supports-colors"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }

View File

@@ -2,6 +2,8 @@
`codex app-server` is the interface Codex uses to power rich interfaces such as the [Codex VS Code extension](https://marketplace.visualstudio.com/items?itemName=openai.chatgpt).
For remote-control-only deployments, use `codexd`. It runs the same app-server runtime in a headless daemon mode, connects outbound to the ChatGPT remote control server using ChatGPT auth, and does not expose a local stdio or websocket transport.
## Table of Contents
- [Protocol](#protocol)
@@ -25,6 +27,24 @@ Supported transports:
- stdio (`--listen stdio://`, default): newline-delimited JSON (JSONL)
- websocket (`--listen ws://IP:PORT`): one JSON-RPC message per websocket text frame (**experimental / unsupported**)
- remote control (`--with-remote-control`): also connect outbound to the ChatGPT remote control server derived from `chatgpt_base_url`
You can combine a local transport with remote control in the same process:
```sh
codex app-server --listen stdio:// --with-remote-control
codex app-server --listen ws://127.0.0.1:8080 --with-remote-control
```
Both local and remote-controlled clients share the same in-process app-server state, and remote-controlled clients still use the normal JSON-RPC connection lifecycle, including `initialize`.
For remote-control-only deployments, keep using `codexd`:
```sh
codexd
```
`codexd` runs the same runtime with no local listener and remote control enabled by default.
When running with `--listen ws://IP:PORT`, the same listener also serves basic HTTP health probes:

View File

@@ -29,12 +29,8 @@ pub(crate) fn request_span(
) -> Span {
let initialize_client_info = initialize_client_info(request);
let method = request.method.as_str();
let span = app_server_request_span_template(
method,
transport_name(transport),
&request.id,
connection_id,
);
let span =
app_server_request_span_template(method, transport.as_str(), &request.id, connection_id);
record_client_info(
&span,
@@ -82,13 +78,6 @@ pub(crate) fn typed_request_span(
span
}
fn transport_name(transport: AppServerTransport) -> &'static str {
match transport {
AppServerTransport::Stdio => "stdio",
AppServerTransport::WebSocket { .. } => "websocket",
}
}
fn app_server_request_span_template(
method: &str,
transport: &'static str,

View File

@@ -23,6 +23,7 @@ use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::OutgoingEnvelope;
use crate::outgoing_message::OutgoingMessageSender;
use crate::transport::CHANNEL_CAPACITY;
use crate::transport::ConnectionIdAllocator;
use crate::transport::ConnectionState;
use crate::transport::OutboundConnectionState;
use crate::transport::TransportEvent;
@@ -76,6 +77,8 @@ mod thread_state;
mod thread_status;
mod transport;
mod remote_control;
pub use crate::error_code::INPUT_TOO_LARGE_ERROR_CODE;
pub use crate::error_code::INVALID_PARAMS_ERROR_CODE;
pub use crate::transport::AppServerTransport;
@@ -330,12 +333,32 @@ pub async fn run_main(
loader_overrides: LoaderOverrides,
default_analytics_enabled: bool,
) -> IoResult<()> {
run_main_with_transport(
run_main_with_runtime(
arg0_paths,
cli_config_overrides,
loader_overrides,
default_analytics_enabled,
AppServerTransport::Stdio,
Some(AppServerTransport::Stdio),
false,
)
.await
}
pub async fn run_main_with_runtime(
arg0_paths: Arg0DispatchPaths,
cli_config_overrides: CliConfigOverrides,
loader_overrides: LoaderOverrides,
default_analytics_enabled: bool,
local_transport: Option<AppServerTransport>,
with_remote_control: bool,
) -> IoResult<()> {
run_main_with_transport_impl(
arg0_paths,
cli_config_overrides,
loader_overrides,
default_analytics_enabled,
local_transport,
with_remote_control,
)
.await
}
@@ -347,43 +370,64 @@ pub async fn run_main_with_transport(
default_analytics_enabled: bool,
transport: AppServerTransport,
) -> IoResult<()> {
match transport {
AppServerTransport::Stdio => {
run_main_with_runtime(
arg0_paths,
cli_config_overrides,
loader_overrides,
default_analytics_enabled,
Some(AppServerTransport::Stdio),
false,
)
.await
}
AppServerTransport::WebSocket { bind_address } => {
run_main_with_runtime(
arg0_paths,
cli_config_overrides,
loader_overrides,
default_analytics_enabled,
Some(AppServerTransport::WebSocket { bind_address }),
false,
)
.await
}
AppServerTransport::RemoteControlled => {
run_main_with_runtime(
arg0_paths,
cli_config_overrides,
loader_overrides,
default_analytics_enabled,
None,
true,
)
.await
}
}
}
async fn run_main_with_transport_impl(
arg0_paths: Arg0DispatchPaths,
cli_config_overrides: CliConfigOverrides,
loader_overrides: LoaderOverrides,
default_analytics_enabled: bool,
local_transport: Option<AppServerTransport>,
with_remote_control: bool,
) -> IoResult<()> {
if matches!(local_transport, Some(AppServerTransport::RemoteControlled)) {
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
"remote-controlled transport cannot be used as a local listener",
));
}
let (transport_event_tx, mut transport_event_rx) =
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingEnvelope>(CHANNEL_CAPACITY);
let (outbound_control_tx, mut outbound_control_rx) =
mpsc::channel::<OutboundControlEvent>(CHANNEL_CAPACITY);
enum TransportRuntime {
Stdio,
WebSocket {
accept_handle: JoinHandle<()>,
shutdown_token: CancellationToken,
},
}
let mut stdio_handles = Vec::<JoinHandle<()>>::new();
let transport_runtime = match transport {
AppServerTransport::Stdio => {
start_stdio_connection(transport_event_tx.clone(), &mut stdio_handles).await?;
TransportRuntime::Stdio
}
AppServerTransport::WebSocket { bind_address } => {
let shutdown_token = CancellationToken::new();
let accept_handle = start_websocket_acceptor(
bind_address,
transport_event_tx.clone(),
shutdown_token.clone(),
)
.await?;
TransportRuntime::WebSocket {
accept_handle,
shutdown_token,
}
}
};
let single_client_mode = matches!(&transport_runtime, TransportRuntime::Stdio);
let shutdown_when_no_connections = single_client_mode;
let graceful_signal_restart_enabled = !single_client_mode;
// Parse CLI overrides once and derive the base Config eagerly so later
// components do not need to work with raw TOML values.
let cli_kv_overrides = cli_config_overrides.parse_overrides().map_err(|e| {
@@ -466,6 +510,49 @@ pub async fn run_main_with_transport(
config_warnings.push(message);
}
let transport_shutdown_token = CancellationToken::new();
let mut transport_accept_handles = Vec::<JoinHandle<()>>::new();
let connection_id_allocator = ConnectionIdAllocator::default();
match local_transport {
Some(AppServerTransport::Stdio) => {
start_stdio_connection(transport_event_tx.clone(), &mut transport_accept_handles)
.await?;
}
Some(AppServerTransport::WebSocket { bind_address }) => {
transport_accept_handles.push(
start_websocket_acceptor(
bind_address,
transport_event_tx.clone(),
transport_shutdown_token.clone(),
connection_id_allocator.clone(),
)
.await?,
);
}
Some(AppServerTransport::RemoteControlled) => unreachable!(),
None => {}
}
if with_remote_control {
transport_accept_handles.push(
remote_control::start_remote_control(
config.chatgpt_base_url.clone(),
config.codex_home.clone(),
AuthManager::shared(
config.codex_home.clone(),
false,
config.cli_auth_credentials_store_mode,
),
transport_event_tx.clone(),
transport_shutdown_token.clone(),
connection_id_allocator,
)
.await?,
);
}
let shutdown_when_stdio_disconnects =
matches!(local_transport, Some(AppServerTransport::Stdio));
let graceful_signal_restart_enabled = !shutdown_when_stdio_disconnects;
if let Some(warning) = project_config_warning(&config) {
config_warnings.push(warning);
}
@@ -619,10 +706,7 @@ pub async fn run_main_with_transport(
let mut thread_created_rx = processor.thread_created_receiver();
let mut running_turn_count_rx = processor.subscribe_running_assistant_turn_count();
let mut connections = HashMap::<ConnectionId, ConnectionState>::new();
let websocket_accept_shutdown = match &transport_runtime {
TransportRuntime::WebSocket { shutdown_token, .. } => Some(shutdown_token.clone()),
TransportRuntime::Stdio => None,
};
let transport_shutdown_token = transport_shutdown_token.clone();
async move {
let mut listen_for_threads = true;
let mut shutdown_state = ShutdownState::default();
@@ -635,9 +719,7 @@ pub async fn run_main_with_transport(
shutdown_state.update(running_turn_count, connections.len()),
ShutdownAction::Finish
) {
if let Some(shutdown_token) = &websocket_accept_shutdown {
shutdown_token.cancel();
}
transport_shutdown_token.cancel();
let _ = outbound_control_tx
.send(OutboundControlEvent::DisconnectAll)
.await;
@@ -664,10 +746,16 @@ pub async fn run_main_with_transport(
match event {
TransportEvent::ConnectionOpened {
connection_id,
transport_kind,
writer,
allow_legacy_notifications,
disconnect_sender,
} => {
info!(
connection_id = %connection_id,
transport = transport_kind.as_str(),
"app-server connection opened"
);
let outbound_initialized = Arc::new(AtomicBool::new(false));
let outbound_experimental_api_enabled =
Arc::new(AtomicBool::new(false));
@@ -695,13 +783,22 @@ pub async fn run_main_with_transport(
connections.insert(
connection_id,
ConnectionState::new(
transport_kind,
outbound_initialized,
outbound_experimental_api_enabled,
outbound_opted_out_notification_methods,
),
);
}
TransportEvent::ConnectionClosed { connection_id } => {
TransportEvent::ConnectionClosed {
connection_id,
transport_kind,
} => {
info!(
connection_id = %connection_id,
transport = transport_kind.as_str(),
"app-server connection closed"
);
if connections.remove(&connection_id).is_none() {
continue;
}
@@ -713,7 +810,8 @@ pub async fn run_main_with_transport(
break;
}
processor.connection_closed(connection_id).await;
if shutdown_when_no_connections && connections.is_empty() {
if shutdown_when_stdio_disconnects && connection_id == ConnectionId(0)
{
break;
}
}
@@ -729,7 +827,7 @@ pub async fn run_main_with_transport(
.process_request(
connection_id,
request,
transport,
connection_state.transport_kind,
&mut connection_state.session,
)
.await;
@@ -833,16 +931,8 @@ pub async fn run_main_with_transport(
let _ = processor_handle.await;
let _ = outbound_handle.await;
if let TransportRuntime::WebSocket {
accept_handle,
shutdown_token,
} = transport_runtime
{
shutdown_token.cancel();
let _ = accept_handle.await;
}
for handle in stdio_handles {
transport_shutdown_token.cancel();
for handle in transport_accept_handles {
let _ = handle.await;
}

View File

@@ -1,6 +1,6 @@
use clap::Parser;
use codex_app_server::AppServerTransport;
use codex_app_server::run_main_with_transport;
use codex_app_server::run_main_with_runtime;
use codex_arg0::Arg0DispatchPaths;
use codex_arg0::arg0_dispatch_or_else;
use codex_core::config_loader::LoaderOverrides;
@@ -21,6 +21,11 @@ struct AppServerArgs {
default_value = AppServerTransport::DEFAULT_LISTEN_URL
)]
listen: AppServerTransport,
/// Also connect outbound to the ChatGPT remote control server derived from
/// the configured `chatgpt_base_url`.
#[arg(long = "with-remote-control", default_value_t = false)]
with_remote_control: bool,
}
fn main() -> anyhow::Result<()> {
@@ -31,14 +36,13 @@ fn main() -> anyhow::Result<()> {
managed_config_path,
..Default::default()
};
let transport = args.listen;
run_main_with_transport(
run_main_with_runtime(
arg0_paths,
CliConfigOverrides::default(),
loader_overrides,
false,
transport,
Some(args.listen),
args.with_remote_control,
)
.await?;
Ok(())
@@ -59,3 +63,34 @@ fn managed_config_path_from_debug_env() -> Option<PathBuf> {
None
}
#[cfg(test)]
mod tests {
use super::AppServerArgs;
use clap::Parser;
use pretty_assertions::assert_eq;
#[test]
fn app_server_args_default_to_stdio_without_remote_control() {
let args = AppServerArgs::parse_from(["codex-app-server"]);
assert_eq!(args.listen, super::AppServerTransport::Stdio);
assert!(!args.with_remote_control);
}
#[test]
fn app_server_args_parse_with_remote_control_flag() {
let args = AppServerArgs::parse_from([
"codex-app-server",
"--listen",
"ws://127.0.0.1:8080",
"--with-remote-control",
]);
assert_eq!(
args.listen,
super::AppServerTransport::WebSocket {
bind_address: "127.0.0.1:8080".parse().expect("valid socket address"),
}
);
assert!(args.with_remote_control);
}
}

View File

@@ -9,6 +9,7 @@ use app_test_support::create_mock_responses_server_repeating_assistant;
use app_test_support::write_mock_responses_config_toml;
use codex_app_server_protocol::ClientInfo;
use codex_app_server_protocol::ClientRequest;
use codex_app_server_protocol::ConfigReadParams;
use codex_app_server_protocol::InitializeCapabilities;
use codex_app_server_protocol::InitializeParams;
use codex_app_server_protocol::InitializeResponse;
@@ -164,6 +165,19 @@ impl TracingHarness {
}
async fn request<T>(&mut self, request: ClientRequest, trace: Option<W3cTraceContext>) -> T
where
T: serde::de::DeserializeOwned,
{
self.request_with_transport(request, trace, AppServerTransport::Stdio)
.await
}
async fn request_with_transport<T>(
&mut self,
request: ClientRequest,
trace: Option<W3cTraceContext>,
transport: AppServerTransport,
) -> T
where
T: serde::de::DeserializeOwned,
{
@@ -175,12 +189,7 @@ impl TracingHarness {
request.trace = trace;
self.processor
.process_request(
TEST_CONNECTION_ID,
request,
AppServerTransport::Stdio,
&mut self.session,
)
.process_request(TEST_CONNECTION_ID, request, transport, &mut self.session)
.await;
read_response(&mut self.outgoing_rx, request_id).await
}
@@ -566,6 +575,73 @@ async fn thread_start_jsonrpc_span_exports_server_span_and_parents_children() ->
Ok(())
}
#[tokio::test(flavor = "current_thread")]
async fn request_spans_record_per_connection_transport_kind() -> Result<()> {
let _guard = tracing_test_guard().lock().await;
let mut harness = TracingHarness::new().await?;
for (offset, transport, expected_transport) in [
(0_i64, AppServerTransport::Stdio, "stdio"),
(
1_i64,
AppServerTransport::WebSocket {
bind_address: "127.0.0.1:8080".parse().expect("valid socket address"),
},
"websocket",
),
(
2_i64,
AppServerTransport::RemoteControlled,
"remote_controlled",
),
] {
harness.reset_tracing();
let baseline_len = harness
.tracing
.exporter
.get_finished_spans()
.expect("span export")
.len();
let _: serde_json::Value = harness
.request_with_transport(
ClientRequest::ConfigRead {
request_id: RequestId::Integer(30_100 + offset),
params: ConfigReadParams {
include_layers: false,
cwd: None,
},
},
None,
transport,
)
.await;
let spans = wait_for_new_exported_spans(harness.tracing, baseline_len, |spans| {
spans.iter().any(|span| {
span.span_kind == SpanKind::Server
&& span_attr(span, "rpc.method") == Some("config/read")
&& span_attr(span, "rpc.transport") == Some(expected_transport)
})
})
.await;
let server_span = spans
.iter()
.find(|span| {
span.span_kind == SpanKind::Server
&& span_attr(span, "rpc.method") == Some("config/read")
})
.expect("config/read server span should be exported");
assert_eq!(
span_attr(server_span, "rpc.transport"),
Some(expected_transport)
);
}
harness.shutdown().await;
Ok(())
}
#[tokio::test(flavor = "current_thread")]
async fn turn_start_jsonrpc_span_parents_core_turn_spans() -> Result<()> {
let _guard = tracing_test_guard().lock().await;

View File

@@ -0,0 +1,255 @@
use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::OutgoingMessage;
use crate::transport::AppServerTransport;
use crate::transport::CHANNEL_CAPACITY;
use crate::transport::ConnectionIdAllocator;
use crate::transport::TransportEvent;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCRequest;
use std::collections::HashMap;
use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio::time::Instant;
use tokio::time::MissedTickBehavior;
use tokio_util::sync::CancellationToken;
use super::ClientEvent;
use super::ClientId;
use super::PongStatus;
use super::ServerEvent;
const REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT: Duration = Duration::from_secs(10 * 60);
const REMOTE_CONTROL_IDLE_SWEEP_INTERVAL: Duration = Duration::from_secs(30);
struct RemoteControlClientState {
connection_id: ConnectionId,
disconnect_token: CancellationToken,
last_activity_at: Instant,
}
pub(super) async fn run(
transport_event_tx: mpsc::Sender<TransportEvent>,
mut client_event_rx: mpsc::Receiver<ClientEvent>,
server_event_tx: mpsc::Sender<ServerEvent>,
writer_exited_tx: mpsc::Sender<ClientId>,
mut writer_exited_rx: mpsc::Receiver<ClientId>,
shutdown_token: CancellationToken,
connection_id_allocator: ConnectionIdAllocator,
) {
let mut clients = HashMap::<ClientId, RemoteControlClientState>::new();
let mut idle_sweep = tokio::time::interval(REMOTE_CONTROL_IDLE_SWEEP_INTERVAL);
idle_sweep.set_missed_tick_behavior(MissedTickBehavior::Delay);
loop {
tokio::select! {
_ = shutdown_token.cancelled() => {
break;
}
_ = idle_sweep.tick() => {
if !close_expired_remote_control_clients(&transport_event_tx, &mut clients).await {
break;
}
}
writer_exited = writer_exited_rx.recv() => {
let Some(client_id) = writer_exited else {
break;
};
if !close_remote_control_client(&transport_event_tx, &mut clients, &client_id).await {
break;
}
}
client_event = client_event_rx.recv() => {
let Some(client_event) = client_event else {
break;
};
match client_event {
ClientEvent::ClientMessage { client_id, message } => {
if let Some(connection_id) = clients.get_mut(&client_id).map(|client| {
client.last_activity_at = Instant::now();
client.connection_id
}) {
if transport_event_tx
.send(TransportEvent::IncomingMessage {
connection_id,
message,
})
.await
.is_err()
{
break;
}
continue;
}
if !remote_control_message_starts_connection(&message) {
continue;
}
let connection_id = connection_id_allocator.next_connection_id();
let (writer_tx, writer_rx) =
mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
let disconnect_token = CancellationToken::new();
if transport_event_tx
.send(TransportEvent::ConnectionOpened {
connection_id,
transport_kind: AppServerTransport::RemoteControlled,
writer: writer_tx,
allow_legacy_notifications: false,
disconnect_sender: Some(disconnect_token.clone()),
})
.await
.is_err()
{
break;
}
tokio::spawn(run_remote_control_client_outbound(
client_id.clone(),
writer_rx,
server_event_tx.clone(),
writer_exited_tx.clone(),
disconnect_token.clone(),
));
clients.insert(
client_id,
RemoteControlClientState {
connection_id,
disconnect_token,
last_activity_at: Instant::now(),
},
);
if transport_event_tx
.send(TransportEvent::IncomingMessage {
connection_id,
message,
})
.await
.is_err()
{
break;
}
}
ClientEvent::Ping { client_id, .. } => {
let status = match clients.get_mut(&client_id) {
Some(client) => {
client.last_activity_at = Instant::now();
PongStatus::Active
}
None => PongStatus::Unknown,
};
if server_event_tx
.send(ServerEvent::Pong { client_id, status })
.await
.is_err()
{
break;
}
}
ClientEvent::ClientClosed { client_id } => {
if !close_remote_control_client(&transport_event_tx, &mut clients, &client_id)
.await
{
break;
}
}
}
}
}
}
while let Some(client_id) = clients.keys().next().cloned() {
if !close_remote_control_client(&transport_event_tx, &mut clients, &client_id).await {
break;
}
}
}
fn remote_control_message_starts_connection(message: &JSONRPCMessage) -> bool {
matches!(
message,
JSONRPCMessage::Request(JSONRPCRequest { method, .. }) if method == "initialize"
)
}
fn remote_control_client_is_alive(client: &RemoteControlClientState, now: Instant) -> bool {
now.duration_since(client.last_activity_at) < REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT
}
async fn close_expired_remote_control_clients(
transport_event_tx: &mpsc::Sender<TransportEvent>,
clients: &mut HashMap<ClientId, RemoteControlClientState>,
) -> bool {
let now = Instant::now();
let expired_client_ids: Vec<ClientId> = clients
.iter()
.filter_map(|(client_id, client)| {
(!remote_control_client_is_alive(client, now)).then_some(client_id.clone())
})
.collect();
for client_id in expired_client_ids {
if !close_remote_control_client(transport_event_tx, clients, &client_id).await {
return false;
}
}
true
}
async fn close_remote_control_client(
transport_event_tx: &mpsc::Sender<TransportEvent>,
clients: &mut HashMap<ClientId, RemoteControlClientState>,
client_id: &ClientId,
) -> bool {
let Some(client) = clients.remove(client_id) else {
return true;
};
client.disconnect_token.cancel();
transport_event_tx
.send(TransportEvent::ConnectionClosed {
connection_id: client.connection_id,
transport_kind: AppServerTransport::RemoteControlled,
})
.await
.is_ok()
}
async fn run_remote_control_client_outbound(
client_id: ClientId,
mut writer_rx: mpsc::Receiver<OutgoingMessage>,
server_event_tx: mpsc::Sender<ServerEvent>,
writer_exited_tx: mpsc::Sender<ClientId>,
disconnect_token: CancellationToken,
) {
loop {
tokio::select! {
_ = disconnect_token.cancelled() => {
break;
}
outgoing_message = writer_rx.recv() => {
let Some(outgoing_message) = outgoing_message else {
break;
};
if server_event_tx
.send(ServerEvent::ServerMessage {
client_id: client_id.clone(),
message: Box::new(outgoing_message),
})
.await
.is_err()
{
break;
}
}
}
}
let _ = writer_exited_tx.send(client_id).await;
}
#[cfg(test)]
#[path = "client_manager_tests.rs"]
mod client_manager_tests;

View File

@@ -0,0 +1,321 @@
use super::super::ClientActivityState;
use super::super::ClientEvent;
use super::super::ClientId;
use super::super::start_remote_control;
use super::super::test_support::accept_http_request;
use super::super::test_support::accept_remote_control_backend_connection;
use super::super::test_support::read_server_event;
use super::super::test_support::remote_control_auth_manager;
use super::super::test_support::respond_with_json;
use super::super::test_support::send_client_event;
use super::*;
use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::OutgoingMessage;
use crate::outgoing_message::OutgoingNotification;
use crate::transport::AppServerTransport;
use crate::transport::CHANNEL_CAPACITY;
use crate::transport::ConnectionIdAllocator;
use crate::transport::TransportEvent;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::RequestId;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::collections::HashMap;
use tempfile::TempDir;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio::time::timeout;
use tokio_util::sync::CancellationToken;
#[tokio::test]
async fn close_expired_remote_control_clients_closes_only_stale_connections() {
let (transport_event_tx, mut transport_event_rx) =
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
let stale_disconnect_token = CancellationToken::new();
let fresh_disconnect_token = CancellationToken::new();
let stale_client_id = ClientId("stale-client".to_string());
let fresh_client_id = ClientId("fresh-client".to_string());
let now = tokio::time::Instant::now();
let mut clients = HashMap::from([
(
stale_client_id.clone(),
RemoteControlClientState {
connection_id: ConnectionId(11),
disconnect_token: stale_disconnect_token.clone(),
last_activity_at: now - REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT,
},
),
(
fresh_client_id.clone(),
RemoteControlClientState {
connection_id: ConnectionId(12),
disconnect_token: fresh_disconnect_token.clone(),
last_activity_at: now,
},
),
]);
assert!(close_expired_remote_control_clients(&transport_event_tx, &mut clients).await);
assert!(stale_disconnect_token.is_cancelled());
assert!(!fresh_disconnect_token.is_cancelled());
assert!(!clients.contains_key(&stale_client_id));
assert!(clients.contains_key(&fresh_client_id));
match timeout(Duration::from_secs(5), transport_event_rx.recv())
.await
.expect("stale client close should arrive in time")
.expect("stale client close should exist")
{
TransportEvent::ConnectionClosed {
connection_id,
transport_kind,
} => {
assert_eq!(connection_id, ConnectionId(11));
assert_eq!(transport_kind, AppServerTransport::RemoteControlled);
}
other => panic!("expected stale client close event, got {other:?}"),
}
assert!(
timeout(Duration::from_millis(100), transport_event_rx.recv())
.await
.is_err(),
"fresh clients should remain connected during stale sweep"
);
}
#[tokio::test]
async fn remote_control_transport_manages_virtual_clients_and_routes_messages() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = format!(
"http://{}/api/codex",
listener
.local_addr()
.expect("listener should have a local addr")
);
let codex_home = TempDir::new().expect("temp dir should create");
let (transport_event_tx, mut transport_event_rx) =
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
let shutdown_token = CancellationToken::new();
let remote_handle = start_remote_control(
remote_control_url,
codex_home.path().to_path_buf(),
remote_control_auth_manager(),
transport_event_tx,
shutdown_token.clone(),
ConnectionIdAllocator::default(),
)
.await
.expect("remote control should start");
let enroll_request = accept_http_request(&listener).await;
assert_eq!(
enroll_request.request_line,
"POST /api/codex/remote/control/server/enroll HTTP/1.1"
);
respond_with_json(
enroll_request.stream,
json!({ "server_id": "srv_e_client_manager" }),
)
.await;
let (_handshake_request, mut websocket) =
accept_remote_control_backend_connection(&listener, None).await;
let client_id = ClientId("client-1".to_string());
send_client_event(
&mut websocket,
ClientEvent::Ping {
client_id: client_id.clone(),
state: Some(ClientActivityState::Foreground),
},
)
.await;
assert_eq!(
read_server_event(&mut websocket).await,
json!({
"type": "pong",
"client_id": "client-1",
"status": "unknown",
})
);
send_client_event(
&mut websocket,
ClientEvent::ClientMessage {
client_id: client_id.clone(),
message: JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
}),
},
)
.await;
assert!(
timeout(Duration::from_millis(100), transport_event_rx.recv())
.await
.is_err(),
"non-initialize client messages should be ignored before connection creation"
);
let initialize_message = JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(1),
method: "initialize".to_string(),
params: Some(json!({
"clientInfo": {
"name": "remote-test-client",
"version": "0.1.0"
}
})),
trace: None,
});
send_client_event(
&mut websocket,
ClientEvent::ClientMessage {
client_id: client_id.clone(),
message: initialize_message.clone(),
},
)
.await;
let (connection_id, writer) = match timeout(Duration::from_secs(5), transport_event_rx.recv())
.await
.expect("connection open should arrive in time")
.expect("connection open should exist")
{
TransportEvent::ConnectionOpened {
connection_id,
transport_kind,
writer,
..
} => {
assert_eq!(transport_kind, AppServerTransport::RemoteControlled);
(connection_id, writer)
}
other => panic!("expected connection open event, got {other:?}"),
};
match timeout(Duration::from_secs(5), transport_event_rx.recv())
.await
.expect("initialize message should arrive in time")
.expect("initialize message should exist")
{
TransportEvent::IncomingMessage {
connection_id: incoming_connection_id,
message,
} => {
assert_eq!(incoming_connection_id, connection_id);
assert_eq!(message, initialize_message);
}
other => panic!("expected initialize incoming message, got {other:?}"),
}
let followup_message = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: None,
});
send_client_event(
&mut websocket,
ClientEvent::ClientMessage {
client_id: client_id.clone(),
message: followup_message.clone(),
},
)
.await;
match timeout(Duration::from_secs(5), transport_event_rx.recv())
.await
.expect("followup message should arrive in time")
.expect("followup message should exist")
{
TransportEvent::IncomingMessage {
connection_id: incoming_connection_id,
message,
} => {
assert_eq!(incoming_connection_id, connection_id);
assert_eq!(message, followup_message);
}
other => panic!("expected followup incoming message, got {other:?}"),
}
send_client_event(
&mut websocket,
ClientEvent::Ping {
client_id: client_id.clone(),
state: Some(ClientActivityState::Foreground),
},
)
.await;
assert_eq!(
read_server_event(&mut websocket).await,
json!({
"type": "pong",
"client_id": "client-1",
"status": "active",
})
);
writer
.send(OutgoingMessage::Notification(OutgoingNotification {
method: "codex/event/test".to_string(),
params: Some(json!({ "ok": true })),
}))
.await
.expect("remote writer should accept outgoing message");
assert_eq!(
read_server_event(&mut websocket).await,
json!({
"type": "server_message",
"client_id": "client-1",
"message": {
"method": "codex/event/test",
"params": {
"ok": true,
}
}
})
);
send_client_event(
&mut websocket,
ClientEvent::ClientClosed {
client_id: client_id.clone(),
},
)
.await;
match timeout(Duration::from_secs(5), transport_event_rx.recv())
.await
.expect("connection close should arrive in time")
.expect("connection close should exist")
{
TransportEvent::ConnectionClosed {
connection_id: closed_connection_id,
transport_kind,
} => {
assert_eq!(closed_connection_id, connection_id);
assert_eq!(transport_kind, AppServerTransport::RemoteControlled);
}
other => panic!("expected connection close event, got {other:?}"),
}
send_client_event(
&mut websocket,
ClientEvent::Ping {
client_id,
state: Some(ClientActivityState::Foreground),
},
)
.await;
assert_eq!(
read_server_event(&mut websocket).await,
json!({
"type": "pong",
"client_id": "client-1",
"status": "unknown",
})
);
shutdown_token.cancel();
let _ = remote_handle.await;
}

View File

@@ -0,0 +1,341 @@
use crate::remote_control::entrollment_manager::EnrollmentManager;
use crate::remote_control::entrollment_manager::preview_remote_control_response_body;
use crate::transport::colorize;
use codex_core::AuthManager;
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
use futures::SinkExt;
use futures::StreamExt;
use owo_colors::Style;
use std::io::ErrorKind;
use std::io::Result as IoResult;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::HeaderValue;
use tokio_tungstenite::tungstenite::{self};
use tokio_util::sync::CancellationToken;
use tracing::error;
use tracing::info;
use tracing::warn;
use super::ClientEvent;
use super::REMOTE_CONTROL_ACCOUNT_ID_HEADER;
use super::REMOTE_CONTROL_REQUEST_ID_HEADER;
use super::RemoteControlConnectionAuth;
use super::RemoteControlEnrollment;
use super::ServerEvent;
use super::load_remote_control_auth;
const REMOTE_CONTROL_RECONNECT_INITIAL_BACKOFF: Duration = Duration::from_secs(1);
const REMOTE_CONTROL_RECONNECT_MAX_BACKOFF: Duration = Duration::from_secs(30);
const REMOTE_CONTROL_PROTOCOL_VERSION: &str = "2";
const REMOTE_CONTROL_OAI_REQUEST_ID_HEADER: &str = "x-oai-request-id";
const REMOTE_CONTROL_CF_RAY_HEADER: &str = "cf-ray";
struct RemoteControlWebsocketConnection {
websocket_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
#[cfg_attr(not(test), allow(dead_code))]
request_id: Option<String>,
}
#[allow(clippy::print_stderr)]
fn print_remote_control_connection_banner(
remote_control_target: &super::RemoteControlTarget,
reconnect_attempt: u64,
reconnect_backoff: Duration,
reconnect_reason: Option<&str>,
) {
let title = colorize("app-server remote-control", Style::new().bold().yellow());
let control_server_label = colorize("control server:", Style::new().dimmed());
let control_server_url = remote_control_target.websocket_url.as_str();
let control_server_url = colorize(control_server_url, Style::new().green());
eprintln!("{title}");
eprintln!(" {control_server_label} {control_server_url}");
if reconnect_attempt > 0 {
let attempt_label = colorize("attempt:", Style::new().dimmed());
let after_label = colorize("after:", Style::new().dimmed());
eprintln!(" {attempt_label} {reconnect_attempt}");
eprintln!(" {after_label} {reconnect_backoff:?}");
}
if let Some(reason) = reconnect_reason {
let reason_label = colorize("reason:", Style::new().dimmed());
eprintln!(" {reason_label} {reason}");
}
}
#[allow(clippy::print_stderr)]
pub(super) async fn run(
auth_manager: std::sync::Arc<AuthManager>,
remote_control_target: super::RemoteControlTarget,
mut enrollment_manager: EnrollmentManager,
client_event_tx: mpsc::Sender<ClientEvent>,
mut server_event_rx: mpsc::Receiver<ServerEvent>,
shutdown_token: CancellationToken,
) {
let mut reconnect_backoff = REMOTE_CONTROL_RECONNECT_INITIAL_BACKOFF;
let mut reconnect_attempt = 0_u64;
let mut reconnect_reason = None::<String>;
let mut wait_before_connect = false;
let mut pending_server_event = None::<ServerEvent>;
loop {
let connect_delay = if wait_before_connect {
tokio::select! {
_ = shutdown_token.cancelled() => {
break;
}
_ = tokio::time::sleep(reconnect_backoff) => {
reconnect_attempt = reconnect_attempt.saturating_add(1);
}
}
reconnect_backoff
} else {
wait_before_connect = true;
Duration::ZERO
};
print_remote_control_connection_banner(
&remote_control_target,
reconnect_attempt,
connect_delay,
reconnect_reason.as_deref(),
);
let websocket_connection = tokio::select! {
_ = shutdown_token.cancelled() => {
break;
}
connect_result = connect_remote_control_websocket(
auth_manager.as_ref(),
&remote_control_target,
&mut enrollment_manager,
) => {
match connect_result {
Ok(websocket_connection) => {
reconnect_backoff = REMOTE_CONTROL_RECONNECT_INITIAL_BACKOFF;
reconnect_attempt = 0;
info!(
"connected to app-server remote control websocket: {}",
remote_control_target.websocket_url
);
websocket_connection
}
Err(err) => {
warn!("{err}");
reconnect_reason = Some(err.to_string());
if connect_delay != Duration::ZERO {
reconnect_backoff = reconnect_backoff
.saturating_mul(2)
.min(REMOTE_CONTROL_RECONNECT_MAX_BACKOFF);
}
continue;
}
}
}
};
let (mut websocket_writer, mut websocket_reader) =
websocket_connection.websocket_stream.split();
loop {
if let Some(server_event) = pending_server_event.take() {
let payload = match serde_json::to_string(&server_event) {
Ok(payload) => payload,
Err(err) => {
error!("failed to serialize remote-control server event: {err}");
continue;
}
};
if let Err(err) = websocket_writer
.send(TungsteniteMessage::Text(payload.into()))
.await
{
warn!("remote control websocket send failed: {err}");
reconnect_reason = Some(format!("send failed: {err}"));
pending_server_event = Some(server_event);
break;
}
continue;
}
tokio::select! {
_ = shutdown_token.cancelled() => {
return;
}
incoming_message = websocket_reader.next() => {
match incoming_message {
Some(Ok(TungsteniteMessage::Text(text))) => {
match serde_json::from_str::<ClientEvent>(&text) {
Ok(client_event) => {
if client_event_tx.send(client_event).await.is_err() {
return;
}
}
Err(_) => {
warn!("failed to deserialize remote-control client event");
}
}
}
Some(Ok(TungsteniteMessage::Ping(payload))) => {
if let Err(err) = websocket_writer
.send(TungsteniteMessage::Pong(payload))
.await
{
warn!("remote control websocket pong failed: {err}");
reconnect_reason = Some(format!("pong failed: {err}"));
break;
}
}
Some(Ok(TungsteniteMessage::Pong(_))) => {}
Some(Ok(TungsteniteMessage::Binary(_))) => {
warn!("dropping unsupported binary remote-control websocket message");
}
Some(Ok(TungsteniteMessage::Frame(_))) => {}
Some(Ok(TungsteniteMessage::Close(_))) | None => {
warn!("remote control websocket disconnected");
reconnect_reason = Some("server closed the websocket".to_string());
break;
}
Some(Err(err)) => {
warn!("remote control websocket receive error: {err}");
reconnect_reason = Some(format!("receive error: {err}"));
break;
}
}
}
server_event = server_event_rx.recv() => {
let Some(server_event) = server_event else {
return;
};
pending_server_event = Some(server_event);
}
}
}
}
}
fn set_remote_control_header(
headers: &mut tungstenite::http::HeaderMap,
name: &'static str,
value: &str,
) -> IoResult<()> {
let header_value = HeaderValue::from_str(value).map_err(|err| {
std::io::Error::new(
ErrorKind::InvalidInput,
format!("invalid remote control header `{name}`: {err}"),
)
})?;
headers.insert(name, header_value);
Ok(())
}
fn build_remote_control_websocket_request(
auth: &RemoteControlConnectionAuth,
websocket_url: &str,
enrollment: &RemoteControlEnrollment,
) -> IoResult<tungstenite::http::Request<()>> {
let mut request = websocket_url.into_client_request().map_err(|err| {
std::io::Error::new(
ErrorKind::InvalidInput,
format!("invalid remote control websocket URL `{websocket_url}`: {err}"),
)
})?;
let headers = request.headers_mut();
set_remote_control_header(headers, "x-codex-server-id", &enrollment.server_id)?;
set_remote_control_header(headers, "x-codex-name", &enrollment.server_name)?;
set_remote_control_header(
headers,
"x-codex-protocol-version",
REMOTE_CONTROL_PROTOCOL_VERSION,
)?;
set_remote_control_header(
headers,
"authorization",
&format!("Bearer {}", auth.bearer_token),
)?;
if let Some(account_id) = auth.account_id.as_deref() {
set_remote_control_header(headers, REMOTE_CONTROL_ACCOUNT_ID_HEADER, account_id)?;
}
Ok(request)
}
async fn connect_remote_control_websocket(
auth_manager: &AuthManager,
remote_control_target: &super::RemoteControlTarget,
enrollment_manager: &mut EnrollmentManager,
) -> IoResult<RemoteControlWebsocketConnection> {
ensure_rustls_crypto_provider();
let websocket_url = remote_control_target.websocket_url.clone();
let auth = load_remote_control_auth(auth_manager).await?;
let enrollment = enrollment_manager.enroll(&auth).await?;
let request = build_remote_control_websocket_request(&auth, &websocket_url, &enrollment)?;
let (websocket_stream, response) = match connect_async(request).await {
Ok(connection) => connection,
Err(err) => {
return Err(std::io::Error::other(
format_remote_control_websocket_connect_error(&websocket_url, &err),
));
}
};
let request_id = remote_control_request_id(response.headers());
Ok(RemoteControlWebsocketConnection {
websocket_stream,
request_id,
})
}
fn remote_control_header_value(
headers: &tungstenite::http::HeaderMap,
header_name: &str,
) -> Option<String> {
headers
.get(header_name)
.and_then(|value| value.to_str().ok())
.map(str::to_owned)
}
fn remote_control_request_id(headers: &tungstenite::http::HeaderMap) -> Option<String> {
remote_control_header_value(headers, REMOTE_CONTROL_REQUEST_ID_HEADER)
.or_else(|| remote_control_header_value(headers, REMOTE_CONTROL_OAI_REQUEST_ID_HEADER))
}
fn format_remote_control_websocket_connect_error(
websocket_url: &str,
err: &tungstenite::Error,
) -> String {
let mut message =
format!("failed to connect app-server remote control websocket `{websocket_url}`: {err}");
let tungstenite::Error::Http(response) = err else {
return message;
};
if let Some(request_id) = remote_control_request_id(response.headers()) {
message.push_str(&format!(", request id: {request_id}"));
}
if let Some(cf_ray) =
remote_control_header_value(response.headers(), REMOTE_CONTROL_CF_RAY_HEADER)
{
message.push_str(&format!(", cf-ray: {cf_ray}"));
}
if let Some(body) = response.body().as_ref()
&& !body.is_empty()
{
let body_preview = preview_remote_control_response_body(body);
message.push_str(&format!(", body: {body_preview}"));
}
message
}
#[cfg(test)]
#[path = "connection_manager_tests.rs"]
mod connection_manager_tests;

View File

@@ -0,0 +1,285 @@
use super::super::ClientEvent;
use super::super::ClientId;
use super::super::REMOTE_CONTROL_REQUEST_ID_HEADER;
use super::super::entrollment_manager::EnrollmentManager;
use super::super::normalize_remote_control_url;
use super::super::start_remote_control;
use super::super::test_support::accept_http_request;
use super::super::test_support::accept_remote_control_backend_connection;
use super::super::test_support::read_server_event;
use super::super::test_support::remote_control_auth_manager;
use super::super::test_support::respond_with_json;
use super::super::test_support::respond_with_status_and_headers;
use super::super::test_support::send_client_event;
use super::*;
use crate::outgoing_message::OutgoingMessage;
use crate::outgoing_message::OutgoingNotification;
use crate::transport::AppServerTransport;
use crate::transport::CHANNEL_CAPACITY;
use crate::transport::ConnectionIdAllocator;
use crate::transport::TransportEvent;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::RequestId;
use pretty_assertions::assert_eq;
use serde_json::json;
use tempfile::TempDir;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio::time::timeout;
use tokio_util::sync::CancellationToken;
#[tokio::test]
async fn connect_remote_control_websocket_captures_handshake_request_id() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = format!(
"http://{}/api/codex",
listener
.local_addr()
.expect("listener should have a local addr")
);
let remote_control_target =
normalize_remote_control_url(&remote_control_url).expect("target should parse");
let accept_task = tokio::spawn(async move {
let enroll_request = accept_http_request(&listener).await;
assert_eq!(
enroll_request.request_line,
"POST /api/codex/remote/control/server/enroll HTTP/1.1"
);
assert_eq!(
enroll_request.headers.get("authorization"),
Some(&"Bearer Access Token".to_string())
);
assert_eq!(
enroll_request.headers.get("chatgpt-account-id"),
Some(&"account_id".to_string())
);
let enroll_body = serde_json::from_str::<serde_json::Value>(&enroll_request.body)
.expect("enroll body should deserialize");
assert_eq!(enroll_body["os"], json!(std::env::consts::OS));
assert_eq!(enroll_body["arch"], json!(std::env::consts::ARCH));
assert_eq!(
enroll_body["app_server_version"],
json!(env!("CARGO_PKG_VERSION"))
);
assert!(enroll_body["name"].is_string());
respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await;
accept_remote_control_backend_connection(&listener, Some("req-control-123")).await
});
let codex_home = TempDir::new().expect("temp dir should create");
let auth_manager = remote_control_auth_manager();
let mut enrollment_manager = EnrollmentManager::new(
remote_control_target.clone(),
codex_home.path().to_path_buf(),
);
let connection = connect_remote_control_websocket(
auth_manager.as_ref(),
&remote_control_target,
&mut enrollment_manager,
)
.await
.expect("websocket connection should succeed");
assert_eq!(connection.request_id.as_deref(), Some("req-control-123"));
let (_request, server_websocket) = accept_task.await.expect("accept task should succeed");
drop(server_websocket);
drop(connection.websocket_stream);
}
#[tokio::test]
async fn connect_remote_control_websocket_includes_http_error_details() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = format!(
"http://{}/api/codex",
listener
.local_addr()
.expect("listener should have a local addr")
);
let remote_control_target =
normalize_remote_control_url(&remote_control_url).expect("target should parse");
let websocket_url = remote_control_target.websocket_url.clone();
let expected_error = format!(
"failed to connect app-server remote control websocket `{websocket_url}`: HTTP error: 503 Service Unavailable, request id: req-503, cf-ray: ray-503, body: upstream unavailable"
);
let server_task = tokio::spawn(async move {
let enroll_request = accept_http_request(&listener).await;
assert_eq!(
enroll_request.request_line,
"POST /api/codex/remote/control/server/enroll HTTP/1.1"
);
assert_eq!(
enroll_request.headers.get("authorization"),
Some(&"Bearer Access Token".to_string())
);
assert_eq!(
enroll_request.headers.get("chatgpt-account-id"),
Some(&"account_id".to_string())
);
let enroll_body = serde_json::from_str::<serde_json::Value>(&enroll_request.body)
.expect("enroll body should deserialize");
assert_eq!(enroll_body["os"], json!(std::env::consts::OS));
assert_eq!(enroll_body["arch"], json!(std::env::consts::ARCH));
assert_eq!(
enroll_body["app_server_version"],
json!(env!("CARGO_PKG_VERSION"))
);
assert!(enroll_body["name"].is_string());
respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await;
let request = accept_http_request(&listener).await;
assert_eq!(
request.request_line,
"GET /api/codex/remote/control/server HTTP/1.1"
);
respond_with_status_and_headers(
request.stream,
"503 Service Unavailable",
&[
(REMOTE_CONTROL_REQUEST_ID_HEADER, "req-503"),
(REMOTE_CONTROL_CF_RAY_HEADER, "ray-503"),
],
"upstream unavailable",
)
.await;
});
let codex_home = TempDir::new().expect("temp dir should create");
let auth_manager = remote_control_auth_manager();
let mut enrollment_manager = EnrollmentManager::new(
remote_control_target.clone(),
codex_home.path().to_path_buf(),
);
let err = match connect_remote_control_websocket(
auth_manager.as_ref(),
&remote_control_target,
&mut enrollment_manager,
)
.await
{
Ok(_) => panic!("http error response should fail the websocket connect"),
Err(err) => err,
};
server_task.await.expect("server task should succeed");
assert_eq!(err.to_string(), expected_error);
}
#[tokio::test]
async fn remote_control_transport_reconnects_and_keeps_virtual_client_writer_alive() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = format!(
"http://{}/api/codex",
listener
.local_addr()
.expect("listener should have a local addr")
);
let codex_home = TempDir::new().expect("temp dir should create");
let (transport_event_tx, mut transport_event_rx) =
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
let shutdown_token = CancellationToken::new();
let remote_handle = start_remote_control(
remote_control_url,
codex_home.path().to_path_buf(),
remote_control_auth_manager(),
transport_event_tx,
shutdown_token.clone(),
ConnectionIdAllocator::default(),
)
.await
.expect("remote control should start");
let enroll_request = accept_http_request(&listener).await;
assert_eq!(
enroll_request.request_line,
"POST /api/codex/remote/control/server/enroll HTTP/1.1"
);
respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await;
let (_first_request, mut first_websocket) =
accept_remote_control_backend_connection(&listener, None).await;
let client_id = ClientId("client-2".to_string());
let initialize_message = JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(2),
method: "initialize".to_string(),
params: Some(json!({
"clientInfo": {
"name": "remote-test-client",
"version": "0.1.0"
}
})),
trace: None,
});
send_client_event(
&mut first_websocket,
ClientEvent::ClientMessage {
client_id: client_id.clone(),
message: initialize_message.clone(),
},
)
.await;
let writer = match timeout(Duration::from_secs(5), transport_event_rx.recv())
.await
.expect("connection open should arrive in time")
.expect("connection open should exist")
{
TransportEvent::ConnectionOpened {
writer,
transport_kind,
..
} => {
assert_eq!(transport_kind, AppServerTransport::RemoteControlled);
writer
}
other => panic!("expected connection open before reconnect, got {other:?}"),
};
match timeout(Duration::from_secs(5), transport_event_rx.recv())
.await
.expect("initialize message should arrive in time")
.expect("initialize message should exist")
{
TransportEvent::IncomingMessage { message, .. } => {
assert_eq!(message, initialize_message);
}
other => panic!("expected initialize incoming message, got {other:?}"),
}
first_websocket
.close(None)
.await
.expect("first websocket should close");
drop(first_websocket);
let (_second_request, mut second_websocket) =
accept_remote_control_backend_connection(&listener, None).await;
writer
.send(OutgoingMessage::Notification(OutgoingNotification {
method: "codex/event/reconnected".to_string(),
params: Some(json!({ "replayed": true })),
}))
.await
.expect("outgoing message should send after reconnect");
assert_eq!(
read_server_event(&mut second_websocket).await,
json!({
"type": "server_message",
"client_id": client_id.0,
"message": {
"method": "codex/event/reconnected",
"params": {
"replayed": true,
}
}
})
);
shutdown_token.cancel();
let _ = remote_handle.await;
}

View File

@@ -0,0 +1,312 @@
use super::super::RemoteControlConnectionAuth;
use super::super::RemoteControlEnrollment;
use super::super::RemoteControlTarget;
use super::super::load_remote_control_auth;
use super::super::normalize_remote_control_url;
use super::super::test_support::accept_http_request;
use super::super::test_support::respond_with_json;
use super::*;
use codex_core::CodexAuth;
use codex_core::test_support::auth_manager_from_auth;
use pretty_assertions::assert_eq;
use serde_json::json;
use tempfile::TempDir;
use tokio::net::TcpListener;
#[tokio::test]
async fn validate_remote_control_auth_rejects_api_key_auth() {
let auth_manager = auth_manager_from_auth(CodexAuth::from_api_key("sk-test"));
let err = load_remote_control_auth(auth_manager.as_ref())
.await
.expect_err("API key auth should be rejected");
assert_eq!(
err.to_string(),
"remote control requires ChatGPT authentication; API key auth is not supported"
);
}
#[test]
fn normalize_remote_control_url_handles_supported_and_unsupported_inputs() {
assert_eq!(
normalize_remote_control_url("http://example.com/backend-api/wham")
.expect("valid http prefix"),
RemoteControlTarget {
websocket_url: "ws://example.com/backend-api/wham/remote/control/server".to_string(),
enroll_url: "http://example.com/backend-api/wham/remote/control/server/enroll"
.to_string(),
}
);
assert_eq!(
normalize_remote_control_url("https://example.com/backend-api/wham/remote/control/server")
.expect("valid https full path"),
RemoteControlTarget {
websocket_url: "wss://example.com/backend-api/wham/remote/control/server".to_string(),
enroll_url: "https://example.com/backend-api/wham/remote/control/server/enroll"
.to_string(),
}
);
assert_eq!(
normalize_remote_control_url("http://example.com/legacy/server")
.expect("valid legacy http url"),
RemoteControlTarget {
websocket_url: "ws://example.com/legacy/server".to_string(),
enroll_url: "http://example.com/legacy/server/enroll".to_string(),
}
);
assert_eq!(
normalize_remote_control_url("https://chatgpt.com/backend-api")
.expect("chatgpt backend-api base should target the public wham path"),
RemoteControlTarget {
websocket_url: "wss://chatgpt.com/backend-api/wham/remote/control/server".to_string(),
enroll_url: "https://chatgpt.com/backend-api/wham/remote/control/server/enroll"
.to_string(),
}
);
assert_eq!(
normalize_remote_control_url("https://chat.openai.com")
.expect("chat.openai.com root should normalize"),
RemoteControlTarget {
websocket_url: "wss://chat.openai.com/backend-api/wham/remote/control/server"
.to_string(),
enroll_url: "https://chat.openai.com/backend-api/wham/remote/control/server/enroll"
.to_string(),
}
);
assert_eq!(
normalize_remote_control_url("https://chatgpt.com/api/codex/remote/control/server")
.expect("internal chatgpt remote-control path should rewrite to the public wham path"),
RemoteControlTarget {
websocket_url: "wss://chatgpt.com/backend-api/wham/remote/control/server".to_string(),
enroll_url: "https://chatgpt.com/backend-api/wham/remote/control/server/enroll"
.to_string(),
}
);
assert_eq!(
normalize_remote_control_url("https://chatgpt.com/api/codex")
.expect("explicit chatgpt api/codex base should rewrite to the public wham path"),
RemoteControlTarget {
websocket_url: "wss://chatgpt.com/backend-api/wham/remote/control/server".to_string(),
enroll_url: "https://chatgpt.com/backend-api/wham/remote/control/server/enroll"
.to_string(),
}
);
let err = normalize_remote_control_url("ftp://example.com/control")
.expect_err("unsupported scheme should fail");
assert_eq!(
err.to_string(),
"invalid remote control URL `ftp://example.com/control`; expected http:// or https://"
);
let err =
normalize_remote_control_url("ws://example.com/control").expect_err("ws url should fail");
assert_eq!(
err.to_string(),
"invalid remote control URL `ws://example.com/control`; expected http:// or https://"
);
}
#[tokio::test]
async fn persisted_remote_control_enrollment_is_scoped_and_selectively_cleared() {
let codex_home = TempDir::new().expect("temp dir should create");
let state_path = remote_control_state_path(codex_home.path());
let first_target = normalize_remote_control_url("http://example.com/remote/control")
.expect("first target should parse");
let second_target = normalize_remote_control_url("http://example.com/other/control")
.expect("second target should parse");
let first_enrollment = RemoteControlEnrollment {
server_id: "srv_e_first".to_string(),
server_name: REMOTE_CONTROL_SERVER_NAME.to_string(),
};
let same_target_other_account_enrollment = RemoteControlEnrollment {
server_id: "srv_e_first_account_b".to_string(),
server_name: REMOTE_CONTROL_SERVER_NAME.to_string(),
};
let second_enrollment = RemoteControlEnrollment {
server_id: "srv_e_second".to_string(),
server_name: REMOTE_CONTROL_SERVER_NAME.to_string(),
};
update_persisted_remote_control_enrollment(
state_path.as_path(),
&first_target,
Some("account-a"),
Some(&first_enrollment),
)
.await
.expect("first enrollment should persist");
update_persisted_remote_control_enrollment(
state_path.as_path(),
&second_target,
Some("account-a"),
Some(&second_enrollment),
)
.await
.expect("second enrollment should persist");
update_persisted_remote_control_enrollment(
state_path.as_path(),
&first_target,
Some("account-b"),
Some(&same_target_other_account_enrollment),
)
.await
.expect("other-account enrollment should persist");
assert_eq!(
load_persisted_remote_control_enrollment(
state_path.as_path(),
&first_target,
Some("account-a")
)
.await,
Some(first_enrollment.clone())
);
assert_eq!(
load_persisted_remote_control_enrollment(
state_path.as_path(),
&first_target,
Some("account-b")
)
.await,
Some(same_target_other_account_enrollment.clone())
);
assert_eq!(
load_persisted_remote_control_enrollment(
state_path.as_path(),
&second_target,
Some("account-a")
)
.await,
Some(second_enrollment.clone())
);
update_persisted_remote_control_enrollment(
state_path.as_path(),
&first_target,
Some("account-a"),
None,
)
.await
.expect("matching enrollment should clear");
assert_eq!(
load_persisted_remote_control_enrollment(
state_path.as_path(),
&first_target,
Some("account-b")
)
.await,
Some(same_target_other_account_enrollment)
);
assert_eq!(
load_persisted_remote_control_enrollment(
state_path.as_path(),
&second_target,
Some("account-a")
)
.await,
Some(second_enrollment)
);
}
#[tokio::test]
async fn enrollment_manager_cache_is_scoped_to_the_current_account() {
let codex_home = TempDir::new().expect("temp dir should create");
let state_path = remote_control_state_path(codex_home.path());
let remote_control_target = normalize_remote_control_url("http://example.com/remote/control")
.expect("target should parse");
let account_a_enrollment = RemoteControlEnrollment {
server_id: "srv_e_account_a".to_string(),
server_name: REMOTE_CONTROL_SERVER_NAME.to_string(),
};
let account_b_enrollment = RemoteControlEnrollment {
server_id: "srv_e_account_b".to_string(),
server_name: REMOTE_CONTROL_SERVER_NAME.to_string(),
};
update_persisted_remote_control_enrollment(
state_path.as_path(),
&remote_control_target,
Some("account-a"),
Some(&account_a_enrollment),
)
.await
.expect("account-a enrollment should persist");
update_persisted_remote_control_enrollment(
state_path.as_path(),
&remote_control_target,
Some("account-b"),
Some(&account_b_enrollment),
)
.await
.expect("account-b enrollment should persist");
let mut enrollment_manager =
EnrollmentManager::new(remote_control_target, codex_home.path().to_path_buf());
assert_eq!(
enrollment_manager
.enroll(&RemoteControlConnectionAuth {
bearer_token: "Access Token".to_string(),
account_id: Some("account-a".to_string()),
})
.await
.expect("account-a enrollment should load from the cache or state file"),
account_a_enrollment
);
assert_eq!(
enrollment_manager
.enroll(&RemoteControlConnectionAuth {
bearer_token: "Access Token".to_string(),
account_id: Some("account-b".to_string()),
})
.await
.expect("account-b enrollment should replace the cached account-a enrollment"),
account_b_enrollment
);
}
#[tokio::test]
async fn enroll_remote_control_server_parse_failure_includes_response_body() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let remote_control_url = format!(
"http://{}/api/codex",
listener
.local_addr()
.expect("listener should have a local addr")
);
let remote_control_target =
normalize_remote_control_url(&remote_control_url).expect("target should parse");
let enroll_url = remote_control_target.enroll_url.clone();
let response_body = json!({
"error": "not enrolled",
});
let expected_body = response_body.to_string();
let server_task = tokio::spawn(async move {
let enroll_request = accept_http_request(&listener).await;
respond_with_json(enroll_request.stream, response_body).await;
});
let err = enroll_remote_control_server(
&remote_control_target,
&RemoteControlConnectionAuth {
bearer_token: "Access Token".to_string(),
account_id: Some("account_id".to_string()),
},
)
.await
.expect_err("invalid response should fail to parse");
server_task.await.expect("server task should succeed");
assert_eq!(
err.to_string(),
format!(
"failed to parse remote control enrollment response from `{enroll_url}`: HTTP 200 OK, body: {expected_body}, decode error: missing field `server_id` at line 1 column {}",
expected_body.len()
)
);
}

View File

@@ -0,0 +1,325 @@
use codex_core::default_client::build_reqwest_client;
use serde::Deserialize;
use serde::Serialize;
use std::io::ErrorKind;
use std::io::Result as IoResult;
use std::path::Path;
use std::path::PathBuf;
use tracing::warn;
use super::REMOTE_CONTROL_ACCOUNT_ID_HEADER;
use super::REMOTE_CONTROL_STATE_FILE;
use super::RemoteControlConnectionAuth;
use super::RemoteControlEnrollment;
use super::RemoteControlTarget;
const REMOTE_CONTROL_ENROLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
const REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES: usize = 4096;
const REMOTE_CONTROL_SERVER_NAME: &str = "codex-app-server";
struct CachedRemoteControlEnrollment {
account_id: Option<String>,
enrollment: RemoteControlEnrollment,
}
pub(super) struct EnrollmentManager {
remote_control_target: RemoteControlTarget,
remote_control_state_path: PathBuf,
cached_enrollment: Option<CachedRemoteControlEnrollment>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
struct RemoteControlStateToml {
#[serde(default)]
enrollments: Vec<PersistedRemoteControlEnrollment>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
struct PersistedRemoteControlEnrollment {
websocket_url: String,
account_id: Option<String>,
server_id: String,
server_name: String,
}
#[derive(Debug, Serialize)]
struct EnrollRemoteServerRequest<'a> {
name: &'a str,
os: &'a str,
arch: &'a str,
app_server_version: &'a str,
}
#[derive(Debug, Deserialize)]
struct EnrollRemoteServerResponse {
server_id: String,
}
impl EnrollmentManager {
pub(super) fn new(remote_control_target: RemoteControlTarget, codex_home: PathBuf) -> Self {
Self {
remote_control_target,
remote_control_state_path: remote_control_state_path(codex_home.as_path()),
cached_enrollment: None,
}
}
pub(super) async fn enroll(
&mut self,
auth: &RemoteControlConnectionAuth,
) -> IoResult<RemoteControlEnrollment> {
if self
.cached_enrollment
.as_ref()
.and_then(|cached| cached.account_id.as_deref())
!= auth.account_id.as_deref()
{
self.cached_enrollment = None;
}
if self.cached_enrollment.is_none() {
self.cached_enrollment = load_persisted_remote_control_enrollment(
self.remote_control_state_path.as_path(),
&self.remote_control_target,
auth.account_id.as_deref(),
)
.await
.map(|enrollment| CachedRemoteControlEnrollment {
account_id: auth.account_id.clone(),
enrollment,
});
}
if self.cached_enrollment.is_none() {
let new_enrollment =
enroll_remote_control_server(&self.remote_control_target, auth).await?;
if let Err(err) = update_persisted_remote_control_enrollment(
self.remote_control_state_path.as_path(),
&self.remote_control_target,
auth.account_id.as_deref(),
Some(&new_enrollment),
)
.await
{
warn!(
"failed to persist remote control enrollment in `{}`: {err}",
self.remote_control_state_path.display()
);
}
self.cached_enrollment = Some(CachedRemoteControlEnrollment {
account_id: auth.account_id.clone(),
enrollment: new_enrollment,
});
}
let enrollment = self
.cached_enrollment
.as_ref()
.map(|cached| cached.enrollment.clone())
.ok_or_else(|| {
std::io::Error::other("missing remote control enrollment after enrollment step")
})?;
Ok(enrollment)
}
}
fn remote_control_server_name() -> String {
let host_name = gethostname::gethostname();
let host_name = host_name.to_string_lossy();
let host_name = host_name.trim();
if host_name.is_empty() {
REMOTE_CONTROL_SERVER_NAME.to_string()
} else {
host_name.to_owned()
}
}
fn matches_persisted_remote_control_enrollment(
entry: &PersistedRemoteControlEnrollment,
remote_control_target: &RemoteControlTarget,
account_id: Option<&str>,
) -> bool {
entry.websocket_url == remote_control_target.websocket_url
&& entry.account_id.as_deref() == account_id
}
async fn load_remote_control_state(state_path: &Path) -> IoResult<RemoteControlStateToml> {
let contents = match tokio::fs::read_to_string(state_path).await {
Ok(contents) => contents,
Err(err) if err.kind() == ErrorKind::NotFound => {
return Ok(RemoteControlStateToml::default());
}
Err(err) => return Err(err),
};
toml::from_str(&contents).map_err(|err| {
std::io::Error::new(
ErrorKind::InvalidData,
format!(
"failed to parse remote control state `{}`: {err}",
state_path.display()
),
)
})
}
fn remote_control_state_path(codex_home: &Path) -> PathBuf {
codex_home.join(REMOTE_CONTROL_STATE_FILE)
}
async fn write_remote_control_state(
state_path: &Path,
state: &RemoteControlStateToml,
) -> IoResult<()> {
if let Some(parent) = state_path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let serialized = toml::to_string(state).map_err(std::io::Error::other)?;
tokio::fs::write(state_path, serialized).await
}
pub(super) async fn load_persisted_remote_control_enrollment(
state_path: &Path,
remote_control_target: &RemoteControlTarget,
account_id: Option<&str>,
) -> Option<RemoteControlEnrollment> {
let state = match load_remote_control_state(state_path).await {
Ok(state) => state,
Err(err) => {
warn!("{err}");
return None;
}
};
state
.enrollments
.into_iter()
.find(|entry| {
matches_persisted_remote_control_enrollment(entry, remote_control_target, account_id)
})
.map(|entry| RemoteControlEnrollment {
server_id: entry.server_id,
server_name: entry.server_name,
})
}
pub(super) async fn update_persisted_remote_control_enrollment(
state_path: &Path,
remote_control_target: &RemoteControlTarget,
account_id: Option<&str>,
enrollment: Option<&RemoteControlEnrollment>,
) -> IoResult<()> {
let mut state = match load_remote_control_state(state_path).await {
Ok(state) => state,
Err(err) if err.kind() == ErrorKind::InvalidData => {
warn!("{err}");
RemoteControlStateToml::default()
}
Err(err) => return Err(err),
};
state.enrollments.retain(|entry| {
!matches_persisted_remote_control_enrollment(entry, remote_control_target, account_id)
});
if let Some(enrollment) = enrollment {
state.enrollments.push(PersistedRemoteControlEnrollment {
websocket_url: remote_control_target.websocket_url.clone(),
account_id: account_id.map(str::to_owned),
server_id: enrollment.server_id.clone(),
server_name: enrollment.server_name.clone(),
});
}
if state.enrollments.is_empty() {
match tokio::fs::remove_file(state_path).await {
Ok(()) => Ok(()),
Err(err) if err.kind() == ErrorKind::NotFound => Ok(()),
Err(err) => Err(err),
}
} else {
write_remote_control_state(state_path, &state).await
}
}
pub(super) fn preview_remote_control_response_body(body: &[u8]) -> String {
let body = String::from_utf8_lossy(body);
let trimmed = body.trim();
if trimmed.is_empty() {
return "<empty>".to_string();
}
if trimmed.len() <= REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES {
return trimmed.to_string();
}
let mut cut = REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES;
while !trimmed.is_char_boundary(cut) {
cut = cut.saturating_sub(1);
}
let mut truncated = trimmed[..cut].to_string();
truncated.push_str("...");
truncated
}
async fn enroll_remote_control_server(
remote_control_target: &RemoteControlTarget,
auth: &RemoteControlConnectionAuth,
) -> IoResult<RemoteControlEnrollment> {
let server_name = remote_control_server_name();
let request = EnrollRemoteServerRequest {
name: &server_name,
os: std::env::consts::OS,
arch: std::env::consts::ARCH,
app_server_version: env!("CARGO_PKG_VERSION"),
};
let client = build_reqwest_client();
let mut http_request = client
.post(&remote_control_target.enroll_url)
.timeout(REMOTE_CONTROL_ENROLL_TIMEOUT)
.bearer_auth(&auth.bearer_token)
.json(&request);
if let Some(account_id) = auth.account_id.as_deref() {
http_request = http_request.header(REMOTE_CONTROL_ACCOUNT_ID_HEADER, account_id);
}
let response = http_request.send().await.map_err(|err| {
std::io::Error::other(format!(
"failed to enroll remote control server at `{}`: {err}",
remote_control_target.enroll_url
))
})?;
let status = response.status();
let body = response.bytes().await.map_err(|err| {
std::io::Error::other(format!(
"failed to read remote control enrollment response from `{}`: {err}",
remote_control_target.enroll_url
))
})?;
let body_preview = preview_remote_control_response_body(&body);
if !status.is_success() {
return Err(std::io::Error::other(format!(
"remote control server enrollment failed at `{}`: HTTP {status}, body: {body_preview}",
remote_control_target.enroll_url
)));
}
let enrollment = serde_json::from_slice::<EnrollRemoteServerResponse>(&body).map_err(|err| {
std::io::Error::other(format!(
"failed to parse remote control enrollment response from `{}`: HTTP {status}, body: {body_preview}, decode error: {err}",
remote_control_target.enroll_url
))
})?;
Ok(RemoteControlEnrollment {
server_id: enrollment.server_id,
server_name,
})
}
#[cfg(test)]
#[path = "enrollment_tests.rs"]
mod enrollment_tests;

View File

@@ -0,0 +1,251 @@
mod client_manager;
mod connection_manager;
mod entrollment_manager;
#[cfg(test)]
mod test_support;
use self::entrollment_manager::EnrollmentManager;
use crate::outgoing_message::OutgoingMessage;
use crate::transport::CHANNEL_CAPACITY;
use crate::transport::ConnectionIdAllocator;
use crate::transport::TransportEvent;
use codex_app_server_protocol::JSONRPCMessage;
use codex_core::AuthManager;
use serde::Deserialize;
use serde::Serialize;
use std::io::ErrorKind;
use std::io::Result as IoResult;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id";
const REMOTE_CONTROL_REQUEST_ID_HEADER: &str = "x-request-id";
const REMOTE_CONTROL_STATE_FILE: &str = "remote_control.toml";
#[derive(Debug, Clone, PartialEq, Eq)]
struct RemoteControlTarget {
websocket_url: String,
enroll_url: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct RemoteControlEnrollment {
server_id: String,
server_name: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct RemoteControlConnectionAuth {
bearer_token: String,
account_id: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
struct ClientId(String);
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ClientActivityState {
Foreground,
Background,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ClientEvent {
ClientMessage {
client_id: ClientId,
message: JSONRPCMessage,
},
Ping {
client_id: ClientId,
#[serde(skip_serializing_if = "Option::is_none")]
state: Option<ClientActivityState>,
},
ClientClosed {
client_id: ClientId,
},
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ServerEvent {
ServerMessage {
client_id: ClientId,
message: Box<OutgoingMessage>,
},
Pong {
client_id: ClientId,
status: PongStatus,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum PongStatus {
Active,
Unknown,
}
pub(crate) async fn start_remote_control(
remote_control_url: String,
codex_home: PathBuf,
auth_manager: Arc<AuthManager>,
transport_event_tx: mpsc::Sender<TransportEvent>,
shutdown_token: CancellationToken,
connection_id_allocator: ConnectionIdAllocator,
) -> IoResult<JoinHandle<()>> {
let remote_control_target = normalize_remote_control_url(&remote_control_url)?;
let enrollment_manager = EnrollmentManager::new(remote_control_target.clone(), codex_home);
Ok(tokio::spawn(async move {
let local_shutdown_token = shutdown_token.child_token();
let (client_event_tx, client_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (server_event_tx, server_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (writer_exited_tx, writer_exited_rx) = mpsc::channel(CHANNEL_CAPACITY);
let mut websocket_task = tokio::spawn(connection_manager::run(
auth_manager,
remote_control_target,
enrollment_manager,
client_event_tx,
server_event_rx,
local_shutdown_token.clone(),
));
let mut manager_task = tokio::spawn(client_manager::run(
transport_event_tx,
client_event_rx,
server_event_tx,
writer_exited_tx,
writer_exited_rx,
local_shutdown_token.clone(),
connection_id_allocator,
));
tokio::select! {
_ = local_shutdown_token.cancelled() => {}
_ = &mut websocket_task => {
local_shutdown_token.cancel();
}
_ = &mut manager_task => {
local_shutdown_token.cancel();
}
}
let _ = websocket_task.await;
let _ = manager_task.await;
}))
}
fn normalize_remote_control_url(remote_control_url: &str) -> IoResult<RemoteControlTarget> {
let remote_control_url = remote_control_url.trim_end_matches('/');
if let Some(rest) = remote_control_url.strip_prefix("http://") {
return Ok(normalize_http_remote_control_url(rest, "http://", "ws://"));
}
if let Some(rest) = remote_control_url.strip_prefix("https://") {
return Ok(normalize_http_remote_control_url(
rest, "https://", "wss://",
));
}
Err(std::io::Error::new(
ErrorKind::InvalidInput,
format!("invalid remote control URL `{remote_control_url}`; expected http:// or https://"),
))
}
fn normalize_http_remote_control_url(
rest: &str,
http_scheme: &str,
websocket_scheme: &str,
) -> RemoteControlTarget {
let rest = normalize_chatgpt_remote_control_base(rest);
let rest = if let Some(rest) = rest.strip_suffix("/remote/control/server/enroll") {
format!("{rest}/remote/control/server")
} else if rest.ends_with("/remote/control/server") {
rest
} else if let Some(rest) = rest.strip_suffix("/server/enroll") {
format!("{rest}/server")
} else if rest.ends_with("/server") {
rest
} else {
format!("{rest}/remote/control/server")
};
RemoteControlTarget {
websocket_url: format!("{websocket_scheme}{rest}"),
enroll_url: format!("{http_scheme}{rest}/enroll"),
}
}
fn normalize_chatgpt_remote_control_base(rest: &str) -> String {
let trimmed = rest.trim_end_matches('/');
let (host, path) = match trimmed.split_once('/') {
Some((host, path)) => (host, Some(path)),
None => (trimmed, None),
};
if host != "chatgpt.com" && host != "chat.openai.com" {
return trimmed.to_string();
}
let Some(path) = path else {
return format!("{host}/backend-api/wham");
};
if path == "backend-api/wham" || path.starts_with("backend-api/wham/") {
return trimmed.to_string();
}
for internal_prefix in ["api/codex", "backend-api"] {
if path == internal_prefix {
return format!("{host}/backend-api/wham");
}
if let Some(suffix) = path.strip_prefix(&format!("{internal_prefix}/")) {
if suffix == "remote/control/server" || suffix == "remote/control/server/enroll" {
return format!("{host}/backend-api/wham/{suffix}");
}
return format!("{host}/backend-api/wham");
}
}
if path == "remote/control/server" || path == "remote/control/server/enroll" {
return format!("{host}/backend-api/wham/{path}");
}
trimmed.to_string()
}
async fn load_remote_control_auth(
auth_manager: &AuthManager,
) -> IoResult<RemoteControlConnectionAuth> {
let auth = match auth_manager.auth().await {
Some(auth) => auth,
None => {
auth_manager.reload();
auth_manager.auth().await.ok_or_else(|| {
std::io::Error::new(
ErrorKind::PermissionDenied,
"remote control requires ChatGPT authentication",
)
})?
}
};
if !auth.is_chatgpt_auth() {
return Err(std::io::Error::new(
ErrorKind::PermissionDenied,
"remote control requires ChatGPT authentication; API key auth is not supported",
));
}
Ok(RemoteControlConnectionAuth {
bearer_token: auth.get_token().map_err(std::io::Error::other)?,
account_id: auth.get_account_id(),
})
}

View File

@@ -0,0 +1,218 @@
use super::ClientEvent;
use super::REMOTE_CONTROL_REQUEST_ID_HEADER;
use codex_core::AuthManager;
use codex_core::CodexAuth;
use codex_core::test_support::auth_manager_from_auth;
use futures::SinkExt;
use futures::StreamExt;
use std::collections::BTreeMap;
use std::sync::Arc;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio::time::Duration;
use tokio::time::timeout;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::accept_hdr_async;
use tokio_tungstenite::tungstenite;
use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
use tokio_tungstenite::tungstenite::http::HeaderValue;
pub(super) fn remote_control_auth_manager() -> Arc<AuthManager> {
auth_manager_from_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing())
}
#[derive(Debug)]
pub(super) struct CapturedHttpRequest {
pub(super) stream: TcpStream,
pub(super) request_line: String,
pub(super) headers: BTreeMap<String, String>,
pub(super) body: String,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(super) struct CapturedWebSocketRequest {
pub(super) path: String,
pub(super) headers: BTreeMap<String, String>,
}
pub(super) async fn accept_http_request(listener: &TcpListener) -> CapturedHttpRequest {
let (stream, _) = timeout(Duration::from_secs(5), listener.accept())
.await
.expect("HTTP request should arrive in time")
.expect("listener accept should succeed");
let mut reader = BufReader::new(stream);
let mut request_line = String::new();
reader
.read_line(&mut request_line)
.await
.expect("request line should read");
let request_line = request_line.trim_end_matches("\r\n").to_string();
let mut headers = BTreeMap::new();
loop {
let mut line = String::new();
reader
.read_line(&mut line)
.await
.expect("header line should read");
if line == "\r\n" {
break;
}
let line = line.trim_end_matches("\r\n");
let (name, value) = line.split_once(':').expect("header should contain colon");
headers.insert(name.to_ascii_lowercase(), value.trim().to_string());
}
let content_length = headers
.get("content-length")
.and_then(|value| value.parse::<usize>().ok())
.unwrap_or(0);
let mut body = vec![0; content_length];
reader
.read_exact(&mut body)
.await
.expect("request body should read");
CapturedHttpRequest {
stream: reader.into_inner(),
request_line,
headers,
body: String::from_utf8(body).expect("body should be utf-8"),
}
}
pub(super) async fn respond_with_json(mut stream: TcpStream, body: serde_json::Value) {
let body = body.to_string();
let response = format!(
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
body.len()
);
stream
.write_all(response.as_bytes())
.await
.expect("response should write");
stream.flush().await.expect("response should flush");
}
pub(super) async fn respond_with_status_and_headers(
mut stream: TcpStream,
status: &str,
headers: &[(&str, &str)],
body: &str,
) {
let extra_headers = headers
.iter()
.map(|(name, value)| format!("{name}: {value}\r\n"))
.collect::<String>();
let response = format!(
"HTTP/1.1 {status}\r\ncontent-type: text/plain\r\ncontent-length: {}\r\nconnection: close\r\n{extra_headers}\r\n{body}",
body.len(),
);
stream
.write_all(response.as_bytes())
.await
.expect("response should write");
stream.flush().await.expect("response should flush");
}
pub(super) async fn accept_remote_control_backend_connection(
listener: &TcpListener,
request_id: Option<&str>,
) -> (CapturedWebSocketRequest, WebSocketStream<TcpStream>) {
let (stream, _) = timeout(Duration::from_secs(5), listener.accept())
.await
.expect("websocket request should arrive in time")
.expect("listener accept should succeed");
let captured_request = Arc::new(std::sync::Mutex::new(None::<CapturedWebSocketRequest>));
let captured_request_for_callback = captured_request.clone();
let request_id = request_id.map(str::to_owned);
let websocket = accept_hdr_async(
stream,
move |request: &tungstenite::handshake::server::Request,
mut response: tungstenite::handshake::server::Response| {
let headers = request
.headers()
.iter()
.map(|(name, value)| {
(
name.as_str().to_ascii_lowercase(),
value
.to_str()
.expect("header should be valid utf-8")
.to_string(),
)
})
.collect::<BTreeMap<_, _>>();
*captured_request_for_callback
.lock()
.expect("capture lock should acquire") = Some(CapturedWebSocketRequest {
path: request.uri().path().to_string(),
headers,
});
if let Some(request_id) = request_id.as_deref() {
response.headers_mut().insert(
REMOTE_CONTROL_REQUEST_ID_HEADER,
HeaderValue::from_str(request_id)
.expect("request id should be a valid header value"),
);
}
Ok(response)
},
)
.await
.expect("websocket handshake should succeed");
let captured_request = captured_request
.lock()
.expect("capture lock should acquire")
.clone()
.expect("websocket request should be captured");
(captured_request, websocket)
}
pub(super) async fn send_client_event(
websocket: &mut WebSocketStream<TcpStream>,
client_event: ClientEvent,
) {
let payload = serde_json::to_string(&client_event).expect("client event should serialize");
websocket
.send(TungsteniteMessage::Text(payload.into()))
.await
.expect("client event should send");
}
pub(super) async fn read_server_event(
websocket: &mut WebSocketStream<TcpStream>,
) -> serde_json::Value {
loop {
let frame = timeout(Duration::from_secs(5), websocket.next())
.await
.expect("server event should arrive in time")
.expect("websocket should stay open")
.expect("websocket frame should be readable");
match frame {
TungsteniteMessage::Text(text) => {
return serde_json::from_str(text.as_ref())
.expect("server event should deserialize");
}
TungsteniteMessage::Ping(payload) => {
websocket
.send(TungsteniteMessage::Pong(payload))
.await
.expect("websocket pong should send");
}
TungsteniteMessage::Pong(_) => {}
TungsteniteMessage::Close(frame) => {
panic!("unexpected websocket close frame: {frame:?}");
}
TungsteniteMessage::Binary(_) => {
panic!("unexpected binary websocket frame");
}
TungsteniteMessage::Frame(_) => {}
}
}
}

View File

@@ -51,7 +51,7 @@ use tracing::warn;
/// plenty for an interactive CLI.
pub(crate) const CHANNEL_CAPACITY: usize = 128;
fn colorize(text: &str, style: Style) -> String {
pub(crate) fn colorize(text: &str, style: Style) -> String {
text.if_supports_color(Stream::Stderr, |value| value.style(style))
.to_string()
}
@@ -84,7 +84,8 @@ fn print_websocket_startup_banner(addr: SocketAddr) {
#[derive(Clone)]
struct WebSocketListenerState {
transport_event_tx: mpsc::Sender<TransportEvent>,
connection_counter: Arc<AtomicU64>,
connection_id_allocator: ConnectionIdAllocator,
transport_kind: AppServerTransport,
}
async fn health_check_handler() -> StatusCode {
@@ -96,10 +97,16 @@ async fn websocket_upgrade_handler(
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
State(state): State<WebSocketListenerState>,
) -> impl IntoResponse {
let connection_id = ConnectionId(state.connection_counter.fetch_add(1, Ordering::Relaxed));
let connection_id = state.connection_id_allocator.next_connection_id();
info!(%peer_addr, "websocket client connected");
websocket.on_upgrade(move |stream| async move {
run_websocket_connection(connection_id, stream, state.transport_event_tx).await;
run_websocket_connection(
connection_id,
stream,
state.transport_event_tx,
state.transport_kind,
)
.await;
})
}
@@ -107,6 +114,36 @@ async fn websocket_upgrade_handler(
pub enum AppServerTransport {
Stdio,
WebSocket { bind_address: SocketAddr },
RemoteControlled,
}
impl AppServerTransport {
pub const DEFAULT_LISTEN_URL: &'static str = "stdio://";
pub fn from_listen_url(listen_url: &str) -> Result<Self, AppServerTransportParseError> {
if listen_url == Self::DEFAULT_LISTEN_URL {
return Ok(Self::Stdio);
}
if let Some(socket_addr) = listen_url.strip_prefix("ws://") {
let bind_address = socket_addr.parse::<SocketAddr>().map_err(|_| {
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string())
})?;
return Ok(Self::WebSocket { bind_address });
}
Err(AppServerTransportParseError::UnsupportedListenUrl(
listen_url.to_string(),
))
}
pub(crate) fn as_str(self) -> &'static str {
match self {
Self::Stdio => "stdio",
Self::WebSocket { .. } => "websocket",
Self::RemoteControlled => "remote_controlled",
}
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
@@ -132,27 +169,6 @@ impl std::fmt::Display for AppServerTransportParseError {
impl std::error::Error for AppServerTransportParseError {}
impl AppServerTransport {
pub const DEFAULT_LISTEN_URL: &'static str = "stdio://";
pub fn from_listen_url(listen_url: &str) -> Result<Self, AppServerTransportParseError> {
if listen_url == Self::DEFAULT_LISTEN_URL {
return Ok(Self::Stdio);
}
if let Some(socket_addr) = listen_url.strip_prefix("ws://") {
let bind_address = socket_addr.parse::<SocketAddr>().map_err(|_| {
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string())
})?;
return Ok(Self::WebSocket { bind_address });
}
Err(AppServerTransportParseError::UnsupportedListenUrl(
listen_url.to_string(),
))
}
}
impl FromStr for AppServerTransport {
type Err = AppServerTransportParseError;
@@ -161,16 +177,37 @@ impl FromStr for AppServerTransport {
}
}
#[derive(Clone, Debug)]
pub(crate) struct ConnectionIdAllocator {
next_id: Arc<AtomicU64>,
}
impl ConnectionIdAllocator {
pub(crate) fn next_connection_id(&self) -> ConnectionId {
ConnectionId(self.next_id.fetch_add(1, Ordering::Relaxed))
}
}
impl Default for ConnectionIdAllocator {
fn default() -> Self {
Self {
next_id: Arc::new(AtomicU64::new(1)),
}
}
}
#[derive(Debug)]
pub(crate) enum TransportEvent {
ConnectionOpened {
connection_id: ConnectionId,
transport_kind: AppServerTransport,
writer: mpsc::Sender<OutgoingMessage>,
allow_legacy_notifications: bool,
disconnect_sender: Option<CancellationToken>,
},
ConnectionClosed {
connection_id: ConnectionId,
transport_kind: AppServerTransport,
},
IncomingMessage {
connection_id: ConnectionId,
@@ -179,6 +216,7 @@ pub(crate) enum TransportEvent {
}
pub(crate) struct ConnectionState {
pub(crate) transport_kind: AppServerTransport,
pub(crate) outbound_initialized: Arc<AtomicBool>,
pub(crate) outbound_experimental_api_enabled: Arc<AtomicBool>,
pub(crate) outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
@@ -187,11 +225,13 @@ pub(crate) struct ConnectionState {
impl ConnectionState {
pub(crate) fn new(
transport_kind: AppServerTransport,
outbound_initialized: Arc<AtomicBool>,
outbound_experimental_api_enabled: Arc<AtomicBool>,
outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
) -> Self {
Self {
transport_kind,
outbound_initialized,
outbound_experimental_api_enabled,
outbound_opted_out_notification_methods,
@@ -249,6 +289,7 @@ pub(crate) async fn start_stdio_connection(
transport_event_tx
.send(TransportEvent::ConnectionOpened {
connection_id,
transport_kind: AppServerTransport::Stdio,
writer: writer_tx,
allow_legacy_notifications: false,
disconnect_sender: None,
@@ -285,7 +326,10 @@ pub(crate) async fn start_stdio_connection(
}
let _ = transport_event_tx_for_reader
.send(TransportEvent::ConnectionClosed { connection_id })
.send(TransportEvent::ConnectionClosed {
connection_id,
transport_kind: AppServerTransport::Stdio,
})
.await;
debug!("stdin reader finished (EOF)");
}));
@@ -312,6 +356,7 @@ pub(crate) async fn start_websocket_acceptor(
bind_address: SocketAddr,
transport_event_tx: mpsc::Sender<TransportEvent>,
shutdown_token: CancellationToken,
connection_id_allocator: ConnectionIdAllocator,
) -> IoResult<JoinHandle<()>> {
let listener = TcpListener::bind(bind_address).await?;
let local_addr = listener.local_addr()?;
@@ -324,7 +369,10 @@ pub(crate) async fn start_websocket_acceptor(
.fallback(any(websocket_upgrade_handler))
.with_state(WebSocketListenerState {
transport_event_tx,
connection_counter: Arc::new(AtomicU64::new(1)),
connection_id_allocator,
transport_kind: AppServerTransport::WebSocket {
bind_address: local_addr,
},
});
let server = axum::serve(
listener,
@@ -345,6 +393,7 @@ async fn run_websocket_connection(
connection_id: ConnectionId,
websocket_stream: WebSocket,
transport_event_tx: mpsc::Sender<TransportEvent>,
transport_kind: AppServerTransport,
) {
let (writer_tx, writer_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
let writer_tx_for_reader = writer_tx.clone();
@@ -352,6 +401,7 @@ async fn run_websocket_connection(
if transport_event_tx
.send(TransportEvent::ConnectionOpened {
connection_id,
transport_kind,
writer: writer_tx,
allow_legacy_notifications: false,
disconnect_sender: Some(disconnect_token.clone()),
@@ -392,7 +442,10 @@ async fn run_websocket_connection(
}
let _ = transport_event_tx
.send(TransportEvent::ConnectionClosed { connection_id })
.send(TransportEvent::ConnectionClosed {
connection_id,
transport_kind,
})
.await;
}

View File

@@ -0,0 +1,419 @@
use super::connection_handling_websocket::connect_websocket;
use super::connection_handling_websocket::read_response_for_id;
use super::connection_handling_websocket::send_initialize_request;
use super::connection_handling_websocket::send_request;
use super::connection_handling_websocket::spawn_websocket_server_with_args;
use anyhow::Context;
use anyhow::Result;
use app_test_support::ChatGptAuthFixture;
use app_test_support::create_mock_responses_server_sequence_unchecked;
use app_test_support::write_chatgpt_auth;
use codex_app_server_protocol::ClientInfo;
use codex_app_server_protocol::InitializeParams;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::RequestId;
use codex_core::auth::AuthCredentialsStoreMode;
use futures::SinkExt;
use futures::StreamExt;
use pretty_assertions::assert_eq;
use serde_json::Value;
use serde_json::json;
use std::path::Path;
use std::process::Stdio;
use std::sync::Arc;
use tempfile::TempDir;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio::process::Child;
use tokio::process::Command;
use tokio::time::Duration;
use tokio::time::timeout;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::accept_hdr_async;
use tokio_tungstenite::tungstenite;
use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
type BackendWebSocket = WebSocketStream<TcpStream>;
#[tokio::test]
async fn websocket_transport_with_remote_control_routes_connections_independently() -> Result<()> {
let server = create_mock_responses_server_sequence_unchecked(Vec::new()).await;
let remote_listener = TcpListener::bind("127.0.0.1:0")
.await
.context("listener should bind")?;
let remote_control_base_url = format!(
"http://{}/backend-api",
remote_listener
.local_addr()
.context("listener should have local addr")?
);
let codex_home = TempDir::new()?;
create_config_toml_with_remote_control(
codex_home.path(),
&server.uri(),
&remote_control_base_url,
"never",
)?;
write_chatgpt_auth(
codex_home.path(),
ChatGptAuthFixture::new("Access Token").account_id("account_id"),
AuthCredentialsStoreMode::File,
)?;
let (mut process, bind_addr) =
spawn_websocket_server_with_args(codex_home.path(), &["--with-remote-control"]).await?;
let enroll_request = accept_http_request(&remote_listener).await?;
assert_eq!(
enroll_request.request_line,
"POST /backend-api/remote/control/server/enroll HTTP/1.1"
);
respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_mixed" })).await?;
let (backend_request, mut backend_websocket) =
accept_remote_control_backend_connection(&remote_listener).await?;
assert_eq!(backend_request.path, "/backend-api/remote/control/server");
let mut local_websocket = connect_websocket(bind_addr).await?;
send_initialize_request(&mut local_websocket, 11, "local_ws_client").await?;
assert_eq!(
read_response_for_id(&mut local_websocket, 11).await?.id,
RequestId::Integer(11)
);
send_remote_request(
&mut backend_websocket,
"remote-client-1",
"initialize",
11,
Some(serde_json::to_value(InitializeParams {
client_info: ClientInfo {
name: "remote_control_client".to_string(),
title: Some("Remote Control Test Client".to_string()),
version: "0.1.0".to_string(),
},
capabilities: None,
})?),
)
.await?;
let remote_initialize =
read_remote_response_for_id(&mut backend_websocket, "remote-client-1", 11).await?;
assert_eq!(remote_initialize["id"], json!(11));
send_request(
&mut local_websocket,
"config/read",
77,
Some(json!({ "includeLayers": false })),
)
.await?;
send_remote_request(
&mut backend_websocket,
"remote-client-1",
"config/read",
77,
Some(json!({ "includeLayers": false })),
)
.await?;
let local_response = read_response_for_id(&mut local_websocket, 77).await?;
let remote_response =
read_remote_response_for_id(&mut backend_websocket, "remote-client-1", 77).await?;
assert_eq!(local_response.id, RequestId::Integer(77));
assert!(local_response.result.get("config").is_some());
assert_eq!(remote_response["id"], json!(77));
assert!(remote_response["result"].get("config").is_some());
process
.kill()
.await
.context("failed to stop websocket app-server process")?;
Ok(())
}
#[tokio::test]
async fn stdio_transport_with_remote_control_exits_when_stdio_closes() -> Result<()> {
let server = create_mock_responses_server_sequence_unchecked(Vec::new()).await;
let remote_listener = TcpListener::bind("127.0.0.1:0")
.await
.context("listener should bind")?;
let remote_control_base_url = format!(
"http://{}/backend-api",
remote_listener
.local_addr()
.context("listener should have local addr")?
);
let codex_home = TempDir::new()?;
create_config_toml_with_remote_control(
codex_home.path(),
&server.uri(),
&remote_control_base_url,
"never",
)?;
write_chatgpt_auth(
codex_home.path(),
ChatGptAuthFixture::new("Access Token").account_id("account_id"),
AuthCredentialsStoreMode::File,
)?;
let mut process = spawn_stdio_server_with_remote_control(codex_home.path()).await?;
let enroll_request = accept_http_request(&remote_listener).await?;
assert_eq!(
enroll_request.request_line,
"POST /backend-api/remote/control/server/enroll HTTP/1.1"
);
respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_stdio" })).await?;
let (_backend_request, mut backend_websocket) =
accept_remote_control_backend_connection(&remote_listener).await?;
drop(process.stdin.take());
let exit_status = timeout(Duration::from_secs(10), process.wait())
.await
.context("timed out waiting for stdio app-server to exit")?
.context("failed waiting for stdio app-server exit")?;
assert!(exit_status.success());
let close_frame = timeout(Duration::from_secs(5), backend_websocket.next())
.await
.context("timed out waiting for remote-control websocket to close")?;
match close_frame {
Some(Ok(WebSocketMessage::Close(_))) | Some(Err(_)) | None => {}
Some(Ok(other)) => {
anyhow::bail!("unexpected websocket frame while waiting for shutdown: {other:?}")
}
}
Ok(())
}
struct CapturedHttpRequest {
stream: TcpStream,
request_line: String,
}
#[derive(Clone, Debug, PartialEq, Eq)]
struct CapturedWebSocketRequest {
path: String,
}
async fn accept_http_request(listener: &TcpListener) -> Result<CapturedHttpRequest> {
let (stream, _) = timeout(Duration::from_secs(10), listener.accept())
.await
.context("HTTP request should arrive in time")?
.context("listener accept should succeed")?;
let mut reader = BufReader::new(stream);
let mut request_line = String::new();
reader
.read_line(&mut request_line)
.await
.context("request line should read")?;
let request_line = request_line.trim_end_matches("\r\n").to_string();
let mut headers = std::collections::BTreeMap::new();
loop {
let mut line = String::new();
reader
.read_line(&mut line)
.await
.context("header line should read")?;
if line == "\r\n" {
break;
}
let line = line.trim_end_matches("\r\n");
let (name, value) = line
.split_once(':')
.context("header should contain colon")?;
headers.insert(name.to_ascii_lowercase(), value.trim().to_string());
}
let content_length = headers
.get("content-length")
.and_then(|value| value.parse::<usize>().ok())
.unwrap_or(0);
let mut body = vec![0; content_length];
reader
.read_exact(&mut body)
.await
.context("request body should read")?;
Ok(CapturedHttpRequest {
stream: reader.into_inner(),
request_line,
})
}
async fn respond_with_json(mut stream: TcpStream, body: serde_json::Value) -> Result<()> {
let body = body.to_string();
let response = format!(
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
body.len()
);
stream
.write_all(response.as_bytes())
.await
.context("response should write")?;
stream.flush().await.context("response should flush")?;
Ok(())
}
async fn accept_remote_control_backend_connection(
listener: &TcpListener,
) -> Result<(CapturedWebSocketRequest, BackendWebSocket)> {
let (stream, _) = timeout(Duration::from_secs(10), listener.accept())
.await
.context("websocket request should arrive in time")?
.context("listener accept should succeed")?;
let captured_request = Arc::new(std::sync::Mutex::new(None::<CapturedWebSocketRequest>));
let captured_request_for_callback = Arc::clone(&captured_request);
let websocket = accept_hdr_async(
stream,
move |request: &tungstenite::handshake::server::Request,
response: tungstenite::handshake::server::Response| {
let mut guard = match captured_request_for_callback.lock() {
Ok(guard) => guard,
Err(err) => panic!("capture lock should acquire: {err}"),
};
*guard = Some(CapturedWebSocketRequest {
path: request.uri().path().to_string(),
});
Ok(response)
},
)
.await
.context("websocket handshake should succeed")?;
let captured_request = match captured_request.lock() {
Ok(guard) => guard.clone(),
Err(err) => panic!("capture lock should acquire: {err}"),
}
.context("websocket request should be captured")?;
Ok((captured_request, websocket))
}
async fn send_remote_request(
websocket: &mut BackendWebSocket,
client_id: &str,
method: &str,
id: i64,
params: Option<Value>,
) -> Result<()> {
let message = serde_json::to_value(JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(id),
method: method.to_string(),
params,
trace: None,
}))?;
let payload = json!({
"type": "client_message",
"client_id": client_id,
"message": message,
});
websocket
.send(WebSocketMessage::Text(payload.to_string().into()))
.await
.context("client event should send")?;
Ok(())
}
async fn read_remote_response_for_id(
websocket: &mut BackendWebSocket,
client_id: &str,
id: i64,
) -> Result<Value> {
loop {
let event = read_remote_server_event(websocket).await?;
if event["type"] == json!("server_message")
&& event["client_id"] == json!(client_id)
&& event["message"]["id"] == json!(id)
{
return Ok(event["message"].clone());
}
}
}
async fn read_remote_server_event(websocket: &mut BackendWebSocket) -> Result<Value> {
loop {
let frame = timeout(Duration::from_secs(5), websocket.next())
.await
.context("server event should arrive in time")?
.context("websocket should stay open")?
.context("websocket frame should be readable")?;
match frame {
WebSocketMessage::Text(text) => {
return serde_json::from_str(text.as_ref())
.context("server event should deserialize");
}
WebSocketMessage::Ping(payload) => {
websocket
.send(WebSocketMessage::Pong(payload))
.await
.context("websocket pong should send")?;
}
WebSocketMessage::Pong(_) => {}
WebSocketMessage::Close(frame) => {
anyhow::bail!("unexpected websocket close frame: {frame:?}");
}
WebSocketMessage::Binary(_) => anyhow::bail!("unexpected binary websocket frame"),
WebSocketMessage::Frame(_) => {}
}
}
}
async fn spawn_stdio_server_with_remote_control(codex_home: &Path) -> Result<Child> {
let program = codex_utils_cargo_bin::cargo_bin("codex-app-server")
.context("should find app-server binary")?;
let mut command = Command::new(program);
command
.arg("--with-remote-control")
.stdin(Stdio::piped())
.stdout(Stdio::null())
.stderr(Stdio::piped())
.env("CODEX_HOME", codex_home)
.env("RUST_LOG", "debug");
let mut process = command
.kill_on_drop(true)
.spawn()
.context("failed to spawn stdio app-server process")?;
if let Some(stderr) = process.stderr.take() {
let mut stderr_reader = BufReader::new(stderr).lines();
tokio::spawn(async move {
while let Ok(Some(line)) = stderr_reader.next_line().await {
eprintln!("[stdio app-server stderr] {line}");
}
});
}
Ok(process)
}
fn create_config_toml_with_remote_control(
codex_home: &Path,
server_uri: &str,
remote_control_base_url: &str,
approval_policy: &str,
) -> std::io::Result<()> {
let config_toml = codex_home.join("config.toml");
std::fs::write(
config_toml,
format!(
r#"
model = "mock-model"
approval_policy = "{approval_policy}"
sandbox_mode = "read-only"
chatgpt_base_url = "{remote_control_base_url}"
model_provider = "mock_provider"
[model_providers.mock_provider]
name = "Mock provider for test"
base_url = "{server_uri}/v1"
wire_api = "responses"
request_max_retries = 0
stream_max_retries = 0
"#
),
)
}

View File

@@ -108,12 +108,21 @@ async fn websocket_transport_serves_health_endpoints_on_same_listener() -> Resul
}
pub(super) async fn spawn_websocket_server(codex_home: &Path) -> Result<(Child, SocketAddr)> {
spawn_websocket_server_with_args(codex_home, &[]).await
}
pub(super) async fn spawn_websocket_server_with_args(
codex_home: &Path,
extra_args: &[&str],
) -> Result<(Child, SocketAddr)> {
let program = codex_utils_cargo_bin::cargo_bin("codex-app-server")
.context("should find app-server binary")?;
let mut cmd = Command::new(program);
cmd.arg("--listen")
.arg("ws://127.0.0.1:0")
.stdin(Stdio::null())
cmd.arg("--listen").arg("ws://127.0.0.1:0");
for extra_arg in extra_args {
cmd.arg(extra_arg);
}
cmd.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::piped())
.env("CODEX_HOME", codex_home)
@@ -158,10 +167,11 @@ pub(super) async fn spawn_websocket_server(codex_home: &Path) -> Result<(Child,
stripped
};
if let Some(bind_addr) = stripped_line
.split_whitespace()
.find_map(|token| token.strip_prefix("ws://"))
.and_then(|addr| addr.parse::<SocketAddr>().ok())
if stripped_line.contains("listening on:")
&& let Some(bind_addr) = stripped_line
.split_whitespace()
.find_map(|token| token.strip_prefix("ws://"))
.and_then(|addr| addr.parse::<SocketAddr>().ok())
{
break bind_addr;
}

View File

@@ -6,6 +6,7 @@ mod collaboration_mode_list;
mod command_exec;
mod compaction;
mod config_rpc;
mod connection_handling_mixed_remote_control;
mod connection_handling_websocket;
#[cfg(unix)]
mod connection_handling_websocket_unix;

View File

@@ -8,6 +8,10 @@ license.workspace = true
name = "codex"
path = "src/main.rs"
[[bin]]
name = "codexd"
path = "src/bin/codexd.rs"
[lib]
name = "codex_cli"
path = "src/lib.rs"

View File

@@ -0,0 +1,42 @@
use clap::Parser;
use codex_app_server::run_main_with_runtime;
use codex_arg0::Arg0DispatchPaths;
use codex_arg0::arg0_dispatch_or_else;
use codex_core::config_loader::LoaderOverrides;
use codex_utils_cli::CliConfigOverrides;
use std::io::ErrorKind;
#[derive(Debug, Parser)]
#[clap(
author,
version,
bin_name = "codexd",
override_usage = "codexd [OPTIONS]"
)]
struct CodexdCli {
#[clap(flatten)]
config_overrides: CliConfigOverrides,
}
fn main() -> anyhow::Result<()> {
arg0_dispatch_or_else(|arg0_paths: Arg0DispatchPaths| async move {
let cli = CodexdCli::parse();
cli.config_overrides.parse_overrides().map_err(|err| {
std::io::Error::new(
ErrorKind::InvalidInput,
format!("error parsing -c overrides: {err}"),
)
})?;
run_main_with_runtime(
arg0_paths,
cli.config_overrides,
LoaderOverrides::default(),
true,
None,
true,
)
.await?;
Ok(())
})
}

View File

@@ -331,6 +331,11 @@ struct AppServerCommand {
)]
listen: codex_app_server::AppServerTransport,
/// Also connect outbound to the ChatGPT remote control server derived from
/// the configured `chatgpt_base_url`.
#[arg(long = "with-remote-control", default_value_t = false)]
with_remote_control: bool,
/// Controls whether analytics are enabled by default.
///
/// Analytics are disabled by default for app-server. Users have to explicitly opt in
@@ -630,13 +635,13 @@ async fn cli_main(arg0_paths: Arg0DispatchPaths) -> anyhow::Result<()> {
Some(Subcommand::AppServer(app_server_cli)) => match app_server_cli.subcommand {
None => {
reject_remote_mode_for_subcommand(root_remote.as_deref(), "app-server")?;
let transport = app_server_cli.listen;
codex_app_server::run_main_with_transport(
codex_app_server::run_main_with_runtime(
arg0_paths.clone(),
root_config_overrides,
codex_core::config_loader::LoaderOverrides::default(),
app_server_cli.analytics_default_enabled,
transport,
Some(app_server_cli.listen),
app_server_cli.with_remote_control,
)
.await?;
}
@@ -1600,6 +1605,7 @@ mod tests {
app_server.listen,
codex_app_server::AppServerTransport::Stdio
);
assert!(!app_server.with_remote_control);
}
#[test]