mirror of
https://github.com/openai/codex.git
synced 2026-03-26 16:43:58 +00:00
Compare commits
4 Commits
main
...
codex/webs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82be4e88b0 | ||
|
|
685d77d2c1 | ||
|
|
8d62dd3257 | ||
|
|
8c644a154b |
3
codex-rs/Cargo.lock
generated
3
codex-rs/Cargo.lock
generated
@@ -1473,9 +1473,11 @@ dependencies = [
|
||||
"codex-utils-cli",
|
||||
"codex-utils-json-to-toml",
|
||||
"codex-utils-pty",
|
||||
"codex-utils-rustls-provider",
|
||||
"constant_time_eq",
|
||||
"core_test_support",
|
||||
"futures",
|
||||
"gethostname",
|
||||
"hmac",
|
||||
"jsonwebtoken",
|
||||
"opentelemetry",
|
||||
@@ -1498,6 +1500,7 @@ dependencies = [
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
"uuid",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
@@ -51,10 +51,12 @@ codex-sandboxing = { workspace = true }
|
||||
codex-state = { workspace = true }
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
codex-utils-json-to-toml = { workspace = true }
|
||||
codex-utils-rustls-provider = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
constant_time_eq = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
gethostname = { workspace = true }
|
||||
hmac = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
owo-colors = { workspace = true, features = ["supports-colors"] }
|
||||
@@ -75,6 +77,7 @@ tokio-util = { workspace = true }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt", "json"] }
|
||||
url = { workspace = true }
|
||||
uuid = { workspace = true, features = ["serde", "v7"] }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -2,6 +2,17 @@
|
||||
|
||||
`codex app-server` is the interface Codex uses to power rich interfaces such as the [Codex VS Code extension](https://marketplace.visualstudio.com/items?itemName=openai.chatgpt).
|
||||
|
||||
For remote-control-only deployments, use `codexd`. It runs the same app-server runtime in a headless daemon mode, connects outbound to the ChatGPT remote control server using ChatGPT auth, and does not expose a local stdio or websocket transport.
|
||||
|
||||
Remote control is configured in `~/.codex/config.toml`:
|
||||
|
||||
```toml
|
||||
chatgpt_base_url = "https://chatgpt.com/backend-api/"
|
||||
|
||||
[features]
|
||||
remote_control = true
|
||||
```
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Protocol](#protocol)
|
||||
|
||||
@@ -86,6 +86,7 @@ fn transport_name(transport: AppServerTransport) -> &'static str {
|
||||
match transport {
|
||||
AppServerTransport::Stdio => "stdio",
|
||||
AppServerTransport::WebSocket { .. } => "websocket",
|
||||
AppServerTransport::Headless => "headless",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ use codex_core::config::ConfigBuilder;
|
||||
use codex_core::config_loader::CloudRequirementsLoader;
|
||||
use codex_core::config_loader::ConfigLayerStackOrdering;
|
||||
use codex_core::config_loader::LoaderOverrides;
|
||||
use codex_features::Feature;
|
||||
use codex_utils_cli::CliConfigOverrides;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
@@ -29,8 +30,10 @@ use crate::transport::OutboundConnectionState;
|
||||
use crate::transport::TransportEvent;
|
||||
use crate::transport::auth::policy_from_settings;
|
||||
use crate::transport::route_outgoing_envelope;
|
||||
use crate::transport::start_remote_control;
|
||||
use crate::transport::start_stdio_connection;
|
||||
use crate::transport::start_websocket_acceptor;
|
||||
use crate::transport::validate_remote_control_auth;
|
||||
use codex_app_server_protocol::ConfigLayerSource;
|
||||
use codex_app_server_protocol::ConfigWarningNotification;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
@@ -94,6 +97,37 @@ enum LogFormat {
|
||||
Json,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
struct TransportRuntimeMode {
|
||||
single_client_mode: bool,
|
||||
shutdown_when_no_connections: bool,
|
||||
graceful_ctrl_c_restart_enabled: bool,
|
||||
ctrl_c_shutdown_enabled: bool,
|
||||
}
|
||||
|
||||
fn transport_runtime_mode(transport: AppServerTransport) -> TransportRuntimeMode {
|
||||
match transport {
|
||||
AppServerTransport::Stdio => TransportRuntimeMode {
|
||||
single_client_mode: true,
|
||||
shutdown_when_no_connections: true,
|
||||
graceful_ctrl_c_restart_enabled: false,
|
||||
ctrl_c_shutdown_enabled: false,
|
||||
},
|
||||
AppServerTransport::WebSocket { .. } => TransportRuntimeMode {
|
||||
single_client_mode: false,
|
||||
shutdown_when_no_connections: false,
|
||||
graceful_ctrl_c_restart_enabled: true,
|
||||
ctrl_c_shutdown_enabled: false,
|
||||
},
|
||||
AppServerTransport::Headless => TransportRuntimeMode {
|
||||
single_client_mode: false,
|
||||
shutdown_when_no_connections: false,
|
||||
graceful_ctrl_c_restart_enabled: false,
|
||||
ctrl_c_shutdown_enabled: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type StderrLogLayer = Box<dyn Layer<Registry> + Send + Sync + 'static>;
|
||||
|
||||
/// Control-plane messages from the processor/transport side to the outbound router task.
|
||||
@@ -361,38 +395,6 @@ pub async fn run_main_with_transport(
|
||||
let (outbound_control_tx, mut outbound_control_rx) =
|
||||
mpsc::channel::<OutboundControlEvent>(CHANNEL_CAPACITY);
|
||||
|
||||
enum TransportRuntime {
|
||||
Stdio,
|
||||
WebSocket {
|
||||
accept_handle: JoinHandle<()>,
|
||||
shutdown_token: CancellationToken,
|
||||
},
|
||||
}
|
||||
|
||||
let mut stdio_handles = Vec::<JoinHandle<()>>::new();
|
||||
let transport_runtime = match transport {
|
||||
AppServerTransport::Stdio => {
|
||||
start_stdio_connection(transport_event_tx.clone(), &mut stdio_handles).await?;
|
||||
TransportRuntime::Stdio
|
||||
}
|
||||
AppServerTransport::WebSocket { bind_address } => {
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let accept_handle = start_websocket_acceptor(
|
||||
bind_address,
|
||||
transport_event_tx.clone(),
|
||||
shutdown_token.clone(),
|
||||
policy_from_settings(&auth)?,
|
||||
)
|
||||
.await?;
|
||||
TransportRuntime::WebSocket {
|
||||
accept_handle,
|
||||
shutdown_token,
|
||||
}
|
||||
}
|
||||
};
|
||||
let single_client_mode = matches!(&transport_runtime, TransportRuntime::Stdio);
|
||||
let shutdown_when_no_connections = single_client_mode;
|
||||
let graceful_signal_restart_enabled = !single_client_mode;
|
||||
// Parse CLI overrides once and derive the base Config eagerly so later
|
||||
// components do not need to work with raw TOML values.
|
||||
let cli_kv_overrides = cli_config_overrides.parse_overrides().map_err(|e| {
|
||||
@@ -529,13 +531,13 @@ pub async fn run_main_with_transport(
|
||||
|
||||
let feedback_layer = feedback.logger_layer();
|
||||
let feedback_metadata_layer = feedback.metadata_layer();
|
||||
let log_db = codex_state::StateRuntime::init(
|
||||
let state_db = codex_state::StateRuntime::init(
|
||||
config.sqlite_home.clone(),
|
||||
config.model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.ok()
|
||||
.map(log_db::start);
|
||||
.ok();
|
||||
let log_db = state_db.clone().map(log_db::start);
|
||||
let log_db_layer = log_db
|
||||
.clone()
|
||||
.map(|layer| layer.with_filter(Targets::new().with_default(Level::TRACE)));
|
||||
@@ -556,6 +558,57 @@ pub async fn run_main_with_transport(
|
||||
}
|
||||
}
|
||||
|
||||
let transport_shutdown_token = CancellationToken::new();
|
||||
let mut transport_accept_handles = Vec::<JoinHandle<()>>::new();
|
||||
let runtime_mode = transport_runtime_mode(transport);
|
||||
|
||||
match transport {
|
||||
AppServerTransport::Stdio => {
|
||||
start_stdio_connection(transport_event_tx.clone(), &mut transport_accept_handles)
|
||||
.await?;
|
||||
}
|
||||
AppServerTransport::WebSocket { bind_address } => {
|
||||
let accept_handle = start_websocket_acceptor(
|
||||
bind_address,
|
||||
transport_event_tx.clone(),
|
||||
transport_shutdown_token.clone(),
|
||||
policy_from_settings(&auth)?,
|
||||
)
|
||||
.await?;
|
||||
transport_accept_handles.push(accept_handle);
|
||||
}
|
||||
AppServerTransport::Headless => {}
|
||||
}
|
||||
let shutdown_when_no_connections = runtime_mode.shutdown_when_no_connections;
|
||||
let graceful_ctrl_c_restart_enabled = runtime_mode.graceful_ctrl_c_restart_enabled;
|
||||
let graceful_signal_restart_enabled = runtime_mode.graceful_ctrl_c_restart_enabled;
|
||||
|
||||
let auth_manager = AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
/*enable_codex_api_key_env*/ false,
|
||||
config.cli_auth_credentials_store_mode,
|
||||
);
|
||||
auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone());
|
||||
|
||||
if config.features.enabled(Feature::RemoteControl) {
|
||||
validate_remote_control_auth(auth_manager.as_ref()).await?;
|
||||
let accept_handle = start_remote_control(
|
||||
config.chatgpt_base_url.clone(),
|
||||
state_db.clone(),
|
||||
auth_manager.clone(),
|
||||
transport_event_tx.clone(),
|
||||
transport_shutdown_token.clone(),
|
||||
)
|
||||
.await?;
|
||||
transport_accept_handles.push(accept_handle);
|
||||
}
|
||||
if transport_accept_handles.is_empty() {
|
||||
return Err(std::io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
"no transport configured; use --listen or enable remote control",
|
||||
));
|
||||
}
|
||||
|
||||
let outbound_handle = tokio::spawn(async move {
|
||||
let mut outbound_connections = HashMap::<ConnectionId, OutboundConnectionState>::new();
|
||||
loop {
|
||||
@@ -632,10 +685,7 @@ pub async fn run_main_with_transport(
|
||||
let mut thread_created_rx = processor.thread_created_receiver();
|
||||
let mut running_turn_count_rx = processor.subscribe_running_assistant_turn_count();
|
||||
let mut connections = HashMap::<ConnectionId, ConnectionState>::new();
|
||||
let websocket_accept_shutdown = match &transport_runtime {
|
||||
TransportRuntime::WebSocket { shutdown_token, .. } => Some(shutdown_token.clone()),
|
||||
TransportRuntime::Stdio => None,
|
||||
};
|
||||
let transport_shutdown_token = transport_shutdown_token.clone();
|
||||
async move {
|
||||
let mut listen_for_threads = true;
|
||||
let mut shutdown_state = ShutdownState::default();
|
||||
@@ -648,9 +698,7 @@ pub async fn run_main_with_transport(
|
||||
shutdown_state.update(running_turn_count, connections.len()),
|
||||
ShutdownAction::Finish
|
||||
) {
|
||||
if let Some(shutdown_token) = &websocket_accept_shutdown {
|
||||
shutdown_token.cancel();
|
||||
}
|
||||
transport_shutdown_token.cancel();
|
||||
let _ = outbound_control_tx
|
||||
.send(OutboundControlEvent::DisconnectAll)
|
||||
.await;
|
||||
@@ -665,6 +713,24 @@ pub async fn run_main_with_transport(
|
||||
let running_turn_count = *running_turn_count_rx.borrow();
|
||||
shutdown_state.on_signal(connections.len(), running_turn_count);
|
||||
}
|
||||
ctrl_c_result = tokio::signal::ctrl_c(), if runtime_mode.ctrl_c_shutdown_enabled => {
|
||||
if let Err(err) = ctrl_c_result {
|
||||
warn!("failed to listen for Ctrl-C during daemon shutdown: {err}");
|
||||
}
|
||||
info!("received Ctrl-C; shutting down codexd remote-control daemon");
|
||||
transport_shutdown_token.cancel();
|
||||
let _ = outbound_control_tx
|
||||
.send(OutboundControlEvent::DisconnectAll)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
ctrl_c_result = tokio::signal::ctrl_c(), if graceful_ctrl_c_restart_enabled && !shutdown_state.forced() => {
|
||||
if let Err(err) = ctrl_c_result {
|
||||
warn!("failed to listen for Ctrl-C during graceful restart drain: {err}");
|
||||
}
|
||||
let running_turn_count = *running_turn_count_rx.borrow();
|
||||
shutdown_state.on_signal(connections.len(), running_turn_count);
|
||||
}
|
||||
changed = running_turn_count_rx.changed(), if graceful_signal_restart_enabled && shutdown_state.requested() => {
|
||||
if changed.is_err() {
|
||||
warn!("running-turn watcher closed during graceful restart drain");
|
||||
@@ -844,16 +910,8 @@ pub async fn run_main_with_transport(
|
||||
let _ = processor_handle.await;
|
||||
let _ = outbound_handle.await;
|
||||
|
||||
if let TransportRuntime::WebSocket {
|
||||
accept_handle,
|
||||
shutdown_token,
|
||||
} = transport_runtime
|
||||
{
|
||||
shutdown_token.cancel();
|
||||
let _ = accept_handle.await;
|
||||
}
|
||||
|
||||
for handle in stdio_handles {
|
||||
transport_shutdown_token.cancel();
|
||||
for handle in transport_accept_handles {
|
||||
let _ = handle.await;
|
||||
}
|
||||
|
||||
@@ -867,6 +925,9 @@ pub async fn run_main_with_transport(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::LogFormat;
|
||||
use super::TransportRuntimeMode;
|
||||
use super::transport_runtime_mode;
|
||||
use crate::AppServerTransport;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
@@ -883,4 +944,17 @@ mod tests {
|
||||
assert_eq!(LogFormat::from_env_value(Some("text")), LogFormat::Default);
|
||||
assert_eq!(LogFormat::from_env_value(Some("jsonl")), LogFormat::Default);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn headless_transport_runtime_mode_uses_daemon_shutdown_behavior() {
|
||||
assert_eq!(
|
||||
transport_runtime_mode(AppServerTransport::Headless),
|
||||
TransportRuntimeMode {
|
||||
single_client_mode: false,
|
||||
shutdown_when_no_connections: false,
|
||||
graceful_ctrl_c_restart_enabled: false,
|
||||
ctrl_c_shutdown_enabled: true,
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
385
codex-rs/app-server/src/transport/mod.rs
Normal file
385
codex-rs/app-server/src/transport/mod.rs
Normal file
@@ -0,0 +1,385 @@
|
||||
use crate::error_code::OVERLOADED_ERROR_CODE;
|
||||
use crate::message_processor::ConnectionSessionState;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::OutgoingEnvelope;
|
||||
use crate::outgoing_message::OutgoingError;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::ServerRequest;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::net::SocketAddr;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::error;
|
||||
use tracing::warn;
|
||||
|
||||
/// Size of the bounded channels used to communicate between tasks. The value
|
||||
/// is a balance between throughput and memory usage - 128 messages should be
|
||||
/// plenty for an interactive CLI.
|
||||
pub(crate) const CHANNEL_CAPACITY: usize = 128;
|
||||
|
||||
pub(crate) mod auth;
|
||||
mod remote_control;
|
||||
mod stdio;
|
||||
mod websocket;
|
||||
|
||||
pub(crate) use remote_control::start_remote_control;
|
||||
pub(crate) use remote_control::validate_remote_control_auth;
|
||||
pub(crate) use stdio::start_stdio_connection;
|
||||
pub(crate) use websocket::start_websocket_acceptor;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub enum AppServerTransport {
|
||||
Stdio,
|
||||
WebSocket { bind_address: SocketAddr },
|
||||
Headless,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum AppServerTransportParseError {
|
||||
UnsupportedListenUrl(String),
|
||||
InvalidWebSocketListenUrl(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AppServerTransportParseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!(
|
||||
f,
|
||||
"unsupported --listen URL `{listen_url}`; expected `stdio://` or `ws://IP:PORT`"
|
||||
),
|
||||
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!(
|
||||
f,
|
||||
"invalid websocket --listen URL `{listen_url}`; expected `ws://IP:PORT`"
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for AppServerTransportParseError {}
|
||||
|
||||
impl AppServerTransport {
|
||||
pub const DEFAULT_LISTEN_URL: &'static str = "stdio://";
|
||||
|
||||
pub fn from_listen_url(listen_url: &str) -> Result<Self, AppServerTransportParseError> {
|
||||
if listen_url == Self::DEFAULT_LISTEN_URL {
|
||||
return Ok(Self::Stdio);
|
||||
}
|
||||
|
||||
if let Some(socket_addr) = listen_url.strip_prefix("ws://") {
|
||||
let bind_address = socket_addr.parse::<SocketAddr>().map_err(|_| {
|
||||
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string())
|
||||
})?;
|
||||
return Ok(Self::WebSocket { bind_address });
|
||||
}
|
||||
|
||||
Err(AppServerTransportParseError::UnsupportedListenUrl(
|
||||
listen_url.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for AppServerTransport {
|
||||
type Err = AppServerTransportParseError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Self::from_listen_url(s)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum TransportEvent {
|
||||
ConnectionOpened {
|
||||
connection_id: ConnectionId,
|
||||
writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
},
|
||||
ConnectionClosed {
|
||||
connection_id: ConnectionId,
|
||||
},
|
||||
IncomingMessage {
|
||||
connection_id: ConnectionId,
|
||||
message: JSONRPCMessage,
|
||||
},
|
||||
}
|
||||
|
||||
pub(crate) struct ConnectionState {
|
||||
pub(crate) outbound_initialized: Arc<AtomicBool>,
|
||||
pub(crate) outbound_experimental_api_enabled: Arc<AtomicBool>,
|
||||
pub(crate) outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
pub(crate) session: ConnectionSessionState,
|
||||
}
|
||||
|
||||
impl ConnectionState {
|
||||
pub(crate) fn new(
|
||||
outbound_initialized: Arc<AtomicBool>,
|
||||
outbound_experimental_api_enabled: Arc<AtomicBool>,
|
||||
outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
outbound_initialized,
|
||||
outbound_experimental_api_enabled,
|
||||
outbound_opted_out_notification_methods,
|
||||
session: ConnectionSessionState::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct OutboundConnectionState {
|
||||
pub(crate) initialized: Arc<AtomicBool>,
|
||||
pub(crate) experimental_api_enabled: Arc<AtomicBool>,
|
||||
pub(crate) opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
pub(crate) writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
}
|
||||
|
||||
impl OutboundConnectionState {
|
||||
pub(crate) fn new(
|
||||
writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
initialized: Arc<AtomicBool>,
|
||||
experimental_api_enabled: Arc<AtomicBool>,
|
||||
opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
) -> Self {
|
||||
Self {
|
||||
initialized,
|
||||
experimental_api_enabled,
|
||||
opted_out_notification_methods,
|
||||
writer,
|
||||
disconnect_sender,
|
||||
}
|
||||
}
|
||||
|
||||
fn can_disconnect(&self) -> bool {
|
||||
self.disconnect_sender.is_some()
|
||||
}
|
||||
|
||||
pub(crate) fn request_disconnect(&self) {
|
||||
if let Some(disconnect_sender) = &self.disconnect_sender {
|
||||
disconnect_sender.cancel();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static CONNECTION_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
|
||||
|
||||
fn next_connection_id() -> ConnectionId {
|
||||
ConnectionId(CONNECTION_ID_COUNTER.fetch_add(1, Ordering::Relaxed))
|
||||
}
|
||||
async fn forward_incoming_message(
|
||||
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||
writer: &mpsc::Sender<QueuedOutgoingMessage>,
|
||||
connection_id: ConnectionId,
|
||||
payload: &str,
|
||||
) -> bool {
|
||||
match serde_json::from_str::<JSONRPCMessage>(payload) {
|
||||
Ok(message) => {
|
||||
enqueue_incoming_message(transport_event_tx, writer, connection_id, message).await
|
||||
}
|
||||
Err(err) => {
|
||||
error!("Failed to deserialize JSONRPCMessage: {err}");
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn enqueue_incoming_message(
|
||||
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||
writer: &mpsc::Sender<QueuedOutgoingMessage>,
|
||||
connection_id: ConnectionId,
|
||||
message: JSONRPCMessage,
|
||||
) -> bool {
|
||||
let event = TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message,
|
||||
};
|
||||
match transport_event_tx.try_send(event) {
|
||||
Ok(()) => true,
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => false,
|
||||
Err(mpsc::error::TrySendError::Full(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message: JSONRPCMessage::Request(request),
|
||||
})) => {
|
||||
let overload_error = OutgoingMessage::Error(OutgoingError {
|
||||
id: request.id,
|
||||
error: JSONRPCErrorError {
|
||||
code: OVERLOADED_ERROR_CODE,
|
||||
message: "Server overloaded; retry later.".to_string(),
|
||||
data: None,
|
||||
},
|
||||
});
|
||||
match writer.try_send(QueuedOutgoingMessage::new(overload_error)) {
|
||||
Ok(()) => true,
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => false,
|
||||
Err(mpsc::error::TrySendError::Full(_overload_error)) => {
|
||||
warn!(
|
||||
"dropping overload response for connection {:?}: outbound queue is full",
|
||||
connection_id
|
||||
);
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Full(event)) => transport_event_tx.send(event).await.is_ok(),
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option<String> {
|
||||
let value = match serde_json::to_value(outgoing_message) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
error!("Failed to convert OutgoingMessage to JSON value: {err}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
match serde_json::to_string(&value) {
|
||||
Ok(json) => Some(json),
|
||||
Err(err) => {
|
||||
error!("Failed to serialize JSONRPCMessage: {err}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn should_skip_notification_for_connection(
|
||||
connection_state: &OutboundConnectionState,
|
||||
message: &OutgoingMessage,
|
||||
) -> bool {
|
||||
let Ok(opted_out_notification_methods) = connection_state.opted_out_notification_methods.read()
|
||||
else {
|
||||
warn!("failed to read outbound opted-out notifications");
|
||||
return false;
|
||||
};
|
||||
match message {
|
||||
OutgoingMessage::AppServerNotification(notification) => {
|
||||
let method = notification.to_string();
|
||||
opted_out_notification_methods.contains(method.as_str())
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn disconnect_connection(
|
||||
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||
connection_id: ConnectionId,
|
||||
) -> bool {
|
||||
if let Some(connection_state) = connections.remove(&connection_id) {
|
||||
connection_state.request_disconnect();
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
async fn send_message_to_connection(
|
||||
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||
connection_id: ConnectionId,
|
||||
message: OutgoingMessage,
|
||||
write_complete_tx: Option<tokio::sync::oneshot::Sender<()>>,
|
||||
) -> bool {
|
||||
let Some(connection_state) = connections.get(&connection_id) else {
|
||||
warn!("dropping message for disconnected connection: {connection_id:?}");
|
||||
return false;
|
||||
};
|
||||
let message = filter_outgoing_message_for_connection(connection_state, message);
|
||||
if should_skip_notification_for_connection(connection_state, &message) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let writer = connection_state.writer.clone();
|
||||
let queued_message = QueuedOutgoingMessage {
|
||||
message,
|
||||
write_complete_tx,
|
||||
};
|
||||
if connection_state.can_disconnect() {
|
||||
match writer.try_send(queued_message) {
|
||||
Ok(()) => false,
|
||||
Err(mpsc::error::TrySendError::Full(_)) => {
|
||||
warn!(
|
||||
"disconnecting slow connection after outbound queue filled: {connection_id:?}"
|
||||
);
|
||||
disconnect_connection(connections, connection_id)
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => {
|
||||
disconnect_connection(connections, connection_id)
|
||||
}
|
||||
}
|
||||
} else if writer.send(queued_message).await.is_err() {
|
||||
disconnect_connection(connections, connection_id)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_outgoing_message_for_connection(
|
||||
connection_state: &OutboundConnectionState,
|
||||
message: OutgoingMessage,
|
||||
) -> OutgoingMessage {
|
||||
let experimental_api_enabled = connection_state
|
||||
.experimental_api_enabled
|
||||
.load(Ordering::Acquire);
|
||||
match message {
|
||||
OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
|
||||
request_id,
|
||||
mut params,
|
||||
}) => {
|
||||
if !experimental_api_enabled {
|
||||
params.strip_experimental_fields();
|
||||
}
|
||||
OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
|
||||
request_id,
|
||||
params,
|
||||
})
|
||||
}
|
||||
_ => message,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn route_outgoing_envelope(
|
||||
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||
envelope: OutgoingEnvelope,
|
||||
) {
|
||||
match envelope {
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
write_complete_tx,
|
||||
} => {
|
||||
let _ =
|
||||
send_message_to_connection(connections, connection_id, message, write_complete_tx)
|
||||
.await;
|
||||
}
|
||||
OutgoingEnvelope::Broadcast { message } => {
|
||||
let target_connections: Vec<ConnectionId> = connections
|
||||
.iter()
|
||||
.filter_map(|(connection_id, connection_state)| {
|
||||
if connection_state.initialized.load(Ordering::Acquire)
|
||||
&& !should_skip_notification_for_connection(connection_state, &message)
|
||||
{
|
||||
Some(*connection_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
for connection_id in target_connections {
|
||||
let _ = send_message_to_connection(
|
||||
connections,
|
||||
connection_id,
|
||||
message.clone(),
|
||||
/*write_complete_tx*/ None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
370
codex-rs/app-server/src/transport/remote_control/enroll.rs
Normal file
370
codex-rs/app-server/src/transport/remote_control/enroll.rs
Normal file
@@ -0,0 +1,370 @@
|
||||
use super::protocol::EnrollRemoteServerRequest;
|
||||
use super::protocol::EnrollRemoteServerResponse;
|
||||
use super::protocol::RemoteControlTarget;
|
||||
use axum::http::HeaderMap;
|
||||
use base64::Engine;
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::default_client::build_reqwest_client;
|
||||
use codex_state::StateRuntime;
|
||||
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
|
||||
use gethostname::gethostname;
|
||||
use io::ErrorKind;
|
||||
use std::io;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tokio_tungstenite::tungstenite;
|
||||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
use tokio_tungstenite::tungstenite::http::HeaderValue;
|
||||
use tracing::warn;
|
||||
|
||||
const REMOTE_CONTROL_ENROLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
|
||||
const REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES: usize = 4096;
|
||||
pub(super) const REMOTE_CONTROL_PROTOCOL_VERSION: &str = "2";
|
||||
pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id";
|
||||
const REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER: &str = "x-codex-subscribe-cursor";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(super) struct RemoteControlEnrollment {
|
||||
pub(super) server_id: String,
|
||||
pub(super) server_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(super) struct RemoteControlConnectionAuth {
|
||||
pub(super) bearer_token: String,
|
||||
pub(super) account_id: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) struct RemoteControlWebsocketConnection {
|
||||
pub(super) websocket_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
}
|
||||
|
||||
pub(super) async fn load_persisted_remote_control_enrollment(
|
||||
state_db: Option<&StateRuntime>,
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
account_id: Option<&str>,
|
||||
) -> Option<RemoteControlEnrollment> {
|
||||
let state_db = state_db?;
|
||||
let enrollment = match state_db
|
||||
.get_remote_control_enrollment(&remote_control_target.websocket_url, account_id)
|
||||
.await
|
||||
{
|
||||
Ok(enrollment) => enrollment,
|
||||
Err(err) => {
|
||||
warn!("{err}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
enrollment.map(|(server_id, server_name)| RemoteControlEnrollment {
|
||||
server_id,
|
||||
server_name,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) async fn update_persisted_remote_control_enrollment(
|
||||
state_db: Option<&StateRuntime>,
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
account_id: Option<&str>,
|
||||
enrollment: Option<&RemoteControlEnrollment>,
|
||||
) -> io::Result<()> {
|
||||
let Some(state_db) = state_db else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if let Some(enrollment) = enrollment {
|
||||
state_db
|
||||
.upsert_remote_control_enrollment(
|
||||
&remote_control_target.websocket_url,
|
||||
account_id,
|
||||
&enrollment.server_id,
|
||||
&enrollment.server_name,
|
||||
)
|
||||
.await
|
||||
.map_err(io::Error::other)
|
||||
} else {
|
||||
state_db
|
||||
.delete_remote_control_enrollment(&remote_control_target.websocket_url, account_id)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(io::Error::other)
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn load_remote_control_auth(
|
||||
auth_manager: &AuthManager,
|
||||
) -> io::Result<RemoteControlConnectionAuth> {
|
||||
let auth = match auth_manager.auth().await {
|
||||
Some(auth) => auth,
|
||||
None => {
|
||||
auth_manager.reload();
|
||||
auth_manager.auth().await.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
ErrorKind::PermissionDenied,
|
||||
"remote control requires ChatGPT authentication",
|
||||
)
|
||||
})?
|
||||
}
|
||||
};
|
||||
|
||||
if !auth.is_chatgpt_auth() {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::PermissionDenied,
|
||||
"remote control requires ChatGPT authentication; API key auth is not supported",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(RemoteControlConnectionAuth {
|
||||
bearer_token: auth.get_token().map_err(io::Error::other)?,
|
||||
account_id: auth.get_account_id(),
|
||||
})
|
||||
}
|
||||
|
||||
fn preview_remote_control_response_body(body: &[u8]) -> String {
|
||||
let body = String::from_utf8_lossy(body);
|
||||
let trimmed = body.trim();
|
||||
if trimmed.is_empty() {
|
||||
return "<empty>".to_string();
|
||||
}
|
||||
if trimmed.len() <= REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
|
||||
let mut cut = REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES;
|
||||
while !trimmed.is_char_boundary(cut) {
|
||||
cut = cut.saturating_sub(1);
|
||||
}
|
||||
let mut truncated = trimmed[..cut].to_string();
|
||||
truncated.push_str("...");
|
||||
truncated
|
||||
}
|
||||
|
||||
fn format_headers(headers: &HeaderMap) -> String {
|
||||
let mut headers = headers
|
||||
.iter()
|
||||
.map(|(name, value)| {
|
||||
format!(
|
||||
"{}: {}",
|
||||
name.as_str(),
|
||||
value.to_str().unwrap_or("<invalid utf-8>")
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
headers.sort();
|
||||
format!("{{{}}}", headers.join(", "))
|
||||
}
|
||||
|
||||
fn format_remote_control_websocket_connect_error(
|
||||
websocket_url: &str,
|
||||
err: &tungstenite::Error,
|
||||
) -> String {
|
||||
let mut message =
|
||||
format!("failed to connect app-server remote control websocket `{websocket_url}`: {err}");
|
||||
let tungstenite::Error::Http(response) = err else {
|
||||
return message;
|
||||
};
|
||||
|
||||
message.push_str(&format!(
|
||||
", headers: {}",
|
||||
format_headers(response.headers())
|
||||
));
|
||||
if let Some(body) = response.body().as_ref()
|
||||
&& !body.is_empty()
|
||||
{
|
||||
let body_preview = preview_remote_control_response_body(body);
|
||||
message.push_str(&format!(", body: {body_preview}"));
|
||||
}
|
||||
|
||||
message
|
||||
}
|
||||
|
||||
pub(super) async fn enroll_remote_control_server(
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
auth: &RemoteControlConnectionAuth,
|
||||
) -> io::Result<RemoteControlEnrollment> {
|
||||
let enroll_url = &remote_control_target.enroll_url;
|
||||
let server_name = gethostname().to_string_lossy().trim().to_string();
|
||||
let request = EnrollRemoteServerRequest {
|
||||
name: server_name.clone(),
|
||||
os: std::env::consts::OS,
|
||||
arch: std::env::consts::ARCH,
|
||||
app_server_version: env!("CARGO_PKG_VERSION"),
|
||||
};
|
||||
let client = build_reqwest_client();
|
||||
let mut http_request = client
|
||||
.post(enroll_url)
|
||||
.timeout(REMOTE_CONTROL_ENROLL_TIMEOUT)
|
||||
.bearer_auth(&auth.bearer_token)
|
||||
.json(&request);
|
||||
if let Some(account_id) = auth.account_id.as_deref() {
|
||||
http_request = http_request.header(REMOTE_CONTROL_ACCOUNT_ID_HEADER, account_id);
|
||||
}
|
||||
|
||||
let response = http_request.send().await.map_err(|err| {
|
||||
io::Error::other(format!(
|
||||
"failed to enroll remote control server at `{enroll_url}`: {err}"
|
||||
))
|
||||
})?;
|
||||
let headers = response.headers().clone();
|
||||
let status = response.status();
|
||||
let body = response.bytes().await.map_err(|err| {
|
||||
io::Error::other(format!(
|
||||
"failed to read remote control enrollment response from `{enroll_url}`: {err}"
|
||||
))
|
||||
})?;
|
||||
let body_preview = preview_remote_control_response_body(&body);
|
||||
if !status.is_success() {
|
||||
let headers_str = format_headers(&headers);
|
||||
return Err(io::Error::other(format!(
|
||||
"remote control server enrollment failed at `{enroll_url}`: HTTP {status}, headers: {headers_str}, body: {body_preview}"
|
||||
)));
|
||||
}
|
||||
|
||||
let enrollment = serde_json::from_slice::<EnrollRemoteServerResponse>(&body).map_err(|err| {
|
||||
let headers_str = format_headers(&headers);
|
||||
io::Error::other(format!(
|
||||
"failed to parse remote control enrollment response from `{enroll_url}`: HTTP {status}, headers: {headers_str}, body: {body_preview}, decode error: {err}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(RemoteControlEnrollment {
|
||||
server_id: enrollment.server_id,
|
||||
server_name,
|
||||
})
|
||||
}
|
||||
|
||||
fn set_remote_control_header(
|
||||
headers: &mut tungstenite::http::HeaderMap,
|
||||
name: &'static str,
|
||||
value: &str,
|
||||
) -> io::Result<()> {
|
||||
let header_value = HeaderValue::from_str(value).map_err(|err| {
|
||||
io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("invalid remote control header `{name}`: {err}"),
|
||||
)
|
||||
})?;
|
||||
headers.insert(name, header_value);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_remote_control_websocket_request(
|
||||
websocket_url: &str,
|
||||
enrollment: &RemoteControlEnrollment,
|
||||
auth: &RemoteControlConnectionAuth,
|
||||
subscribe_cursor: Option<&str>,
|
||||
) -> io::Result<tungstenite::http::Request<()>> {
|
||||
let mut request = websocket_url.into_client_request().map_err(|err| {
|
||||
io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("invalid remote control websocket URL `{websocket_url}`: {err}"),
|
||||
)
|
||||
})?;
|
||||
let headers = request.headers_mut();
|
||||
set_remote_control_header(headers, "x-codex-server-id", &enrollment.server_id)?;
|
||||
set_remote_control_header(
|
||||
headers,
|
||||
"x-codex-name",
|
||||
&base64::engine::general_purpose::STANDARD.encode(&enrollment.server_name),
|
||||
)?;
|
||||
set_remote_control_header(
|
||||
headers,
|
||||
"x-codex-protocol-version",
|
||||
REMOTE_CONTROL_PROTOCOL_VERSION,
|
||||
)?;
|
||||
set_remote_control_header(
|
||||
headers,
|
||||
"authorization",
|
||||
&format!("Bearer {}", auth.bearer_token),
|
||||
)?;
|
||||
if let Some(account_id) = auth.account_id.as_deref() {
|
||||
set_remote_control_header(headers, REMOTE_CONTROL_ACCOUNT_ID_HEADER, account_id)?;
|
||||
}
|
||||
if let Some(subscribe_cursor) = subscribe_cursor {
|
||||
set_remote_control_header(
|
||||
headers,
|
||||
REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER,
|
||||
subscribe_cursor,
|
||||
)?;
|
||||
}
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
pub(super) async fn connect_remote_control_websocket(
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
state_db: Option<&StateRuntime>,
|
||||
auth_manager: &AuthManager,
|
||||
enrollment: &mut Option<RemoteControlEnrollment>,
|
||||
subscribe_cursor: Option<&str>,
|
||||
) -> io::Result<RemoteControlWebsocketConnection> {
|
||||
ensure_rustls_crypto_provider();
|
||||
|
||||
let auth = load_remote_control_auth(auth_manager).await?;
|
||||
if enrollment.is_none() {
|
||||
*enrollment = load_persisted_remote_control_enrollment(
|
||||
state_db,
|
||||
remote_control_target,
|
||||
auth.account_id.as_deref(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
if enrollment.is_none() {
|
||||
let new_enrollment = enroll_remote_control_server(remote_control_target, &auth).await?;
|
||||
if let Err(err) = update_persisted_remote_control_enrollment(
|
||||
state_db,
|
||||
remote_control_target,
|
||||
auth.account_id.as_deref(),
|
||||
Some(&new_enrollment),
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!("failed to persist remote control enrollment in sqlite state db: {err}");
|
||||
}
|
||||
*enrollment = Some(new_enrollment);
|
||||
}
|
||||
|
||||
let enrollment_ref = enrollment.as_ref().ok_or_else(|| {
|
||||
io::Error::other("missing remote control enrollment after enrollment step")
|
||||
})?;
|
||||
let request = build_remote_control_websocket_request(
|
||||
&remote_control_target.websocket_url,
|
||||
enrollment_ref,
|
||||
&auth,
|
||||
subscribe_cursor,
|
||||
)?;
|
||||
|
||||
let (websocket_stream, _response) = match connect_async(request).await {
|
||||
Ok((websocket_stream, response)) => (websocket_stream, response),
|
||||
Err(err) => {
|
||||
if matches!(
|
||||
&err,
|
||||
tungstenite::Error::Http(response) if response.status().as_u16() == 404
|
||||
) {
|
||||
if let Err(clear_err) = update_persisted_remote_control_enrollment(
|
||||
state_db,
|
||||
remote_control_target,
|
||||
auth.account_id.as_deref(),
|
||||
/*enrollment*/ None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
"failed to clear stale remote control enrollment in sqlite state db: {clear_err}"
|
||||
);
|
||||
}
|
||||
*enrollment = None;
|
||||
}
|
||||
return Err(io::Error::other(
|
||||
format_remote_control_websocket_connect_error(
|
||||
&remote_control_target.websocket_url,
|
||||
&err,
|
||||
),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(RemoteControlWebsocketConnection { websocket_stream })
|
||||
}
|
||||
346
codex-rs/app-server/src/transport/remote_control/mod.rs
Normal file
346
codex-rs/app-server/src/transport/remote_control/mod.rs
Normal file
@@ -0,0 +1,346 @@
|
||||
mod enroll;
|
||||
mod protocol;
|
||||
mod websocket;
|
||||
|
||||
use self::enroll::load_remote_control_auth;
|
||||
use self::protocol::ClientEnvelope;
|
||||
pub use self::protocol::ClientEvent;
|
||||
pub use self::protocol::ClientId;
|
||||
use self::protocol::PongStatus;
|
||||
use self::protocol::ServerEnvelope;
|
||||
use self::protocol::ServerEvent;
|
||||
use self::protocol::normalize_remote_control_url;
|
||||
use self::websocket::run_remote_control_websocket_loop;
|
||||
use super::CHANNEL_CAPACITY;
|
||||
use super::TransportEvent;
|
||||
use super::next_connection_id;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_core::AuthManager;
|
||||
use codex_state::StateRuntime;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::MissedTickBehavior;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
const REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT: Duration = Duration::from_secs(10 * 60);
|
||||
const REMOTE_CONTROL_IDLE_SWEEP_INTERVAL: Duration = Duration::from_secs(30);
|
||||
|
||||
struct RemoteControlClientState {
|
||||
connection_id: ConnectionId,
|
||||
disconnect_token: CancellationToken,
|
||||
last_activity_at: Instant,
|
||||
last_inbound_seq_id: Option<u64>,
|
||||
}
|
||||
|
||||
pub(super) struct RemoteControlQueuedServerEnvelope {
|
||||
pub(super) envelope: ServerEnvelope,
|
||||
pub(super) write_complete_tx: Option<oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
pub(crate) async fn start_remote_control(
|
||||
remote_control_url: String,
|
||||
state_db: Option<Arc<StateRuntime>>,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
shutdown_token: CancellationToken,
|
||||
) -> io::Result<JoinHandle<()>> {
|
||||
let remote_control_url = normalize_remote_control_url(&remote_control_url)?;
|
||||
Ok(tokio::spawn(async move {
|
||||
let local_shutdown_token = shutdown_token.child_token();
|
||||
let (client_event_tx, client_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (server_event_tx, server_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (writer_exited_tx, writer_exited_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
|
||||
let mut join_set = JoinSet::new();
|
||||
join_set.spawn(run_remote_control_websocket_loop(
|
||||
remote_control_url,
|
||||
state_db,
|
||||
auth_manager,
|
||||
client_event_tx,
|
||||
server_event_rx,
|
||||
local_shutdown_token.clone(),
|
||||
));
|
||||
join_set.spawn(run_remote_control_manager(
|
||||
transport_event_tx,
|
||||
client_event_rx,
|
||||
server_event_tx,
|
||||
writer_exited_tx,
|
||||
writer_exited_rx,
|
||||
local_shutdown_token.clone(),
|
||||
));
|
||||
|
||||
tokio::select! {
|
||||
_ = local_shutdown_token.cancelled() => {}
|
||||
_ = join_set.join_next() => local_shutdown_token.cancel(),
|
||||
}
|
||||
|
||||
join_set.shutdown().await;
|
||||
}))
|
||||
}
|
||||
|
||||
async fn run_remote_control_manager(
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
mut client_event_rx: mpsc::Receiver<ClientEnvelope>,
|
||||
server_event_tx: mpsc::Sender<RemoteControlQueuedServerEnvelope>,
|
||||
writer_exited_tx: mpsc::Sender<ClientId>,
|
||||
mut writer_exited_rx: mpsc::Receiver<ClientId>,
|
||||
shutdown_token: CancellationToken,
|
||||
) {
|
||||
let mut clients = HashMap::<ClientId, RemoteControlClientState>::new();
|
||||
let mut idle_sweep = tokio::time::interval(REMOTE_CONTROL_IDLE_SWEEP_INTERVAL);
|
||||
idle_sweep.set_missed_tick_behavior(MissedTickBehavior::Skip);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
_ = idle_sweep.tick() => {
|
||||
if !close_expired_remote_control_clients(&transport_event_tx, &mut clients).await {
|
||||
break;
|
||||
}
|
||||
}
|
||||
writer_exited = writer_exited_rx.recv() => {
|
||||
let Some(client_id) = writer_exited else {
|
||||
break;
|
||||
};
|
||||
if !close_remote_control_client(&transport_event_tx, &mut clients, &client_id).await {
|
||||
break;
|
||||
}
|
||||
}
|
||||
client_event = client_event_rx.recv() => {
|
||||
let Some(client_event) = client_event else {
|
||||
break;
|
||||
};
|
||||
match client_event.event {
|
||||
ClientEvent::ClientMessage { message } => {
|
||||
let client_id = client_event.client_id;
|
||||
let is_initialize = remote_control_message_starts_connection(&message);
|
||||
if let Some(seq_id) = client_event.seq_id
|
||||
&& let Some(client) = clients.get(&client_id)
|
||||
&& client.last_inbound_seq_id.is_some_and(|last_seq_id| last_seq_id >= seq_id)
|
||||
&& !is_initialize
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if is_initialize && clients.contains_key(&client_id)
|
||||
&& !close_remote_control_client(&transport_event_tx, &mut clients, &client_id).await {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(connection_id) = clients.get_mut(&client_id).map(|client| {
|
||||
client.last_activity_at = Instant::now();
|
||||
if let Some(seq_id) = client_event.seq_id {
|
||||
client.last_inbound_seq_id = Some(seq_id);
|
||||
}
|
||||
client.connection_id
|
||||
}) {
|
||||
if transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message,
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if !is_initialize {
|
||||
continue;
|
||||
}
|
||||
|
||||
let connection_id = next_connection_id();
|
||||
let (writer_tx, writer_rx) =
|
||||
mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let disconnect_token = CancellationToken::new();
|
||||
if transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer: writer_tx,
|
||||
disconnect_sender: Some(disconnect_token.clone()),
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
tokio::spawn(run_remote_control_client_outbound(
|
||||
client_id.clone(),
|
||||
writer_rx,
|
||||
server_event_tx.clone(),
|
||||
writer_exited_tx.clone(),
|
||||
disconnect_token.clone(),
|
||||
));
|
||||
clients.insert(
|
||||
client_id,
|
||||
RemoteControlClientState {
|
||||
connection_id,
|
||||
disconnect_token,
|
||||
last_activity_at: Instant::now(),
|
||||
last_inbound_seq_id: client_event.seq_id,
|
||||
},
|
||||
);
|
||||
if transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message,
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
ClientEvent::Ack { .. } => continue,
|
||||
ClientEvent::Ping => {
|
||||
let client_id = client_event.client_id;
|
||||
let status = match clients.get_mut(&client_id) {
|
||||
Some(client) => {
|
||||
client.last_activity_at = Instant::now();
|
||||
PongStatus::Active
|
||||
}
|
||||
None => PongStatus::Unknown,
|
||||
};
|
||||
|
||||
if server_event_tx
|
||||
.send(RemoteControlQueuedServerEnvelope {
|
||||
envelope: ServerEnvelope {
|
||||
event: ServerEvent::Pong { status },
|
||||
client_id,
|
||||
seq_id: None,
|
||||
},
|
||||
write_complete_tx: None,
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
ClientEvent::ClientClosed => {
|
||||
let client_id = client_event.client_id;
|
||||
if !close_remote_control_client(&transport_event_tx, &mut clients, &client_id).await {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(client_id) = clients.keys().next().cloned() {
|
||||
if !close_remote_control_client(&transport_event_tx, &mut clients, &client_id).await {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_control_message_starts_connection(message: &JSONRPCMessage) -> bool {
|
||||
matches!(
|
||||
message,
|
||||
JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { method, .. })
|
||||
if method == "initialize"
|
||||
)
|
||||
}
|
||||
|
||||
fn remote_control_client_is_alive(client: &RemoteControlClientState, now: Instant) -> bool {
|
||||
now.duration_since(client.last_activity_at) < REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT
|
||||
}
|
||||
|
||||
async fn close_expired_remote_control_clients(
|
||||
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||
clients: &mut HashMap<ClientId, RemoteControlClientState>,
|
||||
) -> bool {
|
||||
let now = Instant::now();
|
||||
let expired_client_ids: Vec<ClientId> = clients
|
||||
.iter()
|
||||
.filter_map(|(client_id, client)| {
|
||||
(!remote_control_client_is_alive(client, now)).then_some(client_id.clone())
|
||||
})
|
||||
.collect();
|
||||
for client_id in expired_client_ids {
|
||||
if !close_remote_control_client(transport_event_tx, clients, &client_id).await {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
async fn close_remote_control_client(
|
||||
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||
clients: &mut HashMap<ClientId, RemoteControlClientState>,
|
||||
client_id: &ClientId,
|
||||
) -> bool {
|
||||
let Some(client) = clients.remove(client_id) else {
|
||||
return true;
|
||||
};
|
||||
client.disconnect_token.cancel();
|
||||
transport_event_tx
|
||||
.send(TransportEvent::ConnectionClosed {
|
||||
connection_id: client.connection_id,
|
||||
})
|
||||
.await
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
async fn run_remote_control_client_outbound(
|
||||
client_id: ClientId,
|
||||
mut writer_rx: mpsc::Receiver<QueuedOutgoingMessage>,
|
||||
server_event_tx: mpsc::Sender<RemoteControlQueuedServerEnvelope>,
|
||||
writer_exited_tx: mpsc::Sender<ClientId>,
|
||||
disconnect_token: CancellationToken,
|
||||
) {
|
||||
let mut seq_id = 0_u64;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = disconnect_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
queued_message = writer_rx.recv() => {
|
||||
let Some(queued_message) = queued_message else {
|
||||
break;
|
||||
};
|
||||
if server_event_tx
|
||||
.send(RemoteControlQueuedServerEnvelope {
|
||||
envelope: ServerEnvelope {
|
||||
event: ServerEvent::ServerMessage {
|
||||
message: Box::new(queued_message.message),
|
||||
},
|
||||
client_id: client_id.clone(),
|
||||
seq_id: Some(seq_id),
|
||||
},
|
||||
write_complete_tx: queued_message.write_complete_tx,
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
seq_id = seq_id.wrapping_add(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = writer_exited_tx.send(client_id).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn validate_remote_control_auth(auth_manager: &AuthManager) -> io::Result<()> {
|
||||
load_remote_control_auth(auth_manager).await.map(|_| ())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
188
codex-rs/app-server/src/transport/remote_control/protocol.rs
Normal file
188
codex-rs/app-server/src/transport/remote_control/protocol.rs
Normal file
@@ -0,0 +1,188 @@
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(super) struct RemoteControlTarget {
|
||||
pub(super) websocket_url: String,
|
||||
pub(super) enroll_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(super) struct EnrollRemoteServerRequest {
|
||||
pub(super) name: String,
|
||||
pub(super) os: &'static str,
|
||||
pub(super) arch: &'static str,
|
||||
pub(super) app_server_version: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(super) struct EnrollRemoteServerResponse {
|
||||
pub(super) server_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct ClientId(pub String);
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ClientEvent {
|
||||
ClientMessage {
|
||||
message: JSONRPCMessage,
|
||||
},
|
||||
Ack {
|
||||
#[serde(rename = "acked_seq_id", alias = "ackedSeqId")]
|
||||
acked_seq_id: u64,
|
||||
},
|
||||
Ping,
|
||||
ClientClosed,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub(crate) struct ClientEnvelope {
|
||||
#[serde(flatten)]
|
||||
pub(crate) event: ClientEvent,
|
||||
#[serde(rename = "client_id", alias = "clientId")]
|
||||
pub(crate) client_id: ClientId,
|
||||
#[serde(
|
||||
rename = "seq_id",
|
||||
alias = "seqId",
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
pub(crate) seq_id: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PongStatus {
|
||||
Active,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ServerEvent {
|
||||
ServerMessage {
|
||||
message: Box<OutgoingMessage>,
|
||||
},
|
||||
#[allow(dead_code)]
|
||||
Ack,
|
||||
Pong {
|
||||
status: PongStatus,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub(crate) struct ServerEnvelope {
|
||||
#[serde(flatten)]
|
||||
pub(crate) event: ServerEvent,
|
||||
#[serde(rename = "client_id", alias = "clientId")]
|
||||
pub(crate) client_id: ClientId,
|
||||
#[serde(
|
||||
rename = "seq_id",
|
||||
alias = "seqId",
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
pub(crate) seq_id: Option<u64>,
|
||||
}
|
||||
|
||||
pub(super) fn normalize_remote_control_url(
|
||||
remote_control_url: &str,
|
||||
) -> io::Result<RemoteControlTarget> {
|
||||
let map_url_parse_error = |err: url::ParseError| -> io::Error {
|
||||
io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("invalid remote control URL `{remote_control_url}`: {err}"),
|
||||
)
|
||||
};
|
||||
let map_scheme_error = |_: ()| -> io::Error {
|
||||
io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!(
|
||||
"invalid remote control URL `{remote_control_url}`; expected absolute URL with http:// or https:// scheme"
|
||||
),
|
||||
)
|
||||
};
|
||||
|
||||
let mut remote_control_url = Url::parse(remote_control_url).map_err(map_url_parse_error)?;
|
||||
match remote_control_url.scheme() {
|
||||
"https" | "http" => {}
|
||||
_ => return Err(map_scheme_error(())),
|
||||
}
|
||||
if !remote_control_url.path().ends_with('/') {
|
||||
let normalized_path = format!("{}/", remote_control_url.path());
|
||||
remote_control_url.set_path(&normalized_path);
|
||||
}
|
||||
|
||||
let mut enroll_url = remote_control_url
|
||||
.join("wham/remote/control/server/enroll")
|
||||
.map_err(map_url_parse_error)?;
|
||||
let mut websocket_url = remote_control_url
|
||||
.join("wham/remote/control/server")
|
||||
.map_err(map_url_parse_error)?;
|
||||
match remote_control_url.scheme() {
|
||||
"https" => {
|
||||
enroll_url.set_scheme("https").map_err(map_scheme_error)?;
|
||||
websocket_url.set_scheme("wss").map_err(map_scheme_error)?;
|
||||
}
|
||||
"http" => {
|
||||
enroll_url.set_scheme("http").map_err(map_scheme_error)?;
|
||||
websocket_url.set_scheme("ws").map_err(map_scheme_error)?;
|
||||
}
|
||||
_ => return Err(map_scheme_error(())),
|
||||
}
|
||||
|
||||
Ok(RemoteControlTarget {
|
||||
websocket_url: websocket_url.to_string(),
|
||||
enroll_url: enroll_url.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn normalize_remote_control_url_rewrites_http_schemes() {
|
||||
assert_eq!(
|
||||
normalize_remote_control_url("http://example.com/backend-api")
|
||||
.expect("valid http prefix"),
|
||||
RemoteControlTarget {
|
||||
websocket_url: "ws://example.com/backend-api/wham/remote/control/server"
|
||||
.to_string(),
|
||||
enroll_url: "http://example.com/backend-api/wham/remote/control/server/enroll"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_remote_control_url("https://example.com/backend-api/")
|
||||
.expect("valid https prefix"),
|
||||
RemoteControlTarget {
|
||||
websocket_url: "wss://example.com/backend-api/wham/remote/control/server"
|
||||
.to_string(),
|
||||
enroll_url: "https://example.com/backend-api/wham/remote/control/server/enroll"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_remote_control_url_rejects_unsupported_schemes() {
|
||||
let err = normalize_remote_control_url("ftp://example.com/control")
|
||||
.expect_err("unsupported scheme should fail");
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"invalid remote control URL `ftp://example.com/control`; expected absolute URL with http:// or https:// scheme"
|
||||
);
|
||||
}
|
||||
}
|
||||
1888
codex-rs/app-server/src/transport/remote_control/tests.rs
Normal file
1888
codex-rs/app-server/src/transport/remote_control/tests.rs
Normal file
File diff suppressed because it is too large
Load Diff
297
codex-rs/app-server/src/transport/remote_control/websocket.rs
Normal file
297
codex-rs/app-server/src/transport/remote_control/websocket.rs
Normal file
@@ -0,0 +1,297 @@
|
||||
use super::CHANNEL_CAPACITY;
|
||||
use super::RemoteControlQueuedServerEnvelope;
|
||||
use super::enroll::connect_remote_control_websocket;
|
||||
use super::protocol::ClientEnvelope;
|
||||
use super::protocol::ClientEvent;
|
||||
use super::protocol::ClientId;
|
||||
use super::protocol::RemoteControlTarget;
|
||||
use super::protocol::ServerEnvelope;
|
||||
use super::protocol::ServerEvent;
|
||||
use codex_core::AuthManager;
|
||||
use codex_state::StateRuntime;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::Duration;
|
||||
use tokio_tungstenite::tungstenite;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
const REMOTE_CONTROL_RECONNECT_INITIAL_BACKOFF: Duration = Duration::from_secs(1);
|
||||
const REMOTE_CONTROL_RECONNECT_MAX_BACKOFF: Duration = Duration::from_secs(30);
|
||||
|
||||
enum RemoteControlWriteCommand {
|
||||
ServerEnvelope(RemoteControlQueuedServerEnvelope),
|
||||
Pong(tungstenite::Bytes),
|
||||
}
|
||||
|
||||
struct BufferedServerEvent {
|
||||
event: ServerEvent,
|
||||
write_complete_tx: Option<oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
#[allow(clippy::print_stderr)]
|
||||
pub(super) async fn run_remote_control_websocket_loop(
|
||||
remote_control_target: RemoteControlTarget,
|
||||
state_db: Option<Arc<StateRuntime>>,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
client_event_tx: mpsc::Sender<ClientEnvelope>,
|
||||
mut server_event_rx: mpsc::Receiver<RemoteControlQueuedServerEnvelope>,
|
||||
shutdown_token: CancellationToken,
|
||||
) {
|
||||
let mut reconnect_backoff = REMOTE_CONTROL_RECONNECT_INITIAL_BACKOFF;
|
||||
let mut reconnect_attempt = 0_u64;
|
||||
let mut wait_before_connect = false;
|
||||
let mut enrollment = None;
|
||||
let mut outbound_buffer = HashMap::<ClientId, BTreeMap<u64, BufferedServerEvent>>::new();
|
||||
let mut subscribe_cursor: Option<String> = None;
|
||||
|
||||
loop {
|
||||
if wait_before_connect {
|
||||
tokio::select! {
|
||||
_ = shutdown_token.cancelled() => break,
|
||||
_ = tokio::time::sleep(reconnect_backoff) => {}
|
||||
}
|
||||
reconnect_attempt = reconnect_attempt.saturating_add(1);
|
||||
warn!(
|
||||
"app-server remote control websocket reconnect attempt {reconnect_attempt} after {reconnect_backoff:?}"
|
||||
);
|
||||
reconnect_backoff = reconnect_backoff
|
||||
.saturating_mul(2)
|
||||
.min(REMOTE_CONTROL_RECONNECT_MAX_BACKOFF);
|
||||
} else {
|
||||
wait_before_connect = true;
|
||||
}
|
||||
|
||||
let websocket_connection = tokio::select! {
|
||||
_ = shutdown_token.cancelled() => break,
|
||||
connect_result = connect_remote_control_websocket(
|
||||
&remote_control_target,
|
||||
state_db.as_deref(),
|
||||
auth_manager.as_ref(),
|
||||
&mut enrollment,
|
||||
subscribe_cursor.as_deref(),
|
||||
) => {
|
||||
match connect_result {
|
||||
Ok(websocket_connection) => {
|
||||
reconnect_backoff = REMOTE_CONTROL_RECONNECT_INITIAL_BACKOFF;
|
||||
reconnect_attempt = 0;
|
||||
info!(
|
||||
"connected to app-server remote control websocket: {}",
|
||||
remote_control_target.websocket_url
|
||||
);
|
||||
websocket_connection
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("{err}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let (mut websocket_writer, mut websocket_reader) =
|
||||
websocket_connection.websocket_stream.split();
|
||||
let (write_command_tx, mut write_command_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (reader_event_tx, mut reader_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
|
||||
let mut buffered_events_to_resend = Vec::new();
|
||||
for (client_id, buffered_events) in outbound_buffer.iter_mut() {
|
||||
for (seq_id, buffered_event) in buffered_events.iter_mut() {
|
||||
buffered_events_to_resend.push(RemoteControlQueuedServerEnvelope {
|
||||
envelope: ServerEnvelope {
|
||||
event: buffered_event.event.clone(),
|
||||
client_id: client_id.clone(),
|
||||
seq_id: Some(*seq_id),
|
||||
},
|
||||
write_complete_tx: buffered_event.write_complete_tx.take(),
|
||||
});
|
||||
}
|
||||
}
|
||||
let mut write_task = tokio::spawn(async move {
|
||||
for server_envelope in buffered_events_to_resend {
|
||||
let payload = match serde_json::to_string(&server_envelope.envelope) {
|
||||
Ok(payload) => payload,
|
||||
Err(err) => {
|
||||
error!("failed to serialize remote-control server event: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
info!("YOLO sending to codex backend: {payload}");
|
||||
if websocket_writer
|
||||
.send(tungstenite::Message::Text(payload.into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
if let Some(write_complete_tx) = server_envelope.write_complete_tx {
|
||||
let _ = write_complete_tx.send(());
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(write_command) = write_command_rx.recv().await {
|
||||
match write_command {
|
||||
RemoteControlWriteCommand::ServerEnvelope(server_envelope) => {
|
||||
let payload = match serde_json::to_string(&server_envelope.envelope) {
|
||||
Ok(payload) => payload,
|
||||
Err(err) => {
|
||||
error!("failed to serialize remote-control server event: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
info!("YOLO sending to codex backend: {payload}");
|
||||
if websocket_writer
|
||||
.send(tungstenite::Message::Text(payload.into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
if let Some(write_complete_tx) = server_envelope.write_complete_tx {
|
||||
let _ = write_complete_tx.send(());
|
||||
}
|
||||
}
|
||||
RemoteControlWriteCommand::Pong(payload) => {
|
||||
if websocket_writer
|
||||
.send(tungstenite::Message::Pong(payload))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let write_command_tx_for_reader = write_command_tx.clone();
|
||||
let mut read_task = tokio::spawn(async move {
|
||||
while let Some(incoming_message) = websocket_reader.next().await {
|
||||
match incoming_message {
|
||||
Ok(tungstenite::Message::Text(text)) => {
|
||||
if let Ok(client_envelope) = serde_json::from_str::<ClientEnvelope>(&text) {
|
||||
if reader_event_tx.send(client_envelope).await.is_err() {
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
warn!("failed to deserialize remote-control client event");
|
||||
}
|
||||
}
|
||||
Ok(tungstenite::Message::Ping(payload)) => {
|
||||
if write_command_tx_for_reader
|
||||
.send(RemoteControlWriteCommand::Pong(payload))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
Ok(tungstenite::Message::Pong(_)) => {}
|
||||
Ok(tungstenite::Message::Binary(_)) => {
|
||||
warn!("dropping unsupported binary remote-control websocket message");
|
||||
}
|
||||
Ok(tungstenite::Message::Frame(_)) => {}
|
||||
Ok(tungstenite::Message::Close(_)) => {
|
||||
warn!("remote control websocket disconnected");
|
||||
return;
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("remote control websocket receive failed: {err}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
warn!("remote control websocket disconnected");
|
||||
});
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_token.cancelled() => {
|
||||
write_task.abort();
|
||||
read_task.abort();
|
||||
return;
|
||||
}
|
||||
_ = &mut write_task => {
|
||||
read_task.abort();
|
||||
break;
|
||||
}
|
||||
_ = &mut read_task => {
|
||||
write_task.abort();
|
||||
break;
|
||||
}
|
||||
client_envelope = reader_event_rx.recv() => {
|
||||
let Some(client_envelope) = client_envelope else {
|
||||
write_task.abort();
|
||||
read_task.abort();
|
||||
break;
|
||||
};
|
||||
if let Some(cursor) = client_envelope.cursor.as_deref() {
|
||||
subscribe_cursor = Some(cursor.to_string());
|
||||
}
|
||||
if let ClientEvent::Ack { acked_seq_id } = &client_envelope.event
|
||||
&& let Some(buffered_events) = outbound_buffer.get_mut(&client_envelope.client_id)
|
||||
{
|
||||
let acknowledged_seq_ids: Vec<u64> = buffered_events
|
||||
.range(..=*acked_seq_id)
|
||||
.map(|(seq_id, _)| *seq_id)
|
||||
.collect();
|
||||
for acknowledged_seq_id in acknowledged_seq_ids {
|
||||
buffered_events.remove(&acknowledged_seq_id);
|
||||
}
|
||||
if buffered_events.is_empty() {
|
||||
outbound_buffer.remove(&client_envelope.client_id);
|
||||
}
|
||||
}
|
||||
if client_event_tx.send(client_envelope).await.is_err() {
|
||||
write_task.abort();
|
||||
read_task.abort();
|
||||
return;
|
||||
}
|
||||
}
|
||||
server_envelope = server_event_rx.recv() => {
|
||||
let Some(server_envelope) = server_envelope else {
|
||||
write_task.abort();
|
||||
read_task.abort();
|
||||
return;
|
||||
};
|
||||
if let ServerEvent::ServerMessage { .. } = &server_envelope.envelope.event
|
||||
&& let Some(seq_id) = server_envelope.envelope.seq_id
|
||||
{
|
||||
outbound_buffer
|
||||
.entry(server_envelope.envelope.client_id.clone())
|
||||
.or_default()
|
||||
.insert(seq_id, BufferedServerEvent {
|
||||
event: server_envelope.envelope.event.clone(),
|
||||
write_complete_tx: None,
|
||||
});
|
||||
}
|
||||
if let Err(err) = write_command_tx
|
||||
.send(RemoteControlWriteCommand::ServerEnvelope(server_envelope))
|
||||
.await
|
||||
{
|
||||
let RemoteControlWriteCommand::ServerEnvelope(server_envelope) = err.0 else {
|
||||
unreachable!();
|
||||
};
|
||||
if let ServerEvent::ServerMessage { .. } = &server_envelope.envelope.event
|
||||
&& let Some(seq_id) = server_envelope.envelope.seq_id
|
||||
&& let Some(buffered_events) = outbound_buffer.get_mut(&server_envelope.envelope.client_id)
|
||||
&& let Some(buffered_event) = buffered_events.get_mut(&seq_id)
|
||||
{
|
||||
buffered_event.write_complete_tx = server_envelope.write_complete_tx;
|
||||
}
|
||||
write_task.abort();
|
||||
read_task.abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
88
codex-rs/app-server/src/transport/stdio.rs
Normal file
88
codex-rs/app-server/src/transport/stdio.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
use super::CHANNEL_CAPACITY;
|
||||
use super::TransportEvent;
|
||||
use super::forward_incoming_message;
|
||||
use super::serialize_outgoing_message;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Result as IoResult;
|
||||
use tokio::io;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
|
||||
pub(crate) async fn start_stdio_connection(
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
stdio_handles: &mut Vec<JoinHandle<()>>,
|
||||
) -> IoResult<()> {
|
||||
let connection_id = ConnectionId(0);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let writer_tx_for_reader = writer_tx.clone();
|
||||
transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer: writer_tx,
|
||||
disconnect_sender: None,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| std::io::Error::new(ErrorKind::BrokenPipe, "processor unavailable"))?;
|
||||
|
||||
let transport_event_tx_for_reader = transport_event_tx.clone();
|
||||
stdio_handles.push(tokio::spawn(async move {
|
||||
let stdin = io::stdin();
|
||||
let reader = BufReader::new(stdin);
|
||||
let mut lines = reader.lines();
|
||||
|
||||
loop {
|
||||
match lines.next_line().await {
|
||||
Ok(Some(line)) => {
|
||||
if !forward_incoming_message(
|
||||
&transport_event_tx_for_reader,
|
||||
&writer_tx_for_reader,
|
||||
connection_id,
|
||||
&line,
|
||||
)
|
||||
.await
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(err) => {
|
||||
error!("Failed reading stdin: {err}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = transport_event_tx_for_reader
|
||||
.send(TransportEvent::ConnectionClosed { connection_id })
|
||||
.await;
|
||||
debug!("stdin reader finished (EOF)");
|
||||
}));
|
||||
|
||||
stdio_handles.push(tokio::spawn(async move {
|
||||
let mut stdout = io::stdout();
|
||||
while let Some(queued_message) = writer_rx.recv().await {
|
||||
let Some(mut json) = serialize_outgoing_message(queued_message.message) else {
|
||||
continue;
|
||||
};
|
||||
json.push('\n');
|
||||
if let Err(err) = stdout.write_all(json.as_bytes()).await {
|
||||
error!("Failed to write to stdout: {err}");
|
||||
break;
|
||||
}
|
||||
if let Some(write_complete_tx) = queued_message.write_complete_tx {
|
||||
let _ = write_complete_tx.send(());
|
||||
}
|
||||
}
|
||||
info!("stdout writer exited (channel closed)");
|
||||
}));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
308
codex-rs/app-server/src/transport/websocket.rs
Normal file
308
codex-rs/app-server/src/transport/websocket.rs
Normal file
@@ -0,0 +1,308 @@
|
||||
use super::CHANNEL_CAPACITY;
|
||||
use super::TransportEvent;
|
||||
use super::auth::WebsocketAuthPolicy;
|
||||
use super::auth::authorize_upgrade;
|
||||
use super::auth::should_warn_about_unauthenticated_non_loopback_listener;
|
||||
use super::forward_incoming_message;
|
||||
use super::serialize_outgoing_message;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use axum::Router;
|
||||
use axum::body::Body;
|
||||
use axum::extract::ConnectInfo;
|
||||
use axum::extract::State;
|
||||
use axum::extract::ws::Message as WebSocketMessage;
|
||||
use axum::extract::ws::WebSocket;
|
||||
use axum::extract::ws::WebSocketUpgrade;
|
||||
use axum::http::HeaderMap;
|
||||
use axum::http::Request;
|
||||
use axum::http::StatusCode;
|
||||
use axum::http::header::ORIGIN;
|
||||
use axum::middleware;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::response::Response;
|
||||
use axum::routing::any;
|
||||
use axum::routing::get;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use owo_colors::OwoColorize;
|
||||
use owo_colors::Stream;
|
||||
use owo_colors::Style;
|
||||
use std::io::Result as IoResult;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
fn colorize(text: &str, style: Style) -> String {
|
||||
text.if_supports_color(Stream::Stderr, |value| value.style(style))
|
||||
.to_string()
|
||||
}
|
||||
|
||||
#[allow(clippy::print_stderr)]
|
||||
fn print_websocket_startup_banner(addr: SocketAddr) {
|
||||
let title = colorize("codex app-server (WebSockets)", Style::new().bold().cyan());
|
||||
let listening_label = colorize("listening on:", Style::new().dimmed());
|
||||
let listen_url = colorize(&format!("ws://{addr}"), Style::new().green());
|
||||
let ready_label = colorize("readyz:", Style::new().dimmed());
|
||||
let ready_url = colorize(&format!("http://{addr}/readyz"), Style::new().green());
|
||||
let health_label = colorize("healthz:", Style::new().dimmed());
|
||||
let health_url = colorize(&format!("http://{addr}/healthz"), Style::new().green());
|
||||
let note_label = colorize("note:", Style::new().dimmed());
|
||||
eprintln!("{title}");
|
||||
eprintln!(" {listening_label} {listen_url}");
|
||||
eprintln!(" {ready_label} {ready_url}");
|
||||
eprintln!(" {health_label} {health_url}");
|
||||
if addr.ip().is_loopback() {
|
||||
eprintln!(
|
||||
" {note_label} binds localhost only (use SSH port-forwarding for remote access)"
|
||||
);
|
||||
} else {
|
||||
eprintln!(
|
||||
" {note_label} this is a raw WS server; consider running behind TLS/auth for real remote use"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct WebSocketListenerState {
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
connection_counter: Arc<AtomicU64>,
|
||||
auth_policy: Arc<WebsocketAuthPolicy>,
|
||||
}
|
||||
|
||||
async fn health_check_handler() -> StatusCode {
|
||||
StatusCode::OK
|
||||
}
|
||||
|
||||
async fn reject_requests_with_origin_header(
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
if request.headers().contains_key(ORIGIN) {
|
||||
warn!(
|
||||
method = %request.method(),
|
||||
uri = %request.uri(),
|
||||
"rejecting websocket listener request with Origin header"
|
||||
);
|
||||
Err(StatusCode::FORBIDDEN)
|
||||
} else {
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
}
|
||||
|
||||
async fn websocket_upgrade_handler(
|
||||
websocket: WebSocketUpgrade,
|
||||
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
|
||||
State(state): State<WebSocketListenerState>,
|
||||
headers: HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(err) = authorize_upgrade(&headers, state.auth_policy.as_ref()) {
|
||||
warn!(
|
||||
%peer_addr,
|
||||
message = err.message(),
|
||||
"rejecting websocket client during upgrade"
|
||||
);
|
||||
return (err.status_code(), err.message()).into_response();
|
||||
}
|
||||
let connection_id = ConnectionId(state.connection_counter.fetch_add(1, Ordering::Relaxed));
|
||||
info!(%peer_addr, "websocket client connected");
|
||||
websocket
|
||||
.on_upgrade(move |stream| async move {
|
||||
run_websocket_connection(connection_id, stream, state.transport_event_tx).await;
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
|
||||
pub(crate) async fn start_websocket_acceptor(
|
||||
bind_address: SocketAddr,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
shutdown_token: CancellationToken,
|
||||
auth_policy: WebsocketAuthPolicy,
|
||||
) -> IoResult<JoinHandle<()>> {
|
||||
if should_warn_about_unauthenticated_non_loopback_listener(bind_address, &auth_policy) {
|
||||
warn!(
|
||||
%bind_address,
|
||||
"starting non-loopback websocket listener without auth; websocket auth is opt-in for now and will become the default in a future release"
|
||||
);
|
||||
}
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
let local_addr = listener.local_addr()?;
|
||||
print_websocket_startup_banner(local_addr);
|
||||
info!("app-server websocket listening on ws://{local_addr}");
|
||||
|
||||
let router = Router::new()
|
||||
.route("/readyz", get(health_check_handler))
|
||||
.route("/healthz", get(health_check_handler))
|
||||
.fallback(any(websocket_upgrade_handler))
|
||||
.layer(middleware::from_fn(reject_requests_with_origin_header))
|
||||
.with_state(WebSocketListenerState {
|
||||
transport_event_tx,
|
||||
connection_counter: Arc::new(AtomicU64::new(1)),
|
||||
auth_policy: Arc::new(auth_policy),
|
||||
});
|
||||
let server = axum::serve(
|
||||
listener,
|
||||
router.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
)
|
||||
.with_graceful_shutdown(async move {
|
||||
shutdown_token.cancelled().await;
|
||||
});
|
||||
Ok(tokio::spawn(async move {
|
||||
if let Err(err) = server.await {
|
||||
error!("websocket acceptor failed: {err}");
|
||||
}
|
||||
info!("websocket acceptor shutting down");
|
||||
}))
|
||||
}
|
||||
|
||||
async fn run_websocket_connection(
|
||||
connection_id: ConnectionId,
|
||||
websocket_stream: WebSocket,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
) {
|
||||
let (writer_tx, writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let writer_tx_for_reader = writer_tx.clone();
|
||||
let disconnect_token = CancellationToken::new();
|
||||
if transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer: writer_tx,
|
||||
disconnect_sender: Some(disconnect_token.clone()),
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let (websocket_writer, websocket_reader) = websocket_stream.split();
|
||||
let (writer_control_tx, writer_control_rx) =
|
||||
mpsc::channel::<WebSocketMessage>(CHANNEL_CAPACITY);
|
||||
let mut outbound_task = tokio::spawn(run_websocket_outbound_loop(
|
||||
websocket_writer,
|
||||
writer_rx,
|
||||
writer_control_rx,
|
||||
disconnect_token.clone(),
|
||||
));
|
||||
let mut inbound_task = tokio::spawn(run_websocket_inbound_loop(
|
||||
websocket_reader,
|
||||
transport_event_tx.clone(),
|
||||
writer_tx_for_reader,
|
||||
writer_control_tx,
|
||||
connection_id,
|
||||
disconnect_token.clone(),
|
||||
));
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut outbound_task => {
|
||||
disconnect_token.cancel();
|
||||
inbound_task.abort();
|
||||
}
|
||||
_ = &mut inbound_task => {
|
||||
disconnect_token.cancel();
|
||||
outbound_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
let _ = transport_event_tx
|
||||
.send(TransportEvent::ConnectionClosed { connection_id })
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn run_websocket_outbound_loop(
|
||||
mut websocket_writer: futures::stream::SplitSink<WebSocket, WebSocketMessage>,
|
||||
mut writer_rx: mpsc::Receiver<QueuedOutgoingMessage>,
|
||||
mut writer_control_rx: mpsc::Receiver<WebSocketMessage>,
|
||||
disconnect_token: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = disconnect_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
message = writer_control_rx.recv() => {
|
||||
let Some(message) = message else {
|
||||
break;
|
||||
};
|
||||
if websocket_writer.send(message).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
queued_message = writer_rx.recv() => {
|
||||
let Some(queued_message) = queued_message else {
|
||||
break;
|
||||
};
|
||||
let Some(json) = serialize_outgoing_message(queued_message.message) else {
|
||||
continue;
|
||||
};
|
||||
if websocket_writer.send(WebSocketMessage::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
if let Some(write_complete_tx) = queued_message.write_complete_tx {
|
||||
let _ = write_complete_tx.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_websocket_inbound_loop(
|
||||
mut websocket_reader: futures::stream::SplitStream<WebSocket>,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
writer_tx_for_reader: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
writer_control_tx: mpsc::Sender<WebSocketMessage>,
|
||||
connection_id: ConnectionId,
|
||||
disconnect_token: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = disconnect_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
incoming_message = websocket_reader.next() => {
|
||||
match incoming_message {
|
||||
Some(Ok(WebSocketMessage::Text(text))) => {
|
||||
if !forward_incoming_message(
|
||||
&transport_event_tx,
|
||||
&writer_tx_for_reader,
|
||||
connection_id,
|
||||
text.as_ref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Ok(WebSocketMessage::Ping(payload))) => {
|
||||
match writer_control_tx.try_send(WebSocketMessage::Pong(payload)) {
|
||||
Ok(()) => {}
|
||||
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => break,
|
||||
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
|
||||
warn!("websocket control queue full while replying to ping; closing connection");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(WebSocketMessage::Pong(_))) => {}
|
||||
Some(Ok(WebSocketMessage::Close(_))) | None => break,
|
||||
Some(Ok(WebSocketMessage::Binary(_))) => {
|
||||
warn!("dropping unsupported binary websocket message");
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
warn!("websocket receive error: {err}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
use anyhow::Result;
|
||||
use app_test_support::ChatGptAuthFixture;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::test_path_buf_with_windows;
|
||||
use app_test_support::test_tmp_path_buf;
|
||||
use app_test_support::to_response;
|
||||
use app_test_support::write_chatgpt_auth;
|
||||
use codex_app_server_protocol::AppConfig;
|
||||
use codex_app_server_protocol::AppToolApproval;
|
||||
use codex_app_server_protocol::AppsConfig;
|
||||
@@ -21,6 +23,7 @@ use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::SandboxMode;
|
||||
use codex_app_server_protocol::ToolsV2;
|
||||
use codex_app_server_protocol::WriteStatus;
|
||||
use codex_core::auth::AuthCredentialsStoreMode;
|
||||
use codex_core::config::set_project_trust_level;
|
||||
use codex_protocol::config_types::TrustLevel;
|
||||
use codex_protocol::config_types::WebSearchContextSize;
|
||||
@@ -88,6 +91,69 @@ sandbox_mode = "workspace-write"
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn config_read_includes_chatgpt_base_url() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
write_config(
|
||||
&codex_home,
|
||||
r#"
|
||||
chatgpt_base_url = "https://example.com/backend-api/"
|
||||
|
||||
[features]
|
||||
remote_control = true
|
||||
"#,
|
||||
)?;
|
||||
write_chatgpt_auth(
|
||||
codex_home.path(),
|
||||
ChatGptAuthFixture::new("chatgpt-token"),
|
||||
AuthCredentialsStoreMode::File,
|
||||
)?;
|
||||
let codex_home_path = codex_home.path().canonicalize()?;
|
||||
let user_file = AbsolutePathBuf::try_from(codex_home_path.join("config.toml"))?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let request_id = mcp
|
||||
.send_config_read_request(ConfigReadParams {
|
||||
include_layers: true,
|
||||
cwd: None,
|
||||
})
|
||||
.await?;
|
||||
let resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||
)
|
||||
.await??;
|
||||
let ConfigReadResponse {
|
||||
config,
|
||||
origins,
|
||||
layers,
|
||||
} = to_response(resp)?;
|
||||
|
||||
assert_eq!(
|
||||
config.additional.get("chatgpt_base_url"),
|
||||
Some(&json!("https://example.com/backend-api/"))
|
||||
);
|
||||
assert_eq!(
|
||||
origins.get("chatgpt_base_url").expect("origin").name,
|
||||
ConfigLayerSource::User {
|
||||
file: user_file.clone(),
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
origins.get("features.remote_control").expect("origin").name,
|
||||
ConfigLayerSource::User {
|
||||
file: user_file.clone(),
|
||||
}
|
||||
);
|
||||
|
||||
let layers = layers.expect("layers present");
|
||||
assert_layers_user_then_optional_system(&layers, user_file)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn config_read_includes_tools() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
@@ -328,6 +328,7 @@ impl CloudRequirementsService {
|
||||
return Ok(None);
|
||||
};
|
||||
if !auth.is_chatgpt_auth()
|
||||
|| auth.is_external_chatgpt_tokens()
|
||||
|| !matches!(
|
||||
auth.account_plan_type(),
|
||||
Some(PlanType::Business | PlanType::Enterprise)
|
||||
|
||||
@@ -437,6 +437,9 @@
|
||||
"realtime_conversation": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"remote_control": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"remote_models": {
|
||||
"type": "boolean"
|
||||
},
|
||||
@@ -2054,6 +2057,9 @@
|
||||
"realtime_conversation": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"remote_control": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"remote_models": {
|
||||
"type": "boolean"
|
||||
},
|
||||
|
||||
@@ -176,6 +176,8 @@ pub enum Feature {
|
||||
VoiceTranscription,
|
||||
/// Enable experimental realtime voice conversation mode in the TUI.
|
||||
RealtimeConversation,
|
||||
/// Connect app-server to the ChatGPT remote control service.
|
||||
RemoteControl,
|
||||
/// Route interactive startup to the app-server-backed TUI implementation.
|
||||
TuiAppServer,
|
||||
/// Prevent idle system sleep while a turn is actively running.
|
||||
@@ -819,6 +821,12 @@ pub const FEATURES: &[FeatureSpec] = &[
|
||||
stage: Stage::UnderDevelopment,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::RemoteControl,
|
||||
key: "remote_control",
|
||||
stage: Stage::UnderDevelopment,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::TuiAppServer,
|
||||
key: "tui_app_server",
|
||||
|
||||
@@ -159,6 +159,12 @@ fn image_detail_original_feature_is_under_development() {
|
||||
assert_eq!(Feature::ImageDetailOriginal.default_enabled(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remote_control_is_under_development() {
|
||||
assert_eq!(Feature::RemoteControl.stage(), Stage::UnderDevelopment);
|
||||
assert_eq!(Feature::RemoteControl.default_enabled(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collab_is_legacy_alias_for_multi_agent() {
|
||||
assert_eq!(feature_for_key("multi_agent"), Some(Feature::Collab));
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
CREATE TABLE remote_control_enrollments (
|
||||
websocket_url TEXT NOT NULL,
|
||||
account_id TEXT NOT NULL,
|
||||
server_id TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
PRIMARY KEY (websocket_url, account_id)
|
||||
);
|
||||
@@ -53,6 +53,7 @@ mod agent_jobs;
|
||||
mod backfill;
|
||||
mod logs;
|
||||
mod memories;
|
||||
mod remote_control;
|
||||
#[cfg(test)]
|
||||
mod test_support;
|
||||
mod threads;
|
||||
|
||||
197
codex-rs/state/src/runtime/remote_control.rs
Normal file
197
codex-rs/state/src/runtime/remote_control.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
use super::*;
|
||||
|
||||
const REMOTE_CONTROL_ACCOUNT_ID_NONE: &str = "";
|
||||
|
||||
fn remote_control_account_id_key(account_id: Option<&str>) -> &str {
|
||||
account_id.unwrap_or(REMOTE_CONTROL_ACCOUNT_ID_NONE)
|
||||
}
|
||||
|
||||
impl StateRuntime {
|
||||
pub async fn get_remote_control_enrollment(
|
||||
&self,
|
||||
websocket_url: &str,
|
||||
account_id: Option<&str>,
|
||||
) -> anyhow::Result<Option<(String, String)>> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT server_id, server_name
|
||||
FROM remote_control_enrollments
|
||||
WHERE websocket_url = ? AND account_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(websocket_url)
|
||||
.bind(remote_control_account_id_key(account_id))
|
||||
.fetch_optional(self.pool.as_ref())
|
||||
.await?;
|
||||
|
||||
row.map(|row| Ok((row.try_get("server_id")?, row.try_get("server_name")?)))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
pub async fn upsert_remote_control_enrollment(
|
||||
&self,
|
||||
websocket_url: &str,
|
||||
account_id: Option<&str>,
|
||||
server_id: &str,
|
||||
server_name: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO remote_control_enrollments (
|
||||
websocket_url,
|
||||
account_id,
|
||||
server_id,
|
||||
server_name,
|
||||
updated_at
|
||||
) VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(websocket_url, account_id) DO UPDATE SET
|
||||
server_id = excluded.server_id,
|
||||
server_name = excluded.server_name,
|
||||
updated_at = excluded.updated_at
|
||||
"#,
|
||||
)
|
||||
.bind(websocket_url)
|
||||
.bind(remote_control_account_id_key(account_id))
|
||||
.bind(server_id)
|
||||
.bind(server_name)
|
||||
.bind(Utc::now().timestamp())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_remote_control_enrollment(
|
||||
&self,
|
||||
websocket_url: &str,
|
||||
account_id: Option<&str>,
|
||||
) -> anyhow::Result<u64> {
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
DELETE FROM remote_control_enrollments
|
||||
WHERE websocket_url = ? AND account_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(websocket_url)
|
||||
.bind(remote_control_account_id_key(account_id))
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::StateRuntime;
|
||||
use super::test_support::unique_temp_dir;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_control_enrollment_round_trips_by_target_and_account() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string())
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
runtime
|
||||
.upsert_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
Some("account-a"),
|
||||
"srv_e_first",
|
||||
"first-server",
|
||||
)
|
||||
.await
|
||||
.expect("insert first enrollment");
|
||||
runtime
|
||||
.upsert_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
Some("account-b"),
|
||||
"srv_e_second",
|
||||
"second-server",
|
||||
)
|
||||
.await
|
||||
.expect("insert second enrollment");
|
||||
|
||||
assert_eq!(
|
||||
runtime
|
||||
.get_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
Some("account-a"),
|
||||
)
|
||||
.await
|
||||
.expect("load first enrollment"),
|
||||
Some(("srv_e_first".to_string(), "first-server".to_string()))
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.get_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("load missing enrollment"),
|
||||
None
|
||||
);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn delete_remote_control_enrollment_removes_only_matching_entry() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string())
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
runtime
|
||||
.upsert_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
None,
|
||||
"srv_e_first",
|
||||
"first-server",
|
||||
)
|
||||
.await
|
||||
.expect("insert first enrollment");
|
||||
runtime
|
||||
.upsert_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
Some("account-a"),
|
||||
"srv_e_second",
|
||||
"second-server",
|
||||
)
|
||||
.await
|
||||
.expect("insert second enrollment");
|
||||
|
||||
assert_eq!(
|
||||
runtime
|
||||
.delete_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("delete first enrollment"),
|
||||
1
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.get_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("load deleted enrollment"),
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.get_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
Some("account-a"),
|
||||
)
|
||||
.await
|
||||
.expect("load retained enrollment"),
|
||||
Some(("srv_e_second".to_string(), "second-server".to_string()))
|
||||
);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user