mirror of
https://github.com/openai/codex.git
synced 2026-05-05 11:57:33 +00:00
Compare commits
2 Commits
commit-495
...
anton_pana
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c991b670ee | ||
|
|
83056d0474 |
2
codex-rs/Cargo.lock
generated
2
codex-rs/Cargo.lock
generated
@@ -1440,8 +1440,10 @@ dependencies = [
|
||||
"codex-utils-cli",
|
||||
"codex-utils-json-to-toml",
|
||||
"codex-utils-pty",
|
||||
"codex-utils-rustls-provider",
|
||||
"core_test_support",
|
||||
"futures",
|
||||
"gethostname",
|
||||
"opentelemetry",
|
||||
"opentelemetry_sdk",
|
||||
"owo-colors",
|
||||
|
||||
@@ -47,9 +47,11 @@ codex-rmcp-client = { workspace = true }
|
||||
codex-state = { workspace = true }
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
codex-utils-json-to-toml = { workspace = true }
|
||||
codex-utils-rustls-provider = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
futures = { workspace = true }
|
||||
gethostname = { workspace = true }
|
||||
owo-colors = { workspace = true, features = ["supports-colors"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
`codex app-server` is the interface Codex uses to power rich interfaces such as the [Codex VS Code extension](https://marketplace.visualstudio.com/items?itemName=openai.chatgpt).
|
||||
|
||||
For remote-control-only deployments, use `codexd`. It runs the same app-server runtime in a headless daemon mode, connects outbound to the ChatGPT remote control server using ChatGPT auth, and does not expose a local stdio or websocket transport.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Protocol](#protocol)
|
||||
@@ -25,6 +27,24 @@ Supported transports:
|
||||
|
||||
- stdio (`--listen stdio://`, default): newline-delimited JSON (JSONL)
|
||||
- websocket (`--listen ws://IP:PORT`): one JSON-RPC message per websocket text frame (**experimental / unsupported**)
|
||||
- remote control (`--with-remote-control`): also connect outbound to the ChatGPT remote control server derived from `chatgpt_base_url`
|
||||
|
||||
You can combine a local transport with remote control in the same process:
|
||||
|
||||
```sh
|
||||
codex app-server --listen stdio:// --with-remote-control
|
||||
codex app-server --listen ws://127.0.0.1:8080 --with-remote-control
|
||||
```
|
||||
|
||||
Both local and remote-controlled clients share the same in-process app-server state, and remote-controlled clients still use the normal JSON-RPC connection lifecycle, including `initialize`.
|
||||
|
||||
For remote-control-only deployments, keep using `codexd`:
|
||||
|
||||
```sh
|
||||
codexd
|
||||
```
|
||||
|
||||
`codexd` runs the same runtime with no local listener and remote control enabled by default.
|
||||
|
||||
When running with `--listen ws://IP:PORT`, the same listener also serves basic HTTP health probes:
|
||||
|
||||
|
||||
@@ -29,12 +29,8 @@ pub(crate) fn request_span(
|
||||
) -> Span {
|
||||
let initialize_client_info = initialize_client_info(request);
|
||||
let method = request.method.as_str();
|
||||
let span = app_server_request_span_template(
|
||||
method,
|
||||
transport_name(transport),
|
||||
&request.id,
|
||||
connection_id,
|
||||
);
|
||||
let span =
|
||||
app_server_request_span_template(method, transport.as_str(), &request.id, connection_id);
|
||||
|
||||
record_client_info(
|
||||
&span,
|
||||
@@ -82,13 +78,6 @@ pub(crate) fn typed_request_span(
|
||||
span
|
||||
}
|
||||
|
||||
fn transport_name(transport: AppServerTransport) -> &'static str {
|
||||
match transport {
|
||||
AppServerTransport::Stdio => "stdio",
|
||||
AppServerTransport::WebSocket { .. } => "websocket",
|
||||
}
|
||||
}
|
||||
|
||||
fn app_server_request_span_template(
|
||||
method: &str,
|
||||
transport: &'static str,
|
||||
|
||||
@@ -23,6 +23,7 @@ use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::OutgoingEnvelope;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::transport::CHANNEL_CAPACITY;
|
||||
use crate::transport::ConnectionIdAllocator;
|
||||
use crate::transport::ConnectionState;
|
||||
use crate::transport::OutboundConnectionState;
|
||||
use crate::transport::TransportEvent;
|
||||
@@ -76,6 +77,8 @@ mod thread_state;
|
||||
mod thread_status;
|
||||
mod transport;
|
||||
|
||||
mod remote_control;
|
||||
|
||||
pub use crate::error_code::INPUT_TOO_LARGE_ERROR_CODE;
|
||||
pub use crate::error_code::INVALID_PARAMS_ERROR_CODE;
|
||||
pub use crate::transport::AppServerTransport;
|
||||
@@ -330,12 +333,32 @@ pub async fn run_main(
|
||||
loader_overrides: LoaderOverrides,
|
||||
default_analytics_enabled: bool,
|
||||
) -> IoResult<()> {
|
||||
run_main_with_transport(
|
||||
run_main_with_runtime(
|
||||
arg0_paths,
|
||||
cli_config_overrides,
|
||||
loader_overrides,
|
||||
default_analytics_enabled,
|
||||
AppServerTransport::Stdio,
|
||||
Some(AppServerTransport::Stdio),
|
||||
false,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn run_main_with_runtime(
|
||||
arg0_paths: Arg0DispatchPaths,
|
||||
cli_config_overrides: CliConfigOverrides,
|
||||
loader_overrides: LoaderOverrides,
|
||||
default_analytics_enabled: bool,
|
||||
local_transport: Option<AppServerTransport>,
|
||||
with_remote_control: bool,
|
||||
) -> IoResult<()> {
|
||||
run_main_with_transport_impl(
|
||||
arg0_paths,
|
||||
cli_config_overrides,
|
||||
loader_overrides,
|
||||
default_analytics_enabled,
|
||||
local_transport,
|
||||
with_remote_control,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -347,43 +370,64 @@ pub async fn run_main_with_transport(
|
||||
default_analytics_enabled: bool,
|
||||
transport: AppServerTransport,
|
||||
) -> IoResult<()> {
|
||||
match transport {
|
||||
AppServerTransport::Stdio => {
|
||||
run_main_with_runtime(
|
||||
arg0_paths,
|
||||
cli_config_overrides,
|
||||
loader_overrides,
|
||||
default_analytics_enabled,
|
||||
Some(AppServerTransport::Stdio),
|
||||
false,
|
||||
)
|
||||
.await
|
||||
}
|
||||
AppServerTransport::WebSocket { bind_address } => {
|
||||
run_main_with_runtime(
|
||||
arg0_paths,
|
||||
cli_config_overrides,
|
||||
loader_overrides,
|
||||
default_analytics_enabled,
|
||||
Some(AppServerTransport::WebSocket { bind_address }),
|
||||
false,
|
||||
)
|
||||
.await
|
||||
}
|
||||
AppServerTransport::RemoteControlled => {
|
||||
run_main_with_runtime(
|
||||
arg0_paths,
|
||||
cli_config_overrides,
|
||||
loader_overrides,
|
||||
default_analytics_enabled,
|
||||
None,
|
||||
true,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_main_with_transport_impl(
|
||||
arg0_paths: Arg0DispatchPaths,
|
||||
cli_config_overrides: CliConfigOverrides,
|
||||
loader_overrides: LoaderOverrides,
|
||||
default_analytics_enabled: bool,
|
||||
local_transport: Option<AppServerTransport>,
|
||||
with_remote_control: bool,
|
||||
) -> IoResult<()> {
|
||||
if matches!(local_transport, Some(AppServerTransport::RemoteControlled)) {
|
||||
return Err(std::io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
"remote-controlled transport cannot be used as a local listener",
|
||||
));
|
||||
}
|
||||
|
||||
let (transport_event_tx, mut transport_event_rx) =
|
||||
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingEnvelope>(CHANNEL_CAPACITY);
|
||||
let (outbound_control_tx, mut outbound_control_rx) =
|
||||
mpsc::channel::<OutboundControlEvent>(CHANNEL_CAPACITY);
|
||||
|
||||
enum TransportRuntime {
|
||||
Stdio,
|
||||
WebSocket {
|
||||
accept_handle: JoinHandle<()>,
|
||||
shutdown_token: CancellationToken,
|
||||
},
|
||||
}
|
||||
|
||||
let mut stdio_handles = Vec::<JoinHandle<()>>::new();
|
||||
let transport_runtime = match transport {
|
||||
AppServerTransport::Stdio => {
|
||||
start_stdio_connection(transport_event_tx.clone(), &mut stdio_handles).await?;
|
||||
TransportRuntime::Stdio
|
||||
}
|
||||
AppServerTransport::WebSocket { bind_address } => {
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let accept_handle = start_websocket_acceptor(
|
||||
bind_address,
|
||||
transport_event_tx.clone(),
|
||||
shutdown_token.clone(),
|
||||
)
|
||||
.await?;
|
||||
TransportRuntime::WebSocket {
|
||||
accept_handle,
|
||||
shutdown_token,
|
||||
}
|
||||
}
|
||||
};
|
||||
let single_client_mode = matches!(&transport_runtime, TransportRuntime::Stdio);
|
||||
let shutdown_when_no_connections = single_client_mode;
|
||||
let graceful_signal_restart_enabled = !single_client_mode;
|
||||
// Parse CLI overrides once and derive the base Config eagerly so later
|
||||
// components do not need to work with raw TOML values.
|
||||
let cli_kv_overrides = cli_config_overrides.parse_overrides().map_err(|e| {
|
||||
@@ -466,6 +510,49 @@ pub async fn run_main_with_transport(
|
||||
config_warnings.push(message);
|
||||
}
|
||||
|
||||
let transport_shutdown_token = CancellationToken::new();
|
||||
let mut transport_accept_handles = Vec::<JoinHandle<()>>::new();
|
||||
let connection_id_allocator = ConnectionIdAllocator::default();
|
||||
match local_transport {
|
||||
Some(AppServerTransport::Stdio) => {
|
||||
start_stdio_connection(transport_event_tx.clone(), &mut transport_accept_handles)
|
||||
.await?;
|
||||
}
|
||||
Some(AppServerTransport::WebSocket { bind_address }) => {
|
||||
transport_accept_handles.push(
|
||||
start_websocket_acceptor(
|
||||
bind_address,
|
||||
transport_event_tx.clone(),
|
||||
transport_shutdown_token.clone(),
|
||||
connection_id_allocator.clone(),
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
}
|
||||
Some(AppServerTransport::RemoteControlled) => unreachable!(),
|
||||
None => {}
|
||||
}
|
||||
if with_remote_control {
|
||||
transport_accept_handles.push(
|
||||
remote_control::start_remote_control(
|
||||
config.chatgpt_base_url.clone(),
|
||||
config.codex_home.clone(),
|
||||
AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
false,
|
||||
config.cli_auth_credentials_store_mode,
|
||||
),
|
||||
transport_event_tx.clone(),
|
||||
transport_shutdown_token.clone(),
|
||||
connection_id_allocator,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
}
|
||||
let shutdown_when_stdio_disconnects =
|
||||
matches!(local_transport, Some(AppServerTransport::Stdio));
|
||||
let graceful_signal_restart_enabled = !shutdown_when_stdio_disconnects;
|
||||
|
||||
if let Some(warning) = project_config_warning(&config) {
|
||||
config_warnings.push(warning);
|
||||
}
|
||||
@@ -619,10 +706,7 @@ pub async fn run_main_with_transport(
|
||||
let mut thread_created_rx = processor.thread_created_receiver();
|
||||
let mut running_turn_count_rx = processor.subscribe_running_assistant_turn_count();
|
||||
let mut connections = HashMap::<ConnectionId, ConnectionState>::new();
|
||||
let websocket_accept_shutdown = match &transport_runtime {
|
||||
TransportRuntime::WebSocket { shutdown_token, .. } => Some(shutdown_token.clone()),
|
||||
TransportRuntime::Stdio => None,
|
||||
};
|
||||
let transport_shutdown_token = transport_shutdown_token.clone();
|
||||
async move {
|
||||
let mut listen_for_threads = true;
|
||||
let mut shutdown_state = ShutdownState::default();
|
||||
@@ -635,9 +719,7 @@ pub async fn run_main_with_transport(
|
||||
shutdown_state.update(running_turn_count, connections.len()),
|
||||
ShutdownAction::Finish
|
||||
) {
|
||||
if let Some(shutdown_token) = &websocket_accept_shutdown {
|
||||
shutdown_token.cancel();
|
||||
}
|
||||
transport_shutdown_token.cancel();
|
||||
let _ = outbound_control_tx
|
||||
.send(OutboundControlEvent::DisconnectAll)
|
||||
.await;
|
||||
@@ -664,10 +746,16 @@ pub async fn run_main_with_transport(
|
||||
match event {
|
||||
TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
transport_kind,
|
||||
writer,
|
||||
allow_legacy_notifications,
|
||||
disconnect_sender,
|
||||
} => {
|
||||
info!(
|
||||
connection_id = %connection_id,
|
||||
transport = transport_kind.as_str(),
|
||||
"app-server connection opened"
|
||||
);
|
||||
let outbound_initialized = Arc::new(AtomicBool::new(false));
|
||||
let outbound_experimental_api_enabled =
|
||||
Arc::new(AtomicBool::new(false));
|
||||
@@ -695,13 +783,22 @@ pub async fn run_main_with_transport(
|
||||
connections.insert(
|
||||
connection_id,
|
||||
ConnectionState::new(
|
||||
transport_kind,
|
||||
outbound_initialized,
|
||||
outbound_experimental_api_enabled,
|
||||
outbound_opted_out_notification_methods,
|
||||
),
|
||||
);
|
||||
}
|
||||
TransportEvent::ConnectionClosed { connection_id } => {
|
||||
TransportEvent::ConnectionClosed {
|
||||
connection_id,
|
||||
transport_kind,
|
||||
} => {
|
||||
info!(
|
||||
connection_id = %connection_id,
|
||||
transport = transport_kind.as_str(),
|
||||
"app-server connection closed"
|
||||
);
|
||||
if connections.remove(&connection_id).is_none() {
|
||||
continue;
|
||||
}
|
||||
@@ -713,7 +810,8 @@ pub async fn run_main_with_transport(
|
||||
break;
|
||||
}
|
||||
processor.connection_closed(connection_id).await;
|
||||
if shutdown_when_no_connections && connections.is_empty() {
|
||||
if shutdown_when_stdio_disconnects && connection_id == ConnectionId(0)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -729,7 +827,7 @@ pub async fn run_main_with_transport(
|
||||
.process_request(
|
||||
connection_id,
|
||||
request,
|
||||
transport,
|
||||
connection_state.transport_kind,
|
||||
&mut connection_state.session,
|
||||
)
|
||||
.await;
|
||||
@@ -833,16 +931,8 @@ pub async fn run_main_with_transport(
|
||||
let _ = processor_handle.await;
|
||||
let _ = outbound_handle.await;
|
||||
|
||||
if let TransportRuntime::WebSocket {
|
||||
accept_handle,
|
||||
shutdown_token,
|
||||
} = transport_runtime
|
||||
{
|
||||
shutdown_token.cancel();
|
||||
let _ = accept_handle.await;
|
||||
}
|
||||
|
||||
for handle in stdio_handles {
|
||||
transport_shutdown_token.cancel();
|
||||
for handle in transport_accept_handles {
|
||||
let _ = handle.await;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use clap::Parser;
|
||||
use codex_app_server::AppServerTransport;
|
||||
use codex_app_server::run_main_with_transport;
|
||||
use codex_app_server::run_main_with_runtime;
|
||||
use codex_arg0::Arg0DispatchPaths;
|
||||
use codex_arg0::arg0_dispatch_or_else;
|
||||
use codex_core::config_loader::LoaderOverrides;
|
||||
@@ -21,6 +21,11 @@ struct AppServerArgs {
|
||||
default_value = AppServerTransport::DEFAULT_LISTEN_URL
|
||||
)]
|
||||
listen: AppServerTransport,
|
||||
|
||||
/// Also connect outbound to the ChatGPT remote control server derived from
|
||||
/// the configured `chatgpt_base_url`.
|
||||
#[arg(long = "with-remote-control", default_value_t = false)]
|
||||
with_remote_control: bool,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
@@ -31,14 +36,13 @@ fn main() -> anyhow::Result<()> {
|
||||
managed_config_path,
|
||||
..Default::default()
|
||||
};
|
||||
let transport = args.listen;
|
||||
|
||||
run_main_with_transport(
|
||||
run_main_with_runtime(
|
||||
arg0_paths,
|
||||
CliConfigOverrides::default(),
|
||||
loader_overrides,
|
||||
false,
|
||||
transport,
|
||||
Some(args.listen),
|
||||
args.with_remote_control,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
@@ -59,3 +63,34 @@ fn managed_config_path_from_debug_env() -> Option<PathBuf> {
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::AppServerArgs;
|
||||
use clap::Parser;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn app_server_args_default_to_stdio_without_remote_control() {
|
||||
let args = AppServerArgs::parse_from(["codex-app-server"]);
|
||||
assert_eq!(args.listen, super::AppServerTransport::Stdio);
|
||||
assert!(!args.with_remote_control);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_args_parse_with_remote_control_flag() {
|
||||
let args = AppServerArgs::parse_from([
|
||||
"codex-app-server",
|
||||
"--listen",
|
||||
"ws://127.0.0.1:8080",
|
||||
"--with-remote-control",
|
||||
]);
|
||||
assert_eq!(
|
||||
args.listen,
|
||||
super::AppServerTransport::WebSocket {
|
||||
bind_address: "127.0.0.1:8080".parse().expect("valid socket address"),
|
||||
}
|
||||
);
|
||||
assert!(args.with_remote_control);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ use app_test_support::create_mock_responses_server_repeating_assistant;
|
||||
use app_test_support::write_mock_responses_config_toml;
|
||||
use codex_app_server_protocol::ClientInfo;
|
||||
use codex_app_server_protocol::ClientRequest;
|
||||
use codex_app_server_protocol::ConfigReadParams;
|
||||
use codex_app_server_protocol::InitializeCapabilities;
|
||||
use codex_app_server_protocol::InitializeParams;
|
||||
use codex_app_server_protocol::InitializeResponse;
|
||||
@@ -164,6 +165,19 @@ impl TracingHarness {
|
||||
}
|
||||
|
||||
async fn request<T>(&mut self, request: ClientRequest, trace: Option<W3cTraceContext>) -> T
|
||||
where
|
||||
T: serde::de::DeserializeOwned,
|
||||
{
|
||||
self.request_with_transport(request, trace, AppServerTransport::Stdio)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn request_with_transport<T>(
|
||||
&mut self,
|
||||
request: ClientRequest,
|
||||
trace: Option<W3cTraceContext>,
|
||||
transport: AppServerTransport,
|
||||
) -> T
|
||||
where
|
||||
T: serde::de::DeserializeOwned,
|
||||
{
|
||||
@@ -175,12 +189,7 @@ impl TracingHarness {
|
||||
request.trace = trace;
|
||||
|
||||
self.processor
|
||||
.process_request(
|
||||
TEST_CONNECTION_ID,
|
||||
request,
|
||||
AppServerTransport::Stdio,
|
||||
&mut self.session,
|
||||
)
|
||||
.process_request(TEST_CONNECTION_ID, request, transport, &mut self.session)
|
||||
.await;
|
||||
read_response(&mut self.outgoing_rx, request_id).await
|
||||
}
|
||||
@@ -566,6 +575,73 @@ async fn thread_start_jsonrpc_span_exports_server_span_and_parents_children() ->
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn request_spans_record_per_connection_transport_kind() -> Result<()> {
|
||||
let _guard = tracing_test_guard().lock().await;
|
||||
let mut harness = TracingHarness::new().await?;
|
||||
|
||||
for (offset, transport, expected_transport) in [
|
||||
(0_i64, AppServerTransport::Stdio, "stdio"),
|
||||
(
|
||||
1_i64,
|
||||
AppServerTransport::WebSocket {
|
||||
bind_address: "127.0.0.1:8080".parse().expect("valid socket address"),
|
||||
},
|
||||
"websocket",
|
||||
),
|
||||
(
|
||||
2_i64,
|
||||
AppServerTransport::RemoteControlled,
|
||||
"remote_controlled",
|
||||
),
|
||||
] {
|
||||
harness.reset_tracing();
|
||||
let baseline_len = harness
|
||||
.tracing
|
||||
.exporter
|
||||
.get_finished_spans()
|
||||
.expect("span export")
|
||||
.len();
|
||||
|
||||
let _: serde_json::Value = harness
|
||||
.request_with_transport(
|
||||
ClientRequest::ConfigRead {
|
||||
request_id: RequestId::Integer(30_100 + offset),
|
||||
params: ConfigReadParams {
|
||||
include_layers: false,
|
||||
cwd: None,
|
||||
},
|
||||
},
|
||||
None,
|
||||
transport,
|
||||
)
|
||||
.await;
|
||||
|
||||
let spans = wait_for_new_exported_spans(harness.tracing, baseline_len, |spans| {
|
||||
spans.iter().any(|span| {
|
||||
span.span_kind == SpanKind::Server
|
||||
&& span_attr(span, "rpc.method") == Some("config/read")
|
||||
&& span_attr(span, "rpc.transport") == Some(expected_transport)
|
||||
})
|
||||
})
|
||||
.await;
|
||||
let server_span = spans
|
||||
.iter()
|
||||
.find(|span| {
|
||||
span.span_kind == SpanKind::Server
|
||||
&& span_attr(span, "rpc.method") == Some("config/read")
|
||||
})
|
||||
.expect("config/read server span should be exported");
|
||||
assert_eq!(
|
||||
span_attr(server_span, "rpc.transport"),
|
||||
Some(expected_transport)
|
||||
);
|
||||
}
|
||||
|
||||
harness.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn turn_start_jsonrpc_span_parents_core_turn_spans() -> Result<()> {
|
||||
let _guard = tracing_test_guard().lock().await;
|
||||
|
||||
255
codex-rs/app-server/src/remote_control/client_manager.rs
Normal file
255
codex-rs/app-server/src/remote_control/client_manager.rs
Normal file
@@ -0,0 +1,255 @@
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::transport::AppServerTransport;
|
||||
use crate::transport::CHANNEL_CAPACITY;
|
||||
use crate::transport::ConnectionIdAllocator;
|
||||
use crate::transport::TransportEvent;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::MissedTickBehavior;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use super::ClientEvent;
|
||||
use super::ClientId;
|
||||
use super::PongStatus;
|
||||
use super::ServerEvent;
|
||||
|
||||
const REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT: Duration = Duration::from_secs(10 * 60);
|
||||
const REMOTE_CONTROL_IDLE_SWEEP_INTERVAL: Duration = Duration::from_secs(30);
|
||||
|
||||
struct RemoteControlClientState {
|
||||
connection_id: ConnectionId,
|
||||
disconnect_token: CancellationToken,
|
||||
last_activity_at: Instant,
|
||||
}
|
||||
|
||||
pub(super) async fn run(
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
mut client_event_rx: mpsc::Receiver<ClientEvent>,
|
||||
server_event_tx: mpsc::Sender<ServerEvent>,
|
||||
writer_exited_tx: mpsc::Sender<ClientId>,
|
||||
mut writer_exited_rx: mpsc::Receiver<ClientId>,
|
||||
shutdown_token: CancellationToken,
|
||||
connection_id_allocator: ConnectionIdAllocator,
|
||||
) {
|
||||
let mut clients = HashMap::<ClientId, RemoteControlClientState>::new();
|
||||
let mut idle_sweep = tokio::time::interval(REMOTE_CONTROL_IDLE_SWEEP_INTERVAL);
|
||||
idle_sweep.set_missed_tick_behavior(MissedTickBehavior::Delay);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
_ = idle_sweep.tick() => {
|
||||
if !close_expired_remote_control_clients(&transport_event_tx, &mut clients).await {
|
||||
break;
|
||||
}
|
||||
}
|
||||
writer_exited = writer_exited_rx.recv() => {
|
||||
let Some(client_id) = writer_exited else {
|
||||
break;
|
||||
};
|
||||
if !close_remote_control_client(&transport_event_tx, &mut clients, &client_id).await {
|
||||
break;
|
||||
}
|
||||
}
|
||||
client_event = client_event_rx.recv() => {
|
||||
let Some(client_event) = client_event else {
|
||||
break;
|
||||
};
|
||||
match client_event {
|
||||
ClientEvent::ClientMessage { client_id, message } => {
|
||||
if let Some(connection_id) = clients.get_mut(&client_id).map(|client| {
|
||||
client.last_activity_at = Instant::now();
|
||||
client.connection_id
|
||||
}) {
|
||||
if transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message,
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if !remote_control_message_starts_connection(&message) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let connection_id = connection_id_allocator.next_connection_id();
|
||||
let (writer_tx, writer_rx) =
|
||||
mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let disconnect_token = CancellationToken::new();
|
||||
if transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
transport_kind: AppServerTransport::RemoteControlled,
|
||||
writer: writer_tx,
|
||||
allow_legacy_notifications: false,
|
||||
disconnect_sender: Some(disconnect_token.clone()),
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
tokio::spawn(run_remote_control_client_outbound(
|
||||
client_id.clone(),
|
||||
writer_rx,
|
||||
server_event_tx.clone(),
|
||||
writer_exited_tx.clone(),
|
||||
disconnect_token.clone(),
|
||||
));
|
||||
|
||||
clients.insert(
|
||||
client_id,
|
||||
RemoteControlClientState {
|
||||
connection_id,
|
||||
disconnect_token,
|
||||
last_activity_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
|
||||
if transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message,
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
ClientEvent::Ping { client_id, .. } => {
|
||||
let status = match clients.get_mut(&client_id) {
|
||||
Some(client) => {
|
||||
client.last_activity_at = Instant::now();
|
||||
PongStatus::Active
|
||||
}
|
||||
None => PongStatus::Unknown,
|
||||
};
|
||||
|
||||
if server_event_tx
|
||||
.send(ServerEvent::Pong { client_id, status })
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
ClientEvent::ClientClosed { client_id } => {
|
||||
if !close_remote_control_client(&transport_event_tx, &mut clients, &client_id)
|
||||
.await
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(client_id) = clients.keys().next().cloned() {
|
||||
if !close_remote_control_client(&transport_event_tx, &mut clients, &client_id).await {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_control_message_starts_connection(message: &JSONRPCMessage) -> bool {
|
||||
matches!(
|
||||
message,
|
||||
JSONRPCMessage::Request(JSONRPCRequest { method, .. }) if method == "initialize"
|
||||
)
|
||||
}
|
||||
|
||||
fn remote_control_client_is_alive(client: &RemoteControlClientState, now: Instant) -> bool {
|
||||
now.duration_since(client.last_activity_at) < REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT
|
||||
}
|
||||
|
||||
async fn close_expired_remote_control_clients(
|
||||
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||
clients: &mut HashMap<ClientId, RemoteControlClientState>,
|
||||
) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
let expired_client_ids: Vec<ClientId> = clients
|
||||
.iter()
|
||||
.filter_map(|(client_id, client)| {
|
||||
(!remote_control_client_is_alive(client, now)).then_some(client_id.clone())
|
||||
})
|
||||
.collect();
|
||||
|
||||
for client_id in expired_client_ids {
|
||||
if !close_remote_control_client(transport_event_tx, clients, &client_id).await {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
async fn close_remote_control_client(
|
||||
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||
clients: &mut HashMap<ClientId, RemoteControlClientState>,
|
||||
client_id: &ClientId,
|
||||
) -> bool {
|
||||
let Some(client) = clients.remove(client_id) else {
|
||||
return true;
|
||||
};
|
||||
client.disconnect_token.cancel();
|
||||
transport_event_tx
|
||||
.send(TransportEvent::ConnectionClosed {
|
||||
connection_id: client.connection_id,
|
||||
transport_kind: AppServerTransport::RemoteControlled,
|
||||
})
|
||||
.await
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
async fn run_remote_control_client_outbound(
|
||||
client_id: ClientId,
|
||||
mut writer_rx: mpsc::Receiver<OutgoingMessage>,
|
||||
server_event_tx: mpsc::Sender<ServerEvent>,
|
||||
writer_exited_tx: mpsc::Sender<ClientId>,
|
||||
disconnect_token: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = disconnect_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
outgoing_message = writer_rx.recv() => {
|
||||
let Some(outgoing_message) = outgoing_message else {
|
||||
break;
|
||||
};
|
||||
if server_event_tx
|
||||
.send(ServerEvent::ServerMessage {
|
||||
client_id: client_id.clone(),
|
||||
message: Box::new(outgoing_message),
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = writer_exited_tx.send(client_id).await;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "client_manager_tests.rs"]
|
||||
mod client_manager_tests;
|
||||
321
codex-rs/app-server/src/remote_control/client_manager_tests.rs
Normal file
321
codex-rs/app-server/src/remote_control/client_manager_tests.rs
Normal file
@@ -0,0 +1,321 @@
|
||||
use super::super::ClientActivityState;
|
||||
use super::super::ClientEvent;
|
||||
use super::super::ClientId;
|
||||
use super::super::start_remote_control;
|
||||
use super::super::test_support::accept_http_request;
|
||||
use super::super::test_support::accept_remote_control_backend_connection;
|
||||
use super::super::test_support::read_server_event;
|
||||
use super::super::test_support::remote_control_auth_manager;
|
||||
use super::super::test_support::respond_with_json;
|
||||
use super::super::test_support::send_client_event;
|
||||
use super::*;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::OutgoingNotification;
|
||||
use crate::transport::AppServerTransport;
|
||||
use crate::transport::CHANNEL_CAPACITY;
|
||||
use crate::transport::ConnectionIdAllocator;
|
||||
use crate::transport::TransportEvent;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use tempfile::TempDir;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[tokio::test]
|
||||
async fn close_expired_remote_control_clients_closes_only_stale_connections() {
|
||||
let (transport_event_tx, mut transport_event_rx) =
|
||||
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
|
||||
let stale_disconnect_token = CancellationToken::new();
|
||||
let fresh_disconnect_token = CancellationToken::new();
|
||||
let stale_client_id = ClientId("stale-client".to_string());
|
||||
let fresh_client_id = ClientId("fresh-client".to_string());
|
||||
let now = tokio::time::Instant::now();
|
||||
let mut clients = HashMap::from([
|
||||
(
|
||||
stale_client_id.clone(),
|
||||
RemoteControlClientState {
|
||||
connection_id: ConnectionId(11),
|
||||
disconnect_token: stale_disconnect_token.clone(),
|
||||
last_activity_at: now - REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT,
|
||||
},
|
||||
),
|
||||
(
|
||||
fresh_client_id.clone(),
|
||||
RemoteControlClientState {
|
||||
connection_id: ConnectionId(12),
|
||||
disconnect_token: fresh_disconnect_token.clone(),
|
||||
last_activity_at: now,
|
||||
},
|
||||
),
|
||||
]);
|
||||
|
||||
assert!(close_expired_remote_control_clients(&transport_event_tx, &mut clients).await);
|
||||
assert!(stale_disconnect_token.is_cancelled());
|
||||
assert!(!fresh_disconnect_token.is_cancelled());
|
||||
assert!(!clients.contains_key(&stale_client_id));
|
||||
assert!(clients.contains_key(&fresh_client_id));
|
||||
|
||||
match timeout(Duration::from_secs(5), transport_event_rx.recv())
|
||||
.await
|
||||
.expect("stale client close should arrive in time")
|
||||
.expect("stale client close should exist")
|
||||
{
|
||||
TransportEvent::ConnectionClosed {
|
||||
connection_id,
|
||||
transport_kind,
|
||||
} => {
|
||||
assert_eq!(connection_id, ConnectionId(11));
|
||||
assert_eq!(transport_kind, AppServerTransport::RemoteControlled);
|
||||
}
|
||||
other => panic!("expected stale client close event, got {other:?}"),
|
||||
}
|
||||
assert!(
|
||||
timeout(Duration::from_millis(100), transport_event_rx.recv())
|
||||
.await
|
||||
.is_err(),
|
||||
"fresh clients should remain connected during stale sweep"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_control_transport_manages_virtual_clients_and_routes_messages() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let remote_control_url = format!(
|
||||
"http://{}/api/codex",
|
||||
listener
|
||||
.local_addr()
|
||||
.expect("listener should have a local addr")
|
||||
);
|
||||
let codex_home = TempDir::new().expect("temp dir should create");
|
||||
let (transport_event_tx, mut transport_event_rx) =
|
||||
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let remote_handle = start_remote_control(
|
||||
remote_control_url,
|
||||
codex_home.path().to_path_buf(),
|
||||
remote_control_auth_manager(),
|
||||
transport_event_tx,
|
||||
shutdown_token.clone(),
|
||||
ConnectionIdAllocator::default(),
|
||||
)
|
||||
.await
|
||||
.expect("remote control should start");
|
||||
let enroll_request = accept_http_request(&listener).await;
|
||||
assert_eq!(
|
||||
enroll_request.request_line,
|
||||
"POST /api/codex/remote/control/server/enroll HTTP/1.1"
|
||||
);
|
||||
respond_with_json(
|
||||
enroll_request.stream,
|
||||
json!({ "server_id": "srv_e_client_manager" }),
|
||||
)
|
||||
.await;
|
||||
let (_handshake_request, mut websocket) =
|
||||
accept_remote_control_backend_connection(&listener, None).await;
|
||||
|
||||
let client_id = ClientId("client-1".to_string());
|
||||
send_client_event(
|
||||
&mut websocket,
|
||||
ClientEvent::Ping {
|
||||
client_id: client_id.clone(),
|
||||
state: Some(ClientActivityState::Foreground),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
assert_eq!(
|
||||
read_server_event(&mut websocket).await,
|
||||
json!({
|
||||
"type": "pong",
|
||||
"client_id": "client-1",
|
||||
"status": "unknown",
|
||||
})
|
||||
);
|
||||
|
||||
send_client_event(
|
||||
&mut websocket,
|
||||
ClientEvent::ClientMessage {
|
||||
client_id: client_id.clone(),
|
||||
message: JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
}),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
assert!(
|
||||
timeout(Duration::from_millis(100), transport_event_rx.recv())
|
||||
.await
|
||||
.is_err(),
|
||||
"non-initialize client messages should be ignored before connection creation"
|
||||
);
|
||||
|
||||
let initialize_message = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(1),
|
||||
method: "initialize".to_string(),
|
||||
params: Some(json!({
|
||||
"clientInfo": {
|
||||
"name": "remote-test-client",
|
||||
"version": "0.1.0"
|
||||
}
|
||||
})),
|
||||
trace: None,
|
||||
});
|
||||
send_client_event(
|
||||
&mut websocket,
|
||||
ClientEvent::ClientMessage {
|
||||
client_id: client_id.clone(),
|
||||
message: initialize_message.clone(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let (connection_id, writer) = match timeout(Duration::from_secs(5), transport_event_rx.recv())
|
||||
.await
|
||||
.expect("connection open should arrive in time")
|
||||
.expect("connection open should exist")
|
||||
{
|
||||
TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
transport_kind,
|
||||
writer,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(transport_kind, AppServerTransport::RemoteControlled);
|
||||
(connection_id, writer)
|
||||
}
|
||||
other => panic!("expected connection open event, got {other:?}"),
|
||||
};
|
||||
|
||||
match timeout(Duration::from_secs(5), transport_event_rx.recv())
|
||||
.await
|
||||
.expect("initialize message should arrive in time")
|
||||
.expect("initialize message should exist")
|
||||
{
|
||||
TransportEvent::IncomingMessage {
|
||||
connection_id: incoming_connection_id,
|
||||
message,
|
||||
} => {
|
||||
assert_eq!(incoming_connection_id, connection_id);
|
||||
assert_eq!(message, initialize_message);
|
||||
}
|
||||
other => panic!("expected initialize incoming message, got {other:?}"),
|
||||
}
|
||||
|
||||
let followup_message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
send_client_event(
|
||||
&mut websocket,
|
||||
ClientEvent::ClientMessage {
|
||||
client_id: client_id.clone(),
|
||||
message: followup_message.clone(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
match timeout(Duration::from_secs(5), transport_event_rx.recv())
|
||||
.await
|
||||
.expect("followup message should arrive in time")
|
||||
.expect("followup message should exist")
|
||||
{
|
||||
TransportEvent::IncomingMessage {
|
||||
connection_id: incoming_connection_id,
|
||||
message,
|
||||
} => {
|
||||
assert_eq!(incoming_connection_id, connection_id);
|
||||
assert_eq!(message, followup_message);
|
||||
}
|
||||
other => panic!("expected followup incoming message, got {other:?}"),
|
||||
}
|
||||
|
||||
send_client_event(
|
||||
&mut websocket,
|
||||
ClientEvent::Ping {
|
||||
client_id: client_id.clone(),
|
||||
state: Some(ClientActivityState::Foreground),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
assert_eq!(
|
||||
read_server_event(&mut websocket).await,
|
||||
json!({
|
||||
"type": "pong",
|
||||
"client_id": "client-1",
|
||||
"status": "active",
|
||||
})
|
||||
);
|
||||
|
||||
writer
|
||||
.send(OutgoingMessage::Notification(OutgoingNotification {
|
||||
method: "codex/event/test".to_string(),
|
||||
params: Some(json!({ "ok": true })),
|
||||
}))
|
||||
.await
|
||||
.expect("remote writer should accept outgoing message");
|
||||
assert_eq!(
|
||||
read_server_event(&mut websocket).await,
|
||||
json!({
|
||||
"type": "server_message",
|
||||
"client_id": "client-1",
|
||||
"message": {
|
||||
"method": "codex/event/test",
|
||||
"params": {
|
||||
"ok": true,
|
||||
}
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
send_client_event(
|
||||
&mut websocket,
|
||||
ClientEvent::ClientClosed {
|
||||
client_id: client_id.clone(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
match timeout(Duration::from_secs(5), transport_event_rx.recv())
|
||||
.await
|
||||
.expect("connection close should arrive in time")
|
||||
.expect("connection close should exist")
|
||||
{
|
||||
TransportEvent::ConnectionClosed {
|
||||
connection_id: closed_connection_id,
|
||||
transport_kind,
|
||||
} => {
|
||||
assert_eq!(closed_connection_id, connection_id);
|
||||
assert_eq!(transport_kind, AppServerTransport::RemoteControlled);
|
||||
}
|
||||
other => panic!("expected connection close event, got {other:?}"),
|
||||
}
|
||||
|
||||
send_client_event(
|
||||
&mut websocket,
|
||||
ClientEvent::Ping {
|
||||
client_id,
|
||||
state: Some(ClientActivityState::Foreground),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
assert_eq!(
|
||||
read_server_event(&mut websocket).await,
|
||||
json!({
|
||||
"type": "pong",
|
||||
"client_id": "client-1",
|
||||
"status": "unknown",
|
||||
})
|
||||
);
|
||||
|
||||
shutdown_token.cancel();
|
||||
let _ = remote_handle.await;
|
||||
}
|
||||
341
codex-rs/app-server/src/remote_control/connection_manager.rs
Normal file
341
codex-rs/app-server/src/remote_control/connection_manager.rs
Normal file
@@ -0,0 +1,341 @@
|
||||
use crate::remote_control::entrollment_manager::EnrollmentManager;
|
||||
use crate::remote_control::entrollment_manager::preview_remote_control_response_body;
|
||||
use crate::transport::colorize;
|
||||
use codex_core::AuthManager;
|
||||
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use owo_colors::Style;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Result as IoResult;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Duration;
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
|
||||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
use tokio_tungstenite::tungstenite::http::HeaderValue;
|
||||
use tokio_tungstenite::tungstenite::{self};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use super::ClientEvent;
|
||||
use super::REMOTE_CONTROL_ACCOUNT_ID_HEADER;
|
||||
use super::REMOTE_CONTROL_REQUEST_ID_HEADER;
|
||||
use super::RemoteControlConnectionAuth;
|
||||
use super::RemoteControlEnrollment;
|
||||
use super::ServerEvent;
|
||||
use super::load_remote_control_auth;
|
||||
|
||||
const REMOTE_CONTROL_RECONNECT_INITIAL_BACKOFF: Duration = Duration::from_secs(1);
|
||||
const REMOTE_CONTROL_RECONNECT_MAX_BACKOFF: Duration = Duration::from_secs(30);
|
||||
const REMOTE_CONTROL_PROTOCOL_VERSION: &str = "2";
|
||||
|
||||
const REMOTE_CONTROL_OAI_REQUEST_ID_HEADER: &str = "x-oai-request-id";
|
||||
const REMOTE_CONTROL_CF_RAY_HEADER: &str = "cf-ray";
|
||||
|
||||
struct RemoteControlWebsocketConnection {
|
||||
websocket_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
#[cfg_attr(not(test), allow(dead_code))]
|
||||
request_id: Option<String>,
|
||||
}
|
||||
|
||||
#[allow(clippy::print_stderr)]
|
||||
fn print_remote_control_connection_banner(
|
||||
remote_control_target: &super::RemoteControlTarget,
|
||||
reconnect_attempt: u64,
|
||||
reconnect_backoff: Duration,
|
||||
reconnect_reason: Option<&str>,
|
||||
) {
|
||||
let title = colorize("app-server remote-control", Style::new().bold().yellow());
|
||||
let control_server_label = colorize("control server:", Style::new().dimmed());
|
||||
let control_server_url = remote_control_target.websocket_url.as_str();
|
||||
let control_server_url = colorize(control_server_url, Style::new().green());
|
||||
eprintln!("{title}");
|
||||
eprintln!(" {control_server_label} {control_server_url}");
|
||||
|
||||
if reconnect_attempt > 0 {
|
||||
let attempt_label = colorize("attempt:", Style::new().dimmed());
|
||||
let after_label = colorize("after:", Style::new().dimmed());
|
||||
eprintln!(" {attempt_label} {reconnect_attempt}");
|
||||
eprintln!(" {after_label} {reconnect_backoff:?}");
|
||||
}
|
||||
if let Some(reason) = reconnect_reason {
|
||||
let reason_label = colorize("reason:", Style::new().dimmed());
|
||||
eprintln!(" {reason_label} {reason}");
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::print_stderr)]
|
||||
pub(super) async fn run(
|
||||
auth_manager: std::sync::Arc<AuthManager>,
|
||||
remote_control_target: super::RemoteControlTarget,
|
||||
mut enrollment_manager: EnrollmentManager,
|
||||
client_event_tx: mpsc::Sender<ClientEvent>,
|
||||
mut server_event_rx: mpsc::Receiver<ServerEvent>,
|
||||
shutdown_token: CancellationToken,
|
||||
) {
|
||||
let mut reconnect_backoff = REMOTE_CONTROL_RECONNECT_INITIAL_BACKOFF;
|
||||
let mut reconnect_attempt = 0_u64;
|
||||
let mut reconnect_reason = None::<String>;
|
||||
let mut wait_before_connect = false;
|
||||
let mut pending_server_event = None::<ServerEvent>;
|
||||
|
||||
loop {
|
||||
let connect_delay = if wait_before_connect {
|
||||
tokio::select! {
|
||||
_ = shutdown_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(reconnect_backoff) => {
|
||||
reconnect_attempt = reconnect_attempt.saturating_add(1);
|
||||
}
|
||||
}
|
||||
reconnect_backoff
|
||||
} else {
|
||||
wait_before_connect = true;
|
||||
Duration::ZERO
|
||||
};
|
||||
|
||||
print_remote_control_connection_banner(
|
||||
&remote_control_target,
|
||||
reconnect_attempt,
|
||||
connect_delay,
|
||||
reconnect_reason.as_deref(),
|
||||
);
|
||||
|
||||
let websocket_connection = tokio::select! {
|
||||
_ = shutdown_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
connect_result = connect_remote_control_websocket(
|
||||
auth_manager.as_ref(),
|
||||
&remote_control_target,
|
||||
&mut enrollment_manager,
|
||||
) => {
|
||||
match connect_result {
|
||||
Ok(websocket_connection) => {
|
||||
reconnect_backoff = REMOTE_CONTROL_RECONNECT_INITIAL_BACKOFF;
|
||||
reconnect_attempt = 0;
|
||||
info!(
|
||||
"connected to app-server remote control websocket: {}",
|
||||
remote_control_target.websocket_url
|
||||
);
|
||||
websocket_connection
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("{err}");
|
||||
reconnect_reason = Some(err.to_string());
|
||||
if connect_delay != Duration::ZERO {
|
||||
reconnect_backoff = reconnect_backoff
|
||||
.saturating_mul(2)
|
||||
.min(REMOTE_CONTROL_RECONNECT_MAX_BACKOFF);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let (mut websocket_writer, mut websocket_reader) =
|
||||
websocket_connection.websocket_stream.split();
|
||||
loop {
|
||||
if let Some(server_event) = pending_server_event.take() {
|
||||
let payload = match serde_json::to_string(&server_event) {
|
||||
Ok(payload) => payload,
|
||||
Err(err) => {
|
||||
error!("failed to serialize remote-control server event: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if let Err(err) = websocket_writer
|
||||
.send(TungsteniteMessage::Text(payload.into()))
|
||||
.await
|
||||
{
|
||||
warn!("remote control websocket send failed: {err}");
|
||||
reconnect_reason = Some(format!("send failed: {err}"));
|
||||
pending_server_event = Some(server_event);
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
_ = shutdown_token.cancelled() => {
|
||||
return;
|
||||
}
|
||||
incoming_message = websocket_reader.next() => {
|
||||
match incoming_message {
|
||||
Some(Ok(TungsteniteMessage::Text(text))) => {
|
||||
match serde_json::from_str::<ClientEvent>(&text) {
|
||||
Ok(client_event) => {
|
||||
if client_event_tx.send(client_event).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("failed to deserialize remote-control client event");
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(TungsteniteMessage::Ping(payload))) => {
|
||||
if let Err(err) = websocket_writer
|
||||
.send(TungsteniteMessage::Pong(payload))
|
||||
.await
|
||||
{
|
||||
warn!("remote control websocket pong failed: {err}");
|
||||
reconnect_reason = Some(format!("pong failed: {err}"));
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Ok(TungsteniteMessage::Pong(_))) => {}
|
||||
Some(Ok(TungsteniteMessage::Binary(_))) => {
|
||||
warn!("dropping unsupported binary remote-control websocket message");
|
||||
}
|
||||
Some(Ok(TungsteniteMessage::Frame(_))) => {}
|
||||
Some(Ok(TungsteniteMessage::Close(_))) | None => {
|
||||
warn!("remote control websocket disconnected");
|
||||
reconnect_reason = Some("server closed the websocket".to_string());
|
||||
break;
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
warn!("remote control websocket receive error: {err}");
|
||||
reconnect_reason = Some(format!("receive error: {err}"));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
server_event = server_event_rx.recv() => {
|
||||
let Some(server_event) = server_event else {
|
||||
return;
|
||||
};
|
||||
pending_server_event = Some(server_event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn set_remote_control_header(
|
||||
headers: &mut tungstenite::http::HeaderMap,
|
||||
name: &'static str,
|
||||
value: &str,
|
||||
) -> IoResult<()> {
|
||||
let header_value = HeaderValue::from_str(value).map_err(|err| {
|
||||
std::io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("invalid remote control header `{name}`: {err}"),
|
||||
)
|
||||
})?;
|
||||
headers.insert(name, header_value);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_remote_control_websocket_request(
|
||||
auth: &RemoteControlConnectionAuth,
|
||||
websocket_url: &str,
|
||||
enrollment: &RemoteControlEnrollment,
|
||||
) -> IoResult<tungstenite::http::Request<()>> {
|
||||
let mut request = websocket_url.into_client_request().map_err(|err| {
|
||||
std::io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("invalid remote control websocket URL `{websocket_url}`: {err}"),
|
||||
)
|
||||
})?;
|
||||
let headers = request.headers_mut();
|
||||
set_remote_control_header(headers, "x-codex-server-id", &enrollment.server_id)?;
|
||||
set_remote_control_header(headers, "x-codex-name", &enrollment.server_name)?;
|
||||
set_remote_control_header(
|
||||
headers,
|
||||
"x-codex-protocol-version",
|
||||
REMOTE_CONTROL_PROTOCOL_VERSION,
|
||||
)?;
|
||||
set_remote_control_header(
|
||||
headers,
|
||||
"authorization",
|
||||
&format!("Bearer {}", auth.bearer_token),
|
||||
)?;
|
||||
if let Some(account_id) = auth.account_id.as_deref() {
|
||||
set_remote_control_header(headers, REMOTE_CONTROL_ACCOUNT_ID_HEADER, account_id)?;
|
||||
}
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
async fn connect_remote_control_websocket(
|
||||
auth_manager: &AuthManager,
|
||||
remote_control_target: &super::RemoteControlTarget,
|
||||
enrollment_manager: &mut EnrollmentManager,
|
||||
) -> IoResult<RemoteControlWebsocketConnection> {
|
||||
ensure_rustls_crypto_provider();
|
||||
let websocket_url = remote_control_target.websocket_url.clone();
|
||||
let auth = load_remote_control_auth(auth_manager).await?;
|
||||
let enrollment = enrollment_manager.enroll(&auth).await?;
|
||||
let request = build_remote_control_websocket_request(&auth, &websocket_url, &enrollment)?;
|
||||
|
||||
let (websocket_stream, response) = match connect_async(request).await {
|
||||
Ok(connection) => connection,
|
||||
Err(err) => {
|
||||
return Err(std::io::Error::other(
|
||||
format_remote_control_websocket_connect_error(&websocket_url, &err),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let request_id = remote_control_request_id(response.headers());
|
||||
|
||||
Ok(RemoteControlWebsocketConnection {
|
||||
websocket_stream,
|
||||
request_id,
|
||||
})
|
||||
}
|
||||
|
||||
fn remote_control_header_value(
|
||||
headers: &tungstenite::http::HeaderMap,
|
||||
header_name: &str,
|
||||
) -> Option<String> {
|
||||
headers
|
||||
.get(header_name)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(str::to_owned)
|
||||
}
|
||||
|
||||
fn remote_control_request_id(headers: &tungstenite::http::HeaderMap) -> Option<String> {
|
||||
remote_control_header_value(headers, REMOTE_CONTROL_REQUEST_ID_HEADER)
|
||||
.or_else(|| remote_control_header_value(headers, REMOTE_CONTROL_OAI_REQUEST_ID_HEADER))
|
||||
}
|
||||
|
||||
fn format_remote_control_websocket_connect_error(
|
||||
websocket_url: &str,
|
||||
err: &tungstenite::Error,
|
||||
) -> String {
|
||||
let mut message =
|
||||
format!("failed to connect app-server remote control websocket `{websocket_url}`: {err}");
|
||||
let tungstenite::Error::Http(response) = err else {
|
||||
return message;
|
||||
};
|
||||
|
||||
if let Some(request_id) = remote_control_request_id(response.headers()) {
|
||||
message.push_str(&format!(", request id: {request_id}"));
|
||||
}
|
||||
if let Some(cf_ray) =
|
||||
remote_control_header_value(response.headers(), REMOTE_CONTROL_CF_RAY_HEADER)
|
||||
{
|
||||
message.push_str(&format!(", cf-ray: {cf_ray}"));
|
||||
}
|
||||
if let Some(body) = response.body().as_ref()
|
||||
&& !body.is_empty()
|
||||
{
|
||||
let body_preview = preview_remote_control_response_body(body);
|
||||
message.push_str(&format!(", body: {body_preview}"));
|
||||
}
|
||||
|
||||
message
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "connection_manager_tests.rs"]
|
||||
mod connection_manager_tests;
|
||||
@@ -0,0 +1,285 @@
|
||||
use super::super::ClientEvent;
|
||||
use super::super::ClientId;
|
||||
use super::super::REMOTE_CONTROL_REQUEST_ID_HEADER;
|
||||
use super::super::entrollment_manager::EnrollmentManager;
|
||||
use super::super::normalize_remote_control_url;
|
||||
use super::super::start_remote_control;
|
||||
use super::super::test_support::accept_http_request;
|
||||
use super::super::test_support::accept_remote_control_backend_connection;
|
||||
use super::super::test_support::read_server_event;
|
||||
use super::super::test_support::remote_control_auth_manager;
|
||||
use super::super::test_support::respond_with_json;
|
||||
use super::super::test_support::respond_with_status_and_headers;
|
||||
use super::super::test_support::send_client_event;
|
||||
use super::*;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::OutgoingNotification;
|
||||
use crate::transport::AppServerTransport;
|
||||
use crate::transport::CHANNEL_CAPACITY;
|
||||
use crate::transport::ConnectionIdAllocator;
|
||||
use crate::transport::TransportEvent;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_remote_control_websocket_captures_handshake_request_id() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let remote_control_url = format!(
|
||||
"http://{}/api/codex",
|
||||
listener
|
||||
.local_addr()
|
||||
.expect("listener should have a local addr")
|
||||
);
|
||||
let remote_control_target =
|
||||
normalize_remote_control_url(&remote_control_url).expect("target should parse");
|
||||
let accept_task = tokio::spawn(async move {
|
||||
let enroll_request = accept_http_request(&listener).await;
|
||||
assert_eq!(
|
||||
enroll_request.request_line,
|
||||
"POST /api/codex/remote/control/server/enroll HTTP/1.1"
|
||||
);
|
||||
assert_eq!(
|
||||
enroll_request.headers.get("authorization"),
|
||||
Some(&"Bearer Access Token".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
enroll_request.headers.get("chatgpt-account-id"),
|
||||
Some(&"account_id".to_string())
|
||||
);
|
||||
let enroll_body = serde_json::from_str::<serde_json::Value>(&enroll_request.body)
|
||||
.expect("enroll body should deserialize");
|
||||
assert_eq!(enroll_body["os"], json!(std::env::consts::OS));
|
||||
assert_eq!(enroll_body["arch"], json!(std::env::consts::ARCH));
|
||||
assert_eq!(
|
||||
enroll_body["app_server_version"],
|
||||
json!(env!("CARGO_PKG_VERSION"))
|
||||
);
|
||||
assert!(enroll_body["name"].is_string());
|
||||
respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await;
|
||||
accept_remote_control_backend_connection(&listener, Some("req-control-123")).await
|
||||
});
|
||||
let codex_home = TempDir::new().expect("temp dir should create");
|
||||
let auth_manager = remote_control_auth_manager();
|
||||
let mut enrollment_manager = EnrollmentManager::new(
|
||||
remote_control_target.clone(),
|
||||
codex_home.path().to_path_buf(),
|
||||
);
|
||||
|
||||
let connection = connect_remote_control_websocket(
|
||||
auth_manager.as_ref(),
|
||||
&remote_control_target,
|
||||
&mut enrollment_manager,
|
||||
)
|
||||
.await
|
||||
.expect("websocket connection should succeed");
|
||||
|
||||
assert_eq!(connection.request_id.as_deref(), Some("req-control-123"));
|
||||
|
||||
let (_request, server_websocket) = accept_task.await.expect("accept task should succeed");
|
||||
drop(server_websocket);
|
||||
drop(connection.websocket_stream);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_remote_control_websocket_includes_http_error_details() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let remote_control_url = format!(
|
||||
"http://{}/api/codex",
|
||||
listener
|
||||
.local_addr()
|
||||
.expect("listener should have a local addr")
|
||||
);
|
||||
let remote_control_target =
|
||||
normalize_remote_control_url(&remote_control_url).expect("target should parse");
|
||||
let websocket_url = remote_control_target.websocket_url.clone();
|
||||
let expected_error = format!(
|
||||
"failed to connect app-server remote control websocket `{websocket_url}`: HTTP error: 503 Service Unavailable, request id: req-503, cf-ray: ray-503, body: upstream unavailable"
|
||||
);
|
||||
let server_task = tokio::spawn(async move {
|
||||
let enroll_request = accept_http_request(&listener).await;
|
||||
assert_eq!(
|
||||
enroll_request.request_line,
|
||||
"POST /api/codex/remote/control/server/enroll HTTP/1.1"
|
||||
);
|
||||
assert_eq!(
|
||||
enroll_request.headers.get("authorization"),
|
||||
Some(&"Bearer Access Token".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
enroll_request.headers.get("chatgpt-account-id"),
|
||||
Some(&"account_id".to_string())
|
||||
);
|
||||
let enroll_body = serde_json::from_str::<serde_json::Value>(&enroll_request.body)
|
||||
.expect("enroll body should deserialize");
|
||||
assert_eq!(enroll_body["os"], json!(std::env::consts::OS));
|
||||
assert_eq!(enroll_body["arch"], json!(std::env::consts::ARCH));
|
||||
assert_eq!(
|
||||
enroll_body["app_server_version"],
|
||||
json!(env!("CARGO_PKG_VERSION"))
|
||||
);
|
||||
assert!(enroll_body["name"].is_string());
|
||||
respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await;
|
||||
|
||||
let request = accept_http_request(&listener).await;
|
||||
assert_eq!(
|
||||
request.request_line,
|
||||
"GET /api/codex/remote/control/server HTTP/1.1"
|
||||
);
|
||||
respond_with_status_and_headers(
|
||||
request.stream,
|
||||
"503 Service Unavailable",
|
||||
&[
|
||||
(REMOTE_CONTROL_REQUEST_ID_HEADER, "req-503"),
|
||||
(REMOTE_CONTROL_CF_RAY_HEADER, "ray-503"),
|
||||
],
|
||||
"upstream unavailable",
|
||||
)
|
||||
.await;
|
||||
});
|
||||
let codex_home = TempDir::new().expect("temp dir should create");
|
||||
let auth_manager = remote_control_auth_manager();
|
||||
let mut enrollment_manager = EnrollmentManager::new(
|
||||
remote_control_target.clone(),
|
||||
codex_home.path().to_path_buf(),
|
||||
);
|
||||
|
||||
let err = match connect_remote_control_websocket(
|
||||
auth_manager.as_ref(),
|
||||
&remote_control_target,
|
||||
&mut enrollment_manager,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => panic!("http error response should fail the websocket connect"),
|
||||
Err(err) => err,
|
||||
};
|
||||
|
||||
server_task.await.expect("server task should succeed");
|
||||
assert_eq!(err.to_string(), expected_error);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_control_transport_reconnects_and_keeps_virtual_client_writer_alive() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let remote_control_url = format!(
|
||||
"http://{}/api/codex",
|
||||
listener
|
||||
.local_addr()
|
||||
.expect("listener should have a local addr")
|
||||
);
|
||||
let codex_home = TempDir::new().expect("temp dir should create");
|
||||
let (transport_event_tx, mut transport_event_rx) =
|
||||
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let remote_handle = start_remote_control(
|
||||
remote_control_url,
|
||||
codex_home.path().to_path_buf(),
|
||||
remote_control_auth_manager(),
|
||||
transport_event_tx,
|
||||
shutdown_token.clone(),
|
||||
ConnectionIdAllocator::default(),
|
||||
)
|
||||
.await
|
||||
.expect("remote control should start");
|
||||
|
||||
let enroll_request = accept_http_request(&listener).await;
|
||||
assert_eq!(
|
||||
enroll_request.request_line,
|
||||
"POST /api/codex/remote/control/server/enroll HTTP/1.1"
|
||||
);
|
||||
respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await;
|
||||
|
||||
let (_first_request, mut first_websocket) =
|
||||
accept_remote_control_backend_connection(&listener, None).await;
|
||||
let client_id = ClientId("client-2".to_string());
|
||||
let initialize_message = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(2),
|
||||
method: "initialize".to_string(),
|
||||
params: Some(json!({
|
||||
"clientInfo": {
|
||||
"name": "remote-test-client",
|
||||
"version": "0.1.0"
|
||||
}
|
||||
})),
|
||||
trace: None,
|
||||
});
|
||||
send_client_event(
|
||||
&mut first_websocket,
|
||||
ClientEvent::ClientMessage {
|
||||
client_id: client_id.clone(),
|
||||
message: initialize_message.clone(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
let writer = match timeout(Duration::from_secs(5), transport_event_rx.recv())
|
||||
.await
|
||||
.expect("connection open should arrive in time")
|
||||
.expect("connection open should exist")
|
||||
{
|
||||
TransportEvent::ConnectionOpened {
|
||||
writer,
|
||||
transport_kind,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(transport_kind, AppServerTransport::RemoteControlled);
|
||||
writer
|
||||
}
|
||||
other => panic!("expected connection open before reconnect, got {other:?}"),
|
||||
};
|
||||
match timeout(Duration::from_secs(5), transport_event_rx.recv())
|
||||
.await
|
||||
.expect("initialize message should arrive in time")
|
||||
.expect("initialize message should exist")
|
||||
{
|
||||
TransportEvent::IncomingMessage { message, .. } => {
|
||||
assert_eq!(message, initialize_message);
|
||||
}
|
||||
other => panic!("expected initialize incoming message, got {other:?}"),
|
||||
}
|
||||
first_websocket
|
||||
.close(None)
|
||||
.await
|
||||
.expect("first websocket should close");
|
||||
drop(first_websocket);
|
||||
|
||||
let (_second_request, mut second_websocket) =
|
||||
accept_remote_control_backend_connection(&listener, None).await;
|
||||
writer
|
||||
.send(OutgoingMessage::Notification(OutgoingNotification {
|
||||
method: "codex/event/reconnected".to_string(),
|
||||
params: Some(json!({ "replayed": true })),
|
||||
}))
|
||||
.await
|
||||
.expect("outgoing message should send after reconnect");
|
||||
assert_eq!(
|
||||
read_server_event(&mut second_websocket).await,
|
||||
json!({
|
||||
"type": "server_message",
|
||||
"client_id": client_id.0,
|
||||
"message": {
|
||||
"method": "codex/event/reconnected",
|
||||
"params": {
|
||||
"replayed": true,
|
||||
}
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
shutdown_token.cancel();
|
||||
let _ = remote_handle.await;
|
||||
}
|
||||
312
codex-rs/app-server/src/remote_control/enrollment_tests.rs
Normal file
312
codex-rs/app-server/src/remote_control/enrollment_tests.rs
Normal file
@@ -0,0 +1,312 @@
|
||||
use super::super::RemoteControlConnectionAuth;
|
||||
use super::super::RemoteControlEnrollment;
|
||||
use super::super::RemoteControlTarget;
|
||||
use super::super::load_remote_control_auth;
|
||||
use super::super::normalize_remote_control_url;
|
||||
use super::super::test_support::accept_http_request;
|
||||
use super::super::test_support::respond_with_json;
|
||||
use super::*;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::test_support::auth_manager_from_auth;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
#[tokio::test]
|
||||
async fn validate_remote_control_auth_rejects_api_key_auth() {
|
||||
let auth_manager = auth_manager_from_auth(CodexAuth::from_api_key("sk-test"));
|
||||
|
||||
let err = load_remote_control_auth(auth_manager.as_ref())
|
||||
.await
|
||||
.expect_err("API key auth should be rejected");
|
||||
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"remote control requires ChatGPT authentication; API key auth is not supported"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_remote_control_url_handles_supported_and_unsupported_inputs() {
|
||||
assert_eq!(
|
||||
normalize_remote_control_url("http://example.com/backend-api/wham")
|
||||
.expect("valid http prefix"),
|
||||
RemoteControlTarget {
|
||||
websocket_url: "ws://example.com/backend-api/wham/remote/control/server".to_string(),
|
||||
enroll_url: "http://example.com/backend-api/wham/remote/control/server/enroll"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_remote_control_url("https://example.com/backend-api/wham/remote/control/server")
|
||||
.expect("valid https full path"),
|
||||
RemoteControlTarget {
|
||||
websocket_url: "wss://example.com/backend-api/wham/remote/control/server".to_string(),
|
||||
enroll_url: "https://example.com/backend-api/wham/remote/control/server/enroll"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_remote_control_url("http://example.com/legacy/server")
|
||||
.expect("valid legacy http url"),
|
||||
RemoteControlTarget {
|
||||
websocket_url: "ws://example.com/legacy/server".to_string(),
|
||||
enroll_url: "http://example.com/legacy/server/enroll".to_string(),
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_remote_control_url("https://chatgpt.com/backend-api")
|
||||
.expect("chatgpt backend-api base should target the public wham path"),
|
||||
RemoteControlTarget {
|
||||
websocket_url: "wss://chatgpt.com/backend-api/wham/remote/control/server".to_string(),
|
||||
enroll_url: "https://chatgpt.com/backend-api/wham/remote/control/server/enroll"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_remote_control_url("https://chat.openai.com")
|
||||
.expect("chat.openai.com root should normalize"),
|
||||
RemoteControlTarget {
|
||||
websocket_url: "wss://chat.openai.com/backend-api/wham/remote/control/server"
|
||||
.to_string(),
|
||||
enroll_url: "https://chat.openai.com/backend-api/wham/remote/control/server/enroll"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_remote_control_url("https://chatgpt.com/api/codex/remote/control/server")
|
||||
.expect("internal chatgpt remote-control path should rewrite to the public wham path"),
|
||||
RemoteControlTarget {
|
||||
websocket_url: "wss://chatgpt.com/backend-api/wham/remote/control/server".to_string(),
|
||||
enroll_url: "https://chatgpt.com/backend-api/wham/remote/control/server/enroll"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_remote_control_url("https://chatgpt.com/api/codex")
|
||||
.expect("explicit chatgpt api/codex base should rewrite to the public wham path"),
|
||||
RemoteControlTarget {
|
||||
websocket_url: "wss://chatgpt.com/backend-api/wham/remote/control/server".to_string(),
|
||||
enroll_url: "https://chatgpt.com/backend-api/wham/remote/control/server/enroll"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
let err = normalize_remote_control_url("ftp://example.com/control")
|
||||
.expect_err("unsupported scheme should fail");
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"invalid remote control URL `ftp://example.com/control`; expected http:// or https://"
|
||||
);
|
||||
|
||||
let err =
|
||||
normalize_remote_control_url("ws://example.com/control").expect_err("ws url should fail");
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"invalid remote control URL `ws://example.com/control`; expected http:// or https://"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persisted_remote_control_enrollment_is_scoped_and_selectively_cleared() {
|
||||
let codex_home = TempDir::new().expect("temp dir should create");
|
||||
let state_path = remote_control_state_path(codex_home.path());
|
||||
let first_target = normalize_remote_control_url("http://example.com/remote/control")
|
||||
.expect("first target should parse");
|
||||
let second_target = normalize_remote_control_url("http://example.com/other/control")
|
||||
.expect("second target should parse");
|
||||
let first_enrollment = RemoteControlEnrollment {
|
||||
server_id: "srv_e_first".to_string(),
|
||||
server_name: REMOTE_CONTROL_SERVER_NAME.to_string(),
|
||||
};
|
||||
let same_target_other_account_enrollment = RemoteControlEnrollment {
|
||||
server_id: "srv_e_first_account_b".to_string(),
|
||||
server_name: REMOTE_CONTROL_SERVER_NAME.to_string(),
|
||||
};
|
||||
let second_enrollment = RemoteControlEnrollment {
|
||||
server_id: "srv_e_second".to_string(),
|
||||
server_name: REMOTE_CONTROL_SERVER_NAME.to_string(),
|
||||
};
|
||||
|
||||
update_persisted_remote_control_enrollment(
|
||||
state_path.as_path(),
|
||||
&first_target,
|
||||
Some("account-a"),
|
||||
Some(&first_enrollment),
|
||||
)
|
||||
.await
|
||||
.expect("first enrollment should persist");
|
||||
update_persisted_remote_control_enrollment(
|
||||
state_path.as_path(),
|
||||
&second_target,
|
||||
Some("account-a"),
|
||||
Some(&second_enrollment),
|
||||
)
|
||||
.await
|
||||
.expect("second enrollment should persist");
|
||||
|
||||
update_persisted_remote_control_enrollment(
|
||||
state_path.as_path(),
|
||||
&first_target,
|
||||
Some("account-b"),
|
||||
Some(&same_target_other_account_enrollment),
|
||||
)
|
||||
.await
|
||||
.expect("other-account enrollment should persist");
|
||||
|
||||
assert_eq!(
|
||||
load_persisted_remote_control_enrollment(
|
||||
state_path.as_path(),
|
||||
&first_target,
|
||||
Some("account-a")
|
||||
)
|
||||
.await,
|
||||
Some(first_enrollment.clone())
|
||||
);
|
||||
assert_eq!(
|
||||
load_persisted_remote_control_enrollment(
|
||||
state_path.as_path(),
|
||||
&first_target,
|
||||
Some("account-b")
|
||||
)
|
||||
.await,
|
||||
Some(same_target_other_account_enrollment.clone())
|
||||
);
|
||||
assert_eq!(
|
||||
load_persisted_remote_control_enrollment(
|
||||
state_path.as_path(),
|
||||
&second_target,
|
||||
Some("account-a")
|
||||
)
|
||||
.await,
|
||||
Some(second_enrollment.clone())
|
||||
);
|
||||
|
||||
update_persisted_remote_control_enrollment(
|
||||
state_path.as_path(),
|
||||
&first_target,
|
||||
Some("account-a"),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("matching enrollment should clear");
|
||||
|
||||
assert_eq!(
|
||||
load_persisted_remote_control_enrollment(
|
||||
state_path.as_path(),
|
||||
&first_target,
|
||||
Some("account-b")
|
||||
)
|
||||
.await,
|
||||
Some(same_target_other_account_enrollment)
|
||||
);
|
||||
assert_eq!(
|
||||
load_persisted_remote_control_enrollment(
|
||||
state_path.as_path(),
|
||||
&second_target,
|
||||
Some("account-a")
|
||||
)
|
||||
.await,
|
||||
Some(second_enrollment)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enrollment_manager_cache_is_scoped_to_the_current_account() {
|
||||
let codex_home = TempDir::new().expect("temp dir should create");
|
||||
let state_path = remote_control_state_path(codex_home.path());
|
||||
let remote_control_target = normalize_remote_control_url("http://example.com/remote/control")
|
||||
.expect("target should parse");
|
||||
let account_a_enrollment = RemoteControlEnrollment {
|
||||
server_id: "srv_e_account_a".to_string(),
|
||||
server_name: REMOTE_CONTROL_SERVER_NAME.to_string(),
|
||||
};
|
||||
let account_b_enrollment = RemoteControlEnrollment {
|
||||
server_id: "srv_e_account_b".to_string(),
|
||||
server_name: REMOTE_CONTROL_SERVER_NAME.to_string(),
|
||||
};
|
||||
|
||||
update_persisted_remote_control_enrollment(
|
||||
state_path.as_path(),
|
||||
&remote_control_target,
|
||||
Some("account-a"),
|
||||
Some(&account_a_enrollment),
|
||||
)
|
||||
.await
|
||||
.expect("account-a enrollment should persist");
|
||||
update_persisted_remote_control_enrollment(
|
||||
state_path.as_path(),
|
||||
&remote_control_target,
|
||||
Some("account-b"),
|
||||
Some(&account_b_enrollment),
|
||||
)
|
||||
.await
|
||||
.expect("account-b enrollment should persist");
|
||||
|
||||
let mut enrollment_manager =
|
||||
EnrollmentManager::new(remote_control_target, codex_home.path().to_path_buf());
|
||||
|
||||
assert_eq!(
|
||||
enrollment_manager
|
||||
.enroll(&RemoteControlConnectionAuth {
|
||||
bearer_token: "Access Token".to_string(),
|
||||
account_id: Some("account-a".to_string()),
|
||||
})
|
||||
.await
|
||||
.expect("account-a enrollment should load from the cache or state file"),
|
||||
account_a_enrollment
|
||||
);
|
||||
assert_eq!(
|
||||
enrollment_manager
|
||||
.enroll(&RemoteControlConnectionAuth {
|
||||
bearer_token: "Access Token".to_string(),
|
||||
account_id: Some("account-b".to_string()),
|
||||
})
|
||||
.await
|
||||
.expect("account-b enrollment should replace the cached account-a enrollment"),
|
||||
account_b_enrollment
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enroll_remote_control_server_parse_failure_includes_response_body() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let remote_control_url = format!(
|
||||
"http://{}/api/codex",
|
||||
listener
|
||||
.local_addr()
|
||||
.expect("listener should have a local addr")
|
||||
);
|
||||
let remote_control_target =
|
||||
normalize_remote_control_url(&remote_control_url).expect("target should parse");
|
||||
let enroll_url = remote_control_target.enroll_url.clone();
|
||||
let response_body = json!({
|
||||
"error": "not enrolled",
|
||||
});
|
||||
let expected_body = response_body.to_string();
|
||||
let server_task = tokio::spawn(async move {
|
||||
let enroll_request = accept_http_request(&listener).await;
|
||||
respond_with_json(enroll_request.stream, response_body).await;
|
||||
});
|
||||
|
||||
let err = enroll_remote_control_server(
|
||||
&remote_control_target,
|
||||
&RemoteControlConnectionAuth {
|
||||
bearer_token: "Access Token".to_string(),
|
||||
account_id: Some("account_id".to_string()),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect_err("invalid response should fail to parse");
|
||||
|
||||
server_task.await.expect("server task should succeed");
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
format!(
|
||||
"failed to parse remote control enrollment response from `{enroll_url}`: HTTP 200 OK, body: {expected_body}, decode error: missing field `server_id` at line 1 column {}",
|
||||
expected_body.len()
|
||||
)
|
||||
);
|
||||
}
|
||||
325
codex-rs/app-server/src/remote_control/entrollment_manager.rs
Normal file
325
codex-rs/app-server/src/remote_control/entrollment_manager.rs
Normal file
@@ -0,0 +1,325 @@
|
||||
use codex_core::default_client::build_reqwest_client;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Result as IoResult;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tracing::warn;
|
||||
|
||||
use super::REMOTE_CONTROL_ACCOUNT_ID_HEADER;
|
||||
use super::REMOTE_CONTROL_STATE_FILE;
|
||||
use super::RemoteControlConnectionAuth;
|
||||
use super::RemoteControlEnrollment;
|
||||
use super::RemoteControlTarget;
|
||||
|
||||
const REMOTE_CONTROL_ENROLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
|
||||
const REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES: usize = 4096;
|
||||
|
||||
const REMOTE_CONTROL_SERVER_NAME: &str = "codex-app-server";
|
||||
|
||||
struct CachedRemoteControlEnrollment {
|
||||
account_id: Option<String>,
|
||||
enrollment: RemoteControlEnrollment,
|
||||
}
|
||||
|
||||
pub(super) struct EnrollmentManager {
|
||||
remote_control_target: RemoteControlTarget,
|
||||
remote_control_state_path: PathBuf,
|
||||
cached_enrollment: Option<CachedRemoteControlEnrollment>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||
struct RemoteControlStateToml {
|
||||
#[serde(default)]
|
||||
enrollments: Vec<PersistedRemoteControlEnrollment>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
struct PersistedRemoteControlEnrollment {
|
||||
websocket_url: String,
|
||||
account_id: Option<String>,
|
||||
server_id: String,
|
||||
server_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct EnrollRemoteServerRequest<'a> {
|
||||
name: &'a str,
|
||||
os: &'a str,
|
||||
arch: &'a str,
|
||||
app_server_version: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct EnrollRemoteServerResponse {
|
||||
server_id: String,
|
||||
}
|
||||
|
||||
impl EnrollmentManager {
|
||||
pub(super) fn new(remote_control_target: RemoteControlTarget, codex_home: PathBuf) -> Self {
|
||||
Self {
|
||||
remote_control_target,
|
||||
remote_control_state_path: remote_control_state_path(codex_home.as_path()),
|
||||
cached_enrollment: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn enroll(
|
||||
&mut self,
|
||||
auth: &RemoteControlConnectionAuth,
|
||||
) -> IoResult<RemoteControlEnrollment> {
|
||||
if self
|
||||
.cached_enrollment
|
||||
.as_ref()
|
||||
.and_then(|cached| cached.account_id.as_deref())
|
||||
!= auth.account_id.as_deref()
|
||||
{
|
||||
self.cached_enrollment = None;
|
||||
}
|
||||
|
||||
if self.cached_enrollment.is_none() {
|
||||
self.cached_enrollment = load_persisted_remote_control_enrollment(
|
||||
self.remote_control_state_path.as_path(),
|
||||
&self.remote_control_target,
|
||||
auth.account_id.as_deref(),
|
||||
)
|
||||
.await
|
||||
.map(|enrollment| CachedRemoteControlEnrollment {
|
||||
account_id: auth.account_id.clone(),
|
||||
enrollment,
|
||||
});
|
||||
}
|
||||
|
||||
if self.cached_enrollment.is_none() {
|
||||
let new_enrollment =
|
||||
enroll_remote_control_server(&self.remote_control_target, auth).await?;
|
||||
if let Err(err) = update_persisted_remote_control_enrollment(
|
||||
self.remote_control_state_path.as_path(),
|
||||
&self.remote_control_target,
|
||||
auth.account_id.as_deref(),
|
||||
Some(&new_enrollment),
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
"failed to persist remote control enrollment in `{}`: {err}",
|
||||
self.remote_control_state_path.display()
|
||||
);
|
||||
}
|
||||
self.cached_enrollment = Some(CachedRemoteControlEnrollment {
|
||||
account_id: auth.account_id.clone(),
|
||||
enrollment: new_enrollment,
|
||||
});
|
||||
}
|
||||
|
||||
let enrollment = self
|
||||
.cached_enrollment
|
||||
.as_ref()
|
||||
.map(|cached| cached.enrollment.clone())
|
||||
.ok_or_else(|| {
|
||||
std::io::Error::other("missing remote control enrollment after enrollment step")
|
||||
})?;
|
||||
|
||||
Ok(enrollment)
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_control_server_name() -> String {
|
||||
let host_name = gethostname::gethostname();
|
||||
let host_name = host_name.to_string_lossy();
|
||||
let host_name = host_name.trim();
|
||||
if host_name.is_empty() {
|
||||
REMOTE_CONTROL_SERVER_NAME.to_string()
|
||||
} else {
|
||||
host_name.to_owned()
|
||||
}
|
||||
}
|
||||
|
||||
fn matches_persisted_remote_control_enrollment(
|
||||
entry: &PersistedRemoteControlEnrollment,
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
account_id: Option<&str>,
|
||||
) -> bool {
|
||||
entry.websocket_url == remote_control_target.websocket_url
|
||||
&& entry.account_id.as_deref() == account_id
|
||||
}
|
||||
|
||||
async fn load_remote_control_state(state_path: &Path) -> IoResult<RemoteControlStateToml> {
|
||||
let contents = match tokio::fs::read_to_string(state_path).await {
|
||||
Ok(contents) => contents,
|
||||
Err(err) if err.kind() == ErrorKind::NotFound => {
|
||||
return Ok(RemoteControlStateToml::default());
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
|
||||
toml::from_str(&contents).map_err(|err| {
|
||||
std::io::Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!(
|
||||
"failed to parse remote control state `{}`: {err}",
|
||||
state_path.display()
|
||||
),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn remote_control_state_path(codex_home: &Path) -> PathBuf {
|
||||
codex_home.join(REMOTE_CONTROL_STATE_FILE)
|
||||
}
|
||||
|
||||
async fn write_remote_control_state(
|
||||
state_path: &Path,
|
||||
state: &RemoteControlStateToml,
|
||||
) -> IoResult<()> {
|
||||
if let Some(parent) = state_path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
let serialized = toml::to_string(state).map_err(std::io::Error::other)?;
|
||||
tokio::fs::write(state_path, serialized).await
|
||||
}
|
||||
|
||||
pub(super) async fn load_persisted_remote_control_enrollment(
|
||||
state_path: &Path,
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
account_id: Option<&str>,
|
||||
) -> Option<RemoteControlEnrollment> {
|
||||
let state = match load_remote_control_state(state_path).await {
|
||||
Ok(state) => state,
|
||||
Err(err) => {
|
||||
warn!("{err}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
state
|
||||
.enrollments
|
||||
.into_iter()
|
||||
.find(|entry| {
|
||||
matches_persisted_remote_control_enrollment(entry, remote_control_target, account_id)
|
||||
})
|
||||
.map(|entry| RemoteControlEnrollment {
|
||||
server_id: entry.server_id,
|
||||
server_name: entry.server_name,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) async fn update_persisted_remote_control_enrollment(
|
||||
state_path: &Path,
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
account_id: Option<&str>,
|
||||
enrollment: Option<&RemoteControlEnrollment>,
|
||||
) -> IoResult<()> {
|
||||
let mut state = match load_remote_control_state(state_path).await {
|
||||
Ok(state) => state,
|
||||
Err(err) if err.kind() == ErrorKind::InvalidData => {
|
||||
warn!("{err}");
|
||||
RemoteControlStateToml::default()
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
|
||||
state.enrollments.retain(|entry| {
|
||||
!matches_persisted_remote_control_enrollment(entry, remote_control_target, account_id)
|
||||
});
|
||||
|
||||
if let Some(enrollment) = enrollment {
|
||||
state.enrollments.push(PersistedRemoteControlEnrollment {
|
||||
websocket_url: remote_control_target.websocket_url.clone(),
|
||||
account_id: account_id.map(str::to_owned),
|
||||
server_id: enrollment.server_id.clone(),
|
||||
server_name: enrollment.server_name.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
if state.enrollments.is_empty() {
|
||||
match tokio::fs::remove_file(state_path).await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(err) if err.kind() == ErrorKind::NotFound => Ok(()),
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
} else {
|
||||
write_remote_control_state(state_path, &state).await
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn preview_remote_control_response_body(body: &[u8]) -> String {
|
||||
let body = String::from_utf8_lossy(body);
|
||||
let trimmed = body.trim();
|
||||
if trimmed.is_empty() {
|
||||
return "<empty>".to_string();
|
||||
}
|
||||
if trimmed.len() <= REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
|
||||
let mut cut = REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES;
|
||||
while !trimmed.is_char_boundary(cut) {
|
||||
cut = cut.saturating_sub(1);
|
||||
}
|
||||
let mut truncated = trimmed[..cut].to_string();
|
||||
truncated.push_str("...");
|
||||
truncated
|
||||
}
|
||||
|
||||
async fn enroll_remote_control_server(
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
auth: &RemoteControlConnectionAuth,
|
||||
) -> IoResult<RemoteControlEnrollment> {
|
||||
let server_name = remote_control_server_name();
|
||||
let request = EnrollRemoteServerRequest {
|
||||
name: &server_name,
|
||||
os: std::env::consts::OS,
|
||||
arch: std::env::consts::ARCH,
|
||||
app_server_version: env!("CARGO_PKG_VERSION"),
|
||||
};
|
||||
let client = build_reqwest_client();
|
||||
|
||||
let mut http_request = client
|
||||
.post(&remote_control_target.enroll_url)
|
||||
.timeout(REMOTE_CONTROL_ENROLL_TIMEOUT)
|
||||
.bearer_auth(&auth.bearer_token)
|
||||
.json(&request);
|
||||
|
||||
if let Some(account_id) = auth.account_id.as_deref() {
|
||||
http_request = http_request.header(REMOTE_CONTROL_ACCOUNT_ID_HEADER, account_id);
|
||||
}
|
||||
|
||||
let response = http_request.send().await.map_err(|err| {
|
||||
std::io::Error::other(format!(
|
||||
"failed to enroll remote control server at `{}`: {err}",
|
||||
remote_control_target.enroll_url
|
||||
))
|
||||
})?;
|
||||
let status = response.status();
|
||||
let body = response.bytes().await.map_err(|err| {
|
||||
std::io::Error::other(format!(
|
||||
"failed to read remote control enrollment response from `{}`: {err}",
|
||||
remote_control_target.enroll_url
|
||||
))
|
||||
})?;
|
||||
let body_preview = preview_remote_control_response_body(&body);
|
||||
if !status.is_success() {
|
||||
return Err(std::io::Error::other(format!(
|
||||
"remote control server enrollment failed at `{}`: HTTP {status}, body: {body_preview}",
|
||||
remote_control_target.enroll_url
|
||||
)));
|
||||
}
|
||||
|
||||
let enrollment = serde_json::from_slice::<EnrollRemoteServerResponse>(&body).map_err(|err| {
|
||||
std::io::Error::other(format!(
|
||||
"failed to parse remote control enrollment response from `{}`: HTTP {status}, body: {body_preview}, decode error: {err}",
|
||||
remote_control_target.enroll_url
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(RemoteControlEnrollment {
|
||||
server_id: enrollment.server_id,
|
||||
server_name,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "enrollment_tests.rs"]
|
||||
mod enrollment_tests;
|
||||
251
codex-rs/app-server/src/remote_control/mod.rs
Normal file
251
codex-rs/app-server/src/remote_control/mod.rs
Normal file
@@ -0,0 +1,251 @@
|
||||
mod client_manager;
|
||||
mod connection_manager;
|
||||
mod entrollment_manager;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_support;
|
||||
|
||||
use self::entrollment_manager::EnrollmentManager;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::transport::CHANNEL_CAPACITY;
|
||||
use crate::transport::ConnectionIdAllocator;
|
||||
use crate::transport::TransportEvent;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_core::AuthManager;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Result as IoResult;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id";
|
||||
const REMOTE_CONTROL_REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
const REMOTE_CONTROL_STATE_FILE: &str = "remote_control.toml";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct RemoteControlTarget {
|
||||
websocket_url: String,
|
||||
enroll_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct RemoteControlEnrollment {
|
||||
server_id: String,
|
||||
server_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct RemoteControlConnectionAuth {
|
||||
bearer_token: String,
|
||||
account_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
struct ClientId(String);
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum ClientActivityState {
|
||||
Foreground,
|
||||
Background,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum ClientEvent {
|
||||
ClientMessage {
|
||||
client_id: ClientId,
|
||||
message: JSONRPCMessage,
|
||||
},
|
||||
Ping {
|
||||
client_id: ClientId,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
state: Option<ClientActivityState>,
|
||||
},
|
||||
ClientClosed {
|
||||
client_id: ClientId,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum ServerEvent {
|
||||
ServerMessage {
|
||||
client_id: ClientId,
|
||||
message: Box<OutgoingMessage>,
|
||||
},
|
||||
Pong {
|
||||
client_id: ClientId,
|
||||
status: PongStatus,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum PongStatus {
|
||||
Active,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
pub(crate) async fn start_remote_control(
|
||||
remote_control_url: String,
|
||||
codex_home: PathBuf,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
shutdown_token: CancellationToken,
|
||||
connection_id_allocator: ConnectionIdAllocator,
|
||||
) -> IoResult<JoinHandle<()>> {
|
||||
let remote_control_target = normalize_remote_control_url(&remote_control_url)?;
|
||||
let enrollment_manager = EnrollmentManager::new(remote_control_target.clone(), codex_home);
|
||||
|
||||
Ok(tokio::spawn(async move {
|
||||
let local_shutdown_token = shutdown_token.child_token();
|
||||
let (client_event_tx, client_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (server_event_tx, server_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (writer_exited_tx, writer_exited_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
|
||||
let mut websocket_task = tokio::spawn(connection_manager::run(
|
||||
auth_manager,
|
||||
remote_control_target,
|
||||
enrollment_manager,
|
||||
client_event_tx,
|
||||
server_event_rx,
|
||||
local_shutdown_token.clone(),
|
||||
));
|
||||
let mut manager_task = tokio::spawn(client_manager::run(
|
||||
transport_event_tx,
|
||||
client_event_rx,
|
||||
server_event_tx,
|
||||
writer_exited_tx,
|
||||
writer_exited_rx,
|
||||
local_shutdown_token.clone(),
|
||||
connection_id_allocator,
|
||||
));
|
||||
|
||||
tokio::select! {
|
||||
_ = local_shutdown_token.cancelled() => {}
|
||||
_ = &mut websocket_task => {
|
||||
local_shutdown_token.cancel();
|
||||
}
|
||||
_ = &mut manager_task => {
|
||||
local_shutdown_token.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
let _ = websocket_task.await;
|
||||
let _ = manager_task.await;
|
||||
}))
|
||||
}
|
||||
|
||||
fn normalize_remote_control_url(remote_control_url: &str) -> IoResult<RemoteControlTarget> {
|
||||
let remote_control_url = remote_control_url.trim_end_matches('/');
|
||||
|
||||
if let Some(rest) = remote_control_url.strip_prefix("http://") {
|
||||
return Ok(normalize_http_remote_control_url(rest, "http://", "ws://"));
|
||||
}
|
||||
if let Some(rest) = remote_control_url.strip_prefix("https://") {
|
||||
return Ok(normalize_http_remote_control_url(
|
||||
rest, "https://", "wss://",
|
||||
));
|
||||
}
|
||||
|
||||
Err(std::io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("invalid remote control URL `{remote_control_url}`; expected http:// or https://"),
|
||||
))
|
||||
}
|
||||
|
||||
fn normalize_http_remote_control_url(
|
||||
rest: &str,
|
||||
http_scheme: &str,
|
||||
websocket_scheme: &str,
|
||||
) -> RemoteControlTarget {
|
||||
let rest = normalize_chatgpt_remote_control_base(rest);
|
||||
let rest = if let Some(rest) = rest.strip_suffix("/remote/control/server/enroll") {
|
||||
format!("{rest}/remote/control/server")
|
||||
} else if rest.ends_with("/remote/control/server") {
|
||||
rest
|
||||
} else if let Some(rest) = rest.strip_suffix("/server/enroll") {
|
||||
format!("{rest}/server")
|
||||
} else if rest.ends_with("/server") {
|
||||
rest
|
||||
} else {
|
||||
format!("{rest}/remote/control/server")
|
||||
};
|
||||
|
||||
RemoteControlTarget {
|
||||
websocket_url: format!("{websocket_scheme}{rest}"),
|
||||
enroll_url: format!("{http_scheme}{rest}/enroll"),
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_chatgpt_remote_control_base(rest: &str) -> String {
|
||||
let trimmed = rest.trim_end_matches('/');
|
||||
let (host, path) = match trimmed.split_once('/') {
|
||||
Some((host, path)) => (host, Some(path)),
|
||||
None => (trimmed, None),
|
||||
};
|
||||
if host != "chatgpt.com" && host != "chat.openai.com" {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
|
||||
let Some(path) = path else {
|
||||
return format!("{host}/backend-api/wham");
|
||||
};
|
||||
|
||||
if path == "backend-api/wham" || path.starts_with("backend-api/wham/") {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
|
||||
for internal_prefix in ["api/codex", "backend-api"] {
|
||||
if path == internal_prefix {
|
||||
return format!("{host}/backend-api/wham");
|
||||
}
|
||||
if let Some(suffix) = path.strip_prefix(&format!("{internal_prefix}/")) {
|
||||
if suffix == "remote/control/server" || suffix == "remote/control/server/enroll" {
|
||||
return format!("{host}/backend-api/wham/{suffix}");
|
||||
}
|
||||
return format!("{host}/backend-api/wham");
|
||||
}
|
||||
}
|
||||
|
||||
if path == "remote/control/server" || path == "remote/control/server/enroll" {
|
||||
return format!("{host}/backend-api/wham/{path}");
|
||||
}
|
||||
|
||||
trimmed.to_string()
|
||||
}
|
||||
|
||||
async fn load_remote_control_auth(
|
||||
auth_manager: &AuthManager,
|
||||
) -> IoResult<RemoteControlConnectionAuth> {
|
||||
let auth = match auth_manager.auth().await {
|
||||
Some(auth) => auth,
|
||||
None => {
|
||||
auth_manager.reload();
|
||||
auth_manager.auth().await.ok_or_else(|| {
|
||||
std::io::Error::new(
|
||||
ErrorKind::PermissionDenied,
|
||||
"remote control requires ChatGPT authentication",
|
||||
)
|
||||
})?
|
||||
}
|
||||
};
|
||||
|
||||
if !auth.is_chatgpt_auth() {
|
||||
return Err(std::io::Error::new(
|
||||
ErrorKind::PermissionDenied,
|
||||
"remote control requires ChatGPT authentication; API key auth is not supported",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(RemoteControlConnectionAuth {
|
||||
bearer_token: auth.get_token().map_err(std::io::Error::other)?,
|
||||
account_id: auth.get_account_id(),
|
||||
})
|
||||
}
|
||||
218
codex-rs/app-server/src/remote_control/test_support.rs
Normal file
218
codex-rs/app-server/src/remote_control/test_support.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
use super::ClientEvent;
|
||||
use super::REMOTE_CONTROL_REQUEST_ID_HEADER;
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::test_support::auth_manager_from_auth;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::accept_hdr_async;
|
||||
use tokio_tungstenite::tungstenite;
|
||||
use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
|
||||
use tokio_tungstenite::tungstenite::http::HeaderValue;
|
||||
|
||||
pub(super) fn remote_control_auth_manager() -> Arc<AuthManager> {
|
||||
auth_manager_from_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct CapturedHttpRequest {
|
||||
pub(super) stream: TcpStream,
|
||||
pub(super) request_line: String,
|
||||
pub(super) headers: BTreeMap<String, String>,
|
||||
pub(super) body: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub(super) struct CapturedWebSocketRequest {
|
||||
pub(super) path: String,
|
||||
pub(super) headers: BTreeMap<String, String>,
|
||||
}
|
||||
|
||||
pub(super) async fn accept_http_request(listener: &TcpListener) -> CapturedHttpRequest {
|
||||
let (stream, _) = timeout(Duration::from_secs(5), listener.accept())
|
||||
.await
|
||||
.expect("HTTP request should arrive in time")
|
||||
.expect("listener accept should succeed");
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
let mut request_line = String::new();
|
||||
reader
|
||||
.read_line(&mut request_line)
|
||||
.await
|
||||
.expect("request line should read");
|
||||
let request_line = request_line.trim_end_matches("\r\n").to_string();
|
||||
|
||||
let mut headers = BTreeMap::new();
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
reader
|
||||
.read_line(&mut line)
|
||||
.await
|
||||
.expect("header line should read");
|
||||
if line == "\r\n" {
|
||||
break;
|
||||
}
|
||||
let line = line.trim_end_matches("\r\n");
|
||||
let (name, value) = line.split_once(':').expect("header should contain colon");
|
||||
headers.insert(name.to_ascii_lowercase(), value.trim().to_string());
|
||||
}
|
||||
|
||||
let content_length = headers
|
||||
.get("content-length")
|
||||
.and_then(|value| value.parse::<usize>().ok())
|
||||
.unwrap_or(0);
|
||||
let mut body = vec![0; content_length];
|
||||
reader
|
||||
.read_exact(&mut body)
|
||||
.await
|
||||
.expect("request body should read");
|
||||
|
||||
CapturedHttpRequest {
|
||||
stream: reader.into_inner(),
|
||||
request_line,
|
||||
headers,
|
||||
body: String::from_utf8(body).expect("body should be utf-8"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn respond_with_json(mut stream: TcpStream, body: serde_json::Value) {
|
||||
let body = body.to_string();
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
|
||||
body.len()
|
||||
);
|
||||
stream
|
||||
.write_all(response.as_bytes())
|
||||
.await
|
||||
.expect("response should write");
|
||||
stream.flush().await.expect("response should flush");
|
||||
}
|
||||
|
||||
pub(super) async fn respond_with_status_and_headers(
|
||||
mut stream: TcpStream,
|
||||
status: &str,
|
||||
headers: &[(&str, &str)],
|
||||
body: &str,
|
||||
) {
|
||||
let extra_headers = headers
|
||||
.iter()
|
||||
.map(|(name, value)| format!("{name}: {value}\r\n"))
|
||||
.collect::<String>();
|
||||
let response = format!(
|
||||
"HTTP/1.1 {status}\r\ncontent-type: text/plain\r\ncontent-length: {}\r\nconnection: close\r\n{extra_headers}\r\n{body}",
|
||||
body.len(),
|
||||
);
|
||||
stream
|
||||
.write_all(response.as_bytes())
|
||||
.await
|
||||
.expect("response should write");
|
||||
stream.flush().await.expect("response should flush");
|
||||
}
|
||||
|
||||
pub(super) async fn accept_remote_control_backend_connection(
|
||||
listener: &TcpListener,
|
||||
request_id: Option<&str>,
|
||||
) -> (CapturedWebSocketRequest, WebSocketStream<TcpStream>) {
|
||||
let (stream, _) = timeout(Duration::from_secs(5), listener.accept())
|
||||
.await
|
||||
.expect("websocket request should arrive in time")
|
||||
.expect("listener accept should succeed");
|
||||
let captured_request = Arc::new(std::sync::Mutex::new(None::<CapturedWebSocketRequest>));
|
||||
let captured_request_for_callback = captured_request.clone();
|
||||
let request_id = request_id.map(str::to_owned);
|
||||
let websocket = accept_hdr_async(
|
||||
stream,
|
||||
move |request: &tungstenite::handshake::server::Request,
|
||||
mut response: tungstenite::handshake::server::Response| {
|
||||
let headers = request
|
||||
.headers()
|
||||
.iter()
|
||||
.map(|(name, value)| {
|
||||
(
|
||||
name.as_str().to_ascii_lowercase(),
|
||||
value
|
||||
.to_str()
|
||||
.expect("header should be valid utf-8")
|
||||
.to_string(),
|
||||
)
|
||||
})
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
*captured_request_for_callback
|
||||
.lock()
|
||||
.expect("capture lock should acquire") = Some(CapturedWebSocketRequest {
|
||||
path: request.uri().path().to_string(),
|
||||
headers,
|
||||
});
|
||||
if let Some(request_id) = request_id.as_deref() {
|
||||
response.headers_mut().insert(
|
||||
REMOTE_CONTROL_REQUEST_ID_HEADER,
|
||||
HeaderValue::from_str(request_id)
|
||||
.expect("request id should be a valid header value"),
|
||||
);
|
||||
}
|
||||
Ok(response)
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("websocket handshake should succeed");
|
||||
let captured_request = captured_request
|
||||
.lock()
|
||||
.expect("capture lock should acquire")
|
||||
.clone()
|
||||
.expect("websocket request should be captured");
|
||||
(captured_request, websocket)
|
||||
}
|
||||
|
||||
pub(super) async fn send_client_event(
|
||||
websocket: &mut WebSocketStream<TcpStream>,
|
||||
client_event: ClientEvent,
|
||||
) {
|
||||
let payload = serde_json::to_string(&client_event).expect("client event should serialize");
|
||||
websocket
|
||||
.send(TungsteniteMessage::Text(payload.into()))
|
||||
.await
|
||||
.expect("client event should send");
|
||||
}
|
||||
|
||||
pub(super) async fn read_server_event(
|
||||
websocket: &mut WebSocketStream<TcpStream>,
|
||||
) -> serde_json::Value {
|
||||
loop {
|
||||
let frame = timeout(Duration::from_secs(5), websocket.next())
|
||||
.await
|
||||
.expect("server event should arrive in time")
|
||||
.expect("websocket should stay open")
|
||||
.expect("websocket frame should be readable");
|
||||
match frame {
|
||||
TungsteniteMessage::Text(text) => {
|
||||
return serde_json::from_str(text.as_ref())
|
||||
.expect("server event should deserialize");
|
||||
}
|
||||
TungsteniteMessage::Ping(payload) => {
|
||||
websocket
|
||||
.send(TungsteniteMessage::Pong(payload))
|
||||
.await
|
||||
.expect("websocket pong should send");
|
||||
}
|
||||
TungsteniteMessage::Pong(_) => {}
|
||||
TungsteniteMessage::Close(frame) => {
|
||||
panic!("unexpected websocket close frame: {frame:?}");
|
||||
}
|
||||
TungsteniteMessage::Binary(_) => {
|
||||
panic!("unexpected binary websocket frame");
|
||||
}
|
||||
TungsteniteMessage::Frame(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -51,7 +51,7 @@ use tracing::warn;
|
||||
/// plenty for an interactive CLI.
|
||||
pub(crate) const CHANNEL_CAPACITY: usize = 128;
|
||||
|
||||
fn colorize(text: &str, style: Style) -> String {
|
||||
pub(crate) fn colorize(text: &str, style: Style) -> String {
|
||||
text.if_supports_color(Stream::Stderr, |value| value.style(style))
|
||||
.to_string()
|
||||
}
|
||||
@@ -84,7 +84,8 @@ fn print_websocket_startup_banner(addr: SocketAddr) {
|
||||
#[derive(Clone)]
|
||||
struct WebSocketListenerState {
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
connection_counter: Arc<AtomicU64>,
|
||||
connection_id_allocator: ConnectionIdAllocator,
|
||||
transport_kind: AppServerTransport,
|
||||
}
|
||||
|
||||
async fn health_check_handler() -> StatusCode {
|
||||
@@ -96,10 +97,16 @@ async fn websocket_upgrade_handler(
|
||||
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
|
||||
State(state): State<WebSocketListenerState>,
|
||||
) -> impl IntoResponse {
|
||||
let connection_id = ConnectionId(state.connection_counter.fetch_add(1, Ordering::Relaxed));
|
||||
let connection_id = state.connection_id_allocator.next_connection_id();
|
||||
info!(%peer_addr, "websocket client connected");
|
||||
websocket.on_upgrade(move |stream| async move {
|
||||
run_websocket_connection(connection_id, stream, state.transport_event_tx).await;
|
||||
run_websocket_connection(
|
||||
connection_id,
|
||||
stream,
|
||||
state.transport_event_tx,
|
||||
state.transport_kind,
|
||||
)
|
||||
.await;
|
||||
})
|
||||
}
|
||||
|
||||
@@ -107,6 +114,36 @@ async fn websocket_upgrade_handler(
|
||||
pub enum AppServerTransport {
|
||||
Stdio,
|
||||
WebSocket { bind_address: SocketAddr },
|
||||
RemoteControlled,
|
||||
}
|
||||
|
||||
impl AppServerTransport {
|
||||
pub const DEFAULT_LISTEN_URL: &'static str = "stdio://";
|
||||
|
||||
pub fn from_listen_url(listen_url: &str) -> Result<Self, AppServerTransportParseError> {
|
||||
if listen_url == Self::DEFAULT_LISTEN_URL {
|
||||
return Ok(Self::Stdio);
|
||||
}
|
||||
|
||||
if let Some(socket_addr) = listen_url.strip_prefix("ws://") {
|
||||
let bind_address = socket_addr.parse::<SocketAddr>().map_err(|_| {
|
||||
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string())
|
||||
})?;
|
||||
return Ok(Self::WebSocket { bind_address });
|
||||
}
|
||||
|
||||
Err(AppServerTransportParseError::UnsupportedListenUrl(
|
||||
listen_url.to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::Stdio => "stdio",
|
||||
Self::WebSocket { .. } => "websocket",
|
||||
Self::RemoteControlled => "remote_controlled",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
@@ -132,27 +169,6 @@ impl std::fmt::Display for AppServerTransportParseError {
|
||||
|
||||
impl std::error::Error for AppServerTransportParseError {}
|
||||
|
||||
impl AppServerTransport {
|
||||
pub const DEFAULT_LISTEN_URL: &'static str = "stdio://";
|
||||
|
||||
pub fn from_listen_url(listen_url: &str) -> Result<Self, AppServerTransportParseError> {
|
||||
if listen_url == Self::DEFAULT_LISTEN_URL {
|
||||
return Ok(Self::Stdio);
|
||||
}
|
||||
|
||||
if let Some(socket_addr) = listen_url.strip_prefix("ws://") {
|
||||
let bind_address = socket_addr.parse::<SocketAddr>().map_err(|_| {
|
||||
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string())
|
||||
})?;
|
||||
return Ok(Self::WebSocket { bind_address });
|
||||
}
|
||||
|
||||
Err(AppServerTransportParseError::UnsupportedListenUrl(
|
||||
listen_url.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for AppServerTransport {
|
||||
type Err = AppServerTransportParseError;
|
||||
|
||||
@@ -161,16 +177,37 @@ impl FromStr for AppServerTransport {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ConnectionIdAllocator {
|
||||
next_id: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl ConnectionIdAllocator {
|
||||
pub(crate) fn next_connection_id(&self) -> ConnectionId {
|
||||
ConnectionId(self.next_id.fetch_add(1, Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ConnectionIdAllocator {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
next_id: Arc::new(AtomicU64::new(1)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum TransportEvent {
|
||||
ConnectionOpened {
|
||||
connection_id: ConnectionId,
|
||||
transport_kind: AppServerTransport,
|
||||
writer: mpsc::Sender<OutgoingMessage>,
|
||||
allow_legacy_notifications: bool,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
},
|
||||
ConnectionClosed {
|
||||
connection_id: ConnectionId,
|
||||
transport_kind: AppServerTransport,
|
||||
},
|
||||
IncomingMessage {
|
||||
connection_id: ConnectionId,
|
||||
@@ -179,6 +216,7 @@ pub(crate) enum TransportEvent {
|
||||
}
|
||||
|
||||
pub(crate) struct ConnectionState {
|
||||
pub(crate) transport_kind: AppServerTransport,
|
||||
pub(crate) outbound_initialized: Arc<AtomicBool>,
|
||||
pub(crate) outbound_experimental_api_enabled: Arc<AtomicBool>,
|
||||
pub(crate) outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
@@ -187,11 +225,13 @@ pub(crate) struct ConnectionState {
|
||||
|
||||
impl ConnectionState {
|
||||
pub(crate) fn new(
|
||||
transport_kind: AppServerTransport,
|
||||
outbound_initialized: Arc<AtomicBool>,
|
||||
outbound_experimental_api_enabled: Arc<AtomicBool>,
|
||||
outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
transport_kind,
|
||||
outbound_initialized,
|
||||
outbound_experimental_api_enabled,
|
||||
outbound_opted_out_notification_methods,
|
||||
@@ -249,6 +289,7 @@ pub(crate) async fn start_stdio_connection(
|
||||
transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
transport_kind: AppServerTransport::Stdio,
|
||||
writer: writer_tx,
|
||||
allow_legacy_notifications: false,
|
||||
disconnect_sender: None,
|
||||
@@ -285,7 +326,10 @@ pub(crate) async fn start_stdio_connection(
|
||||
}
|
||||
|
||||
let _ = transport_event_tx_for_reader
|
||||
.send(TransportEvent::ConnectionClosed { connection_id })
|
||||
.send(TransportEvent::ConnectionClosed {
|
||||
connection_id,
|
||||
transport_kind: AppServerTransport::Stdio,
|
||||
})
|
||||
.await;
|
||||
debug!("stdin reader finished (EOF)");
|
||||
}));
|
||||
@@ -312,6 +356,7 @@ pub(crate) async fn start_websocket_acceptor(
|
||||
bind_address: SocketAddr,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
shutdown_token: CancellationToken,
|
||||
connection_id_allocator: ConnectionIdAllocator,
|
||||
) -> IoResult<JoinHandle<()>> {
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
let local_addr = listener.local_addr()?;
|
||||
@@ -324,7 +369,10 @@ pub(crate) async fn start_websocket_acceptor(
|
||||
.fallback(any(websocket_upgrade_handler))
|
||||
.with_state(WebSocketListenerState {
|
||||
transport_event_tx,
|
||||
connection_counter: Arc::new(AtomicU64::new(1)),
|
||||
connection_id_allocator,
|
||||
transport_kind: AppServerTransport::WebSocket {
|
||||
bind_address: local_addr,
|
||||
},
|
||||
});
|
||||
let server = axum::serve(
|
||||
listener,
|
||||
@@ -345,6 +393,7 @@ async fn run_websocket_connection(
|
||||
connection_id: ConnectionId,
|
||||
websocket_stream: WebSocket,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
transport_kind: AppServerTransport,
|
||||
) {
|
||||
let (writer_tx, writer_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let writer_tx_for_reader = writer_tx.clone();
|
||||
@@ -352,6 +401,7 @@ async fn run_websocket_connection(
|
||||
if transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
transport_kind,
|
||||
writer: writer_tx,
|
||||
allow_legacy_notifications: false,
|
||||
disconnect_sender: Some(disconnect_token.clone()),
|
||||
@@ -392,7 +442,10 @@ async fn run_websocket_connection(
|
||||
}
|
||||
|
||||
let _ = transport_event_tx
|
||||
.send(TransportEvent::ConnectionClosed { connection_id })
|
||||
.send(TransportEvent::ConnectionClosed {
|
||||
connection_id,
|
||||
transport_kind,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,419 @@
|
||||
use super::connection_handling_websocket::connect_websocket;
|
||||
use super::connection_handling_websocket::read_response_for_id;
|
||||
use super::connection_handling_websocket::send_initialize_request;
|
||||
use super::connection_handling_websocket::send_request;
|
||||
use super::connection_handling_websocket::spawn_websocket_server_with_args;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use app_test_support::ChatGptAuthFixture;
|
||||
use app_test_support::create_mock_responses_server_sequence_unchecked;
|
||||
use app_test_support::write_chatgpt_auth;
|
||||
use codex_app_server_protocol::ClientInfo;
|
||||
use codex_app_server_protocol::InitializeParams;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_core::auth::AuthCredentialsStoreMode;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use std::path::Path;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use tempfile::TempDir;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::process::Child;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::accept_hdr_async;
|
||||
use tokio_tungstenite::tungstenite;
|
||||
use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
|
||||
|
||||
type BackendWebSocket = WebSocketStream<TcpStream>;
|
||||
|
||||
#[tokio::test]
|
||||
async fn websocket_transport_with_remote_control_routes_connections_independently() -> Result<()> {
|
||||
let server = create_mock_responses_server_sequence_unchecked(Vec::new()).await;
|
||||
let remote_listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.context("listener should bind")?;
|
||||
let remote_control_base_url = format!(
|
||||
"http://{}/backend-api",
|
||||
remote_listener
|
||||
.local_addr()
|
||||
.context("listener should have local addr")?
|
||||
);
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml_with_remote_control(
|
||||
codex_home.path(),
|
||||
&server.uri(),
|
||||
&remote_control_base_url,
|
||||
"never",
|
||||
)?;
|
||||
write_chatgpt_auth(
|
||||
codex_home.path(),
|
||||
ChatGptAuthFixture::new("Access Token").account_id("account_id"),
|
||||
AuthCredentialsStoreMode::File,
|
||||
)?;
|
||||
|
||||
let (mut process, bind_addr) =
|
||||
spawn_websocket_server_with_args(codex_home.path(), &["--with-remote-control"]).await?;
|
||||
|
||||
let enroll_request = accept_http_request(&remote_listener).await?;
|
||||
assert_eq!(
|
||||
enroll_request.request_line,
|
||||
"POST /backend-api/remote/control/server/enroll HTTP/1.1"
|
||||
);
|
||||
respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_mixed" })).await?;
|
||||
let (backend_request, mut backend_websocket) =
|
||||
accept_remote_control_backend_connection(&remote_listener).await?;
|
||||
assert_eq!(backend_request.path, "/backend-api/remote/control/server");
|
||||
|
||||
let mut local_websocket = connect_websocket(bind_addr).await?;
|
||||
send_initialize_request(&mut local_websocket, 11, "local_ws_client").await?;
|
||||
assert_eq!(
|
||||
read_response_for_id(&mut local_websocket, 11).await?.id,
|
||||
RequestId::Integer(11)
|
||||
);
|
||||
|
||||
send_remote_request(
|
||||
&mut backend_websocket,
|
||||
"remote-client-1",
|
||||
"initialize",
|
||||
11,
|
||||
Some(serde_json::to_value(InitializeParams {
|
||||
client_info: ClientInfo {
|
||||
name: "remote_control_client".to_string(),
|
||||
title: Some("Remote Control Test Client".to_string()),
|
||||
version: "0.1.0".to_string(),
|
||||
},
|
||||
capabilities: None,
|
||||
})?),
|
||||
)
|
||||
.await?;
|
||||
let remote_initialize =
|
||||
read_remote_response_for_id(&mut backend_websocket, "remote-client-1", 11).await?;
|
||||
assert_eq!(remote_initialize["id"], json!(11));
|
||||
|
||||
send_request(
|
||||
&mut local_websocket,
|
||||
"config/read",
|
||||
77,
|
||||
Some(json!({ "includeLayers": false })),
|
||||
)
|
||||
.await?;
|
||||
send_remote_request(
|
||||
&mut backend_websocket,
|
||||
"remote-client-1",
|
||||
"config/read",
|
||||
77,
|
||||
Some(json!({ "includeLayers": false })),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let local_response = read_response_for_id(&mut local_websocket, 77).await?;
|
||||
let remote_response =
|
||||
read_remote_response_for_id(&mut backend_websocket, "remote-client-1", 77).await?;
|
||||
|
||||
assert_eq!(local_response.id, RequestId::Integer(77));
|
||||
assert!(local_response.result.get("config").is_some());
|
||||
assert_eq!(remote_response["id"], json!(77));
|
||||
assert!(remote_response["result"].get("config").is_some());
|
||||
|
||||
process
|
||||
.kill()
|
||||
.await
|
||||
.context("failed to stop websocket app-server process")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stdio_transport_with_remote_control_exits_when_stdio_closes() -> Result<()> {
|
||||
let server = create_mock_responses_server_sequence_unchecked(Vec::new()).await;
|
||||
let remote_listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.context("listener should bind")?;
|
||||
let remote_control_base_url = format!(
|
||||
"http://{}/backend-api",
|
||||
remote_listener
|
||||
.local_addr()
|
||||
.context("listener should have local addr")?
|
||||
);
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml_with_remote_control(
|
||||
codex_home.path(),
|
||||
&server.uri(),
|
||||
&remote_control_base_url,
|
||||
"never",
|
||||
)?;
|
||||
write_chatgpt_auth(
|
||||
codex_home.path(),
|
||||
ChatGptAuthFixture::new("Access Token").account_id("account_id"),
|
||||
AuthCredentialsStoreMode::File,
|
||||
)?;
|
||||
|
||||
let mut process = spawn_stdio_server_with_remote_control(codex_home.path()).await?;
|
||||
let enroll_request = accept_http_request(&remote_listener).await?;
|
||||
assert_eq!(
|
||||
enroll_request.request_line,
|
||||
"POST /backend-api/remote/control/server/enroll HTTP/1.1"
|
||||
);
|
||||
respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_stdio" })).await?;
|
||||
let (_backend_request, mut backend_websocket) =
|
||||
accept_remote_control_backend_connection(&remote_listener).await?;
|
||||
|
||||
drop(process.stdin.take());
|
||||
let exit_status = timeout(Duration::from_secs(10), process.wait())
|
||||
.await
|
||||
.context("timed out waiting for stdio app-server to exit")?
|
||||
.context("failed waiting for stdio app-server exit")?;
|
||||
assert!(exit_status.success());
|
||||
|
||||
let close_frame = timeout(Duration::from_secs(5), backend_websocket.next())
|
||||
.await
|
||||
.context("timed out waiting for remote-control websocket to close")?;
|
||||
match close_frame {
|
||||
Some(Ok(WebSocketMessage::Close(_))) | Some(Err(_)) | None => {}
|
||||
Some(Ok(other)) => {
|
||||
anyhow::bail!("unexpected websocket frame while waiting for shutdown: {other:?}")
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct CapturedHttpRequest {
|
||||
stream: TcpStream,
|
||||
request_line: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
struct CapturedWebSocketRequest {
|
||||
path: String,
|
||||
}
|
||||
|
||||
async fn accept_http_request(listener: &TcpListener) -> Result<CapturedHttpRequest> {
|
||||
let (stream, _) = timeout(Duration::from_secs(10), listener.accept())
|
||||
.await
|
||||
.context("HTTP request should arrive in time")?
|
||||
.context("listener accept should succeed")?;
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
let mut request_line = String::new();
|
||||
reader
|
||||
.read_line(&mut request_line)
|
||||
.await
|
||||
.context("request line should read")?;
|
||||
let request_line = request_line.trim_end_matches("\r\n").to_string();
|
||||
|
||||
let mut headers = std::collections::BTreeMap::new();
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
reader
|
||||
.read_line(&mut line)
|
||||
.await
|
||||
.context("header line should read")?;
|
||||
if line == "\r\n" {
|
||||
break;
|
||||
}
|
||||
let line = line.trim_end_matches("\r\n");
|
||||
let (name, value) = line
|
||||
.split_once(':')
|
||||
.context("header should contain colon")?;
|
||||
headers.insert(name.to_ascii_lowercase(), value.trim().to_string());
|
||||
}
|
||||
|
||||
let content_length = headers
|
||||
.get("content-length")
|
||||
.and_then(|value| value.parse::<usize>().ok())
|
||||
.unwrap_or(0);
|
||||
let mut body = vec![0; content_length];
|
||||
reader
|
||||
.read_exact(&mut body)
|
||||
.await
|
||||
.context("request body should read")?;
|
||||
|
||||
Ok(CapturedHttpRequest {
|
||||
stream: reader.into_inner(),
|
||||
request_line,
|
||||
})
|
||||
}
|
||||
|
||||
async fn respond_with_json(mut stream: TcpStream, body: serde_json::Value) -> Result<()> {
|
||||
let body = body.to_string();
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
|
||||
body.len()
|
||||
);
|
||||
stream
|
||||
.write_all(response.as_bytes())
|
||||
.await
|
||||
.context("response should write")?;
|
||||
stream.flush().await.context("response should flush")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept_remote_control_backend_connection(
|
||||
listener: &TcpListener,
|
||||
) -> Result<(CapturedWebSocketRequest, BackendWebSocket)> {
|
||||
let (stream, _) = timeout(Duration::from_secs(10), listener.accept())
|
||||
.await
|
||||
.context("websocket request should arrive in time")?
|
||||
.context("listener accept should succeed")?;
|
||||
let captured_request = Arc::new(std::sync::Mutex::new(None::<CapturedWebSocketRequest>));
|
||||
let captured_request_for_callback = Arc::clone(&captured_request);
|
||||
let websocket = accept_hdr_async(
|
||||
stream,
|
||||
move |request: &tungstenite::handshake::server::Request,
|
||||
response: tungstenite::handshake::server::Response| {
|
||||
let mut guard = match captured_request_for_callback.lock() {
|
||||
Ok(guard) => guard,
|
||||
Err(err) => panic!("capture lock should acquire: {err}"),
|
||||
};
|
||||
*guard = Some(CapturedWebSocketRequest {
|
||||
path: request.uri().path().to_string(),
|
||||
});
|
||||
Ok(response)
|
||||
},
|
||||
)
|
||||
.await
|
||||
.context("websocket handshake should succeed")?;
|
||||
let captured_request = match captured_request.lock() {
|
||||
Ok(guard) => guard.clone(),
|
||||
Err(err) => panic!("capture lock should acquire: {err}"),
|
||||
}
|
||||
.context("websocket request should be captured")?;
|
||||
Ok((captured_request, websocket))
|
||||
}
|
||||
|
||||
async fn send_remote_request(
|
||||
websocket: &mut BackendWebSocket,
|
||||
client_id: &str,
|
||||
method: &str,
|
||||
id: i64,
|
||||
params: Option<Value>,
|
||||
) -> Result<()> {
|
||||
let message = serde_json::to_value(JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(id),
|
||||
method: method.to_string(),
|
||||
params,
|
||||
trace: None,
|
||||
}))?;
|
||||
let payload = json!({
|
||||
"type": "client_message",
|
||||
"client_id": client_id,
|
||||
"message": message,
|
||||
});
|
||||
websocket
|
||||
.send(WebSocketMessage::Text(payload.to_string().into()))
|
||||
.await
|
||||
.context("client event should send")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_remote_response_for_id(
|
||||
websocket: &mut BackendWebSocket,
|
||||
client_id: &str,
|
||||
id: i64,
|
||||
) -> Result<Value> {
|
||||
loop {
|
||||
let event = read_remote_server_event(websocket).await?;
|
||||
if event["type"] == json!("server_message")
|
||||
&& event["client_id"] == json!(client_id)
|
||||
&& event["message"]["id"] == json!(id)
|
||||
{
|
||||
return Ok(event["message"].clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_remote_server_event(websocket: &mut BackendWebSocket) -> Result<Value> {
|
||||
loop {
|
||||
let frame = timeout(Duration::from_secs(5), websocket.next())
|
||||
.await
|
||||
.context("server event should arrive in time")?
|
||||
.context("websocket should stay open")?
|
||||
.context("websocket frame should be readable")?;
|
||||
match frame {
|
||||
WebSocketMessage::Text(text) => {
|
||||
return serde_json::from_str(text.as_ref())
|
||||
.context("server event should deserialize");
|
||||
}
|
||||
WebSocketMessage::Ping(payload) => {
|
||||
websocket
|
||||
.send(WebSocketMessage::Pong(payload))
|
||||
.await
|
||||
.context("websocket pong should send")?;
|
||||
}
|
||||
WebSocketMessage::Pong(_) => {}
|
||||
WebSocketMessage::Close(frame) => {
|
||||
anyhow::bail!("unexpected websocket close frame: {frame:?}");
|
||||
}
|
||||
WebSocketMessage::Binary(_) => anyhow::bail!("unexpected binary websocket frame"),
|
||||
WebSocketMessage::Frame(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_stdio_server_with_remote_control(codex_home: &Path) -> Result<Child> {
|
||||
let program = codex_utils_cargo_bin::cargo_bin("codex-app-server")
|
||||
.context("should find app-server binary")?;
|
||||
let mut command = Command::new(program);
|
||||
command
|
||||
.arg("--with-remote-control")
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::piped())
|
||||
.env("CODEX_HOME", codex_home)
|
||||
.env("RUST_LOG", "debug");
|
||||
let mut process = command
|
||||
.kill_on_drop(true)
|
||||
.spawn()
|
||||
.context("failed to spawn stdio app-server process")?;
|
||||
if let Some(stderr) = process.stderr.take() {
|
||||
let mut stderr_reader = BufReader::new(stderr).lines();
|
||||
tokio::spawn(async move {
|
||||
while let Ok(Some(line)) = stderr_reader.next_line().await {
|
||||
eprintln!("[stdio app-server stderr] {line}");
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(process)
|
||||
}
|
||||
|
||||
fn create_config_toml_with_remote_control(
|
||||
codex_home: &Path,
|
||||
server_uri: &str,
|
||||
remote_control_base_url: &str,
|
||||
approval_policy: &str,
|
||||
) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "{approval_policy}"
|
||||
sandbox_mode = "read-only"
|
||||
chatgpt_base_url = "{remote_control_base_url}"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
@@ -108,12 +108,21 @@ async fn websocket_transport_serves_health_endpoints_on_same_listener() -> Resul
|
||||
}
|
||||
|
||||
pub(super) async fn spawn_websocket_server(codex_home: &Path) -> Result<(Child, SocketAddr)> {
|
||||
spawn_websocket_server_with_args(codex_home, &[]).await
|
||||
}
|
||||
|
||||
pub(super) async fn spawn_websocket_server_with_args(
|
||||
codex_home: &Path,
|
||||
extra_args: &[&str],
|
||||
) -> Result<(Child, SocketAddr)> {
|
||||
let program = codex_utils_cargo_bin::cargo_bin("codex-app-server")
|
||||
.context("should find app-server binary")?;
|
||||
let mut cmd = Command::new(program);
|
||||
cmd.arg("--listen")
|
||||
.arg("ws://127.0.0.1:0")
|
||||
.stdin(Stdio::null())
|
||||
cmd.arg("--listen").arg("ws://127.0.0.1:0");
|
||||
for extra_arg in extra_args {
|
||||
cmd.arg(extra_arg);
|
||||
}
|
||||
cmd.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::piped())
|
||||
.env("CODEX_HOME", codex_home)
|
||||
@@ -158,10 +167,11 @@ pub(super) async fn spawn_websocket_server(codex_home: &Path) -> Result<(Child,
|
||||
stripped
|
||||
};
|
||||
|
||||
if let Some(bind_addr) = stripped_line
|
||||
.split_whitespace()
|
||||
.find_map(|token| token.strip_prefix("ws://"))
|
||||
.and_then(|addr| addr.parse::<SocketAddr>().ok())
|
||||
if stripped_line.contains("listening on:")
|
||||
&& let Some(bind_addr) = stripped_line
|
||||
.split_whitespace()
|
||||
.find_map(|token| token.strip_prefix("ws://"))
|
||||
.and_then(|addr| addr.parse::<SocketAddr>().ok())
|
||||
{
|
||||
break bind_addr;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ mod collaboration_mode_list;
|
||||
mod command_exec;
|
||||
mod compaction;
|
||||
mod config_rpc;
|
||||
mod connection_handling_mixed_remote_control;
|
||||
mod connection_handling_websocket;
|
||||
#[cfg(unix)]
|
||||
mod connection_handling_websocket_unix;
|
||||
|
||||
@@ -8,6 +8,10 @@ license.workspace = true
|
||||
name = "codex"
|
||||
path = "src/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "codexd"
|
||||
path = "src/bin/codexd.rs"
|
||||
|
||||
[lib]
|
||||
name = "codex_cli"
|
||||
path = "src/lib.rs"
|
||||
|
||||
42
codex-rs/cli/src/bin/codexd.rs
Normal file
42
codex-rs/cli/src/bin/codexd.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use clap::Parser;
|
||||
use codex_app_server::run_main_with_runtime;
|
||||
use codex_arg0::Arg0DispatchPaths;
|
||||
use codex_arg0::arg0_dispatch_or_else;
|
||||
use codex_core::config_loader::LoaderOverrides;
|
||||
use codex_utils_cli::CliConfigOverrides;
|
||||
use std::io::ErrorKind;
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
#[clap(
|
||||
author,
|
||||
version,
|
||||
bin_name = "codexd",
|
||||
override_usage = "codexd [OPTIONS]"
|
||||
)]
|
||||
struct CodexdCli {
|
||||
#[clap(flatten)]
|
||||
config_overrides: CliConfigOverrides,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
arg0_dispatch_or_else(|arg0_paths: Arg0DispatchPaths| async move {
|
||||
let cli = CodexdCli::parse();
|
||||
cli.config_overrides.parse_overrides().map_err(|err| {
|
||||
std::io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("error parsing -c overrides: {err}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
run_main_with_runtime(
|
||||
arg0_paths,
|
||||
cli.config_overrides,
|
||||
LoaderOverrides::default(),
|
||||
true,
|
||||
None,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
@@ -331,6 +331,11 @@ struct AppServerCommand {
|
||||
)]
|
||||
listen: codex_app_server::AppServerTransport,
|
||||
|
||||
/// Also connect outbound to the ChatGPT remote control server derived from
|
||||
/// the configured `chatgpt_base_url`.
|
||||
#[arg(long = "with-remote-control", default_value_t = false)]
|
||||
with_remote_control: bool,
|
||||
|
||||
/// Controls whether analytics are enabled by default.
|
||||
///
|
||||
/// Analytics are disabled by default for app-server. Users have to explicitly opt in
|
||||
@@ -630,13 +635,13 @@ async fn cli_main(arg0_paths: Arg0DispatchPaths) -> anyhow::Result<()> {
|
||||
Some(Subcommand::AppServer(app_server_cli)) => match app_server_cli.subcommand {
|
||||
None => {
|
||||
reject_remote_mode_for_subcommand(root_remote.as_deref(), "app-server")?;
|
||||
let transport = app_server_cli.listen;
|
||||
codex_app_server::run_main_with_transport(
|
||||
codex_app_server::run_main_with_runtime(
|
||||
arg0_paths.clone(),
|
||||
root_config_overrides,
|
||||
codex_core::config_loader::LoaderOverrides::default(),
|
||||
app_server_cli.analytics_default_enabled,
|
||||
transport,
|
||||
Some(app_server_cli.listen),
|
||||
app_server_cli.with_remote_control,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
@@ -1600,6 +1605,7 @@ mod tests {
|
||||
app_server.listen,
|
||||
codex_app_server::AppServerTransport::Stdio
|
||||
);
|
||||
assert!(!app_server.with_remote_control);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user