mirror of
https://github.com/openai/codex.git
synced 2026-03-27 09:03:51 +00:00
Compare commits
1 Commits
pakrym/fix
...
codex/webs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2350059789 |
3
codex-rs/Cargo.lock
generated
3
codex-rs/Cargo.lock
generated
@@ -1423,9 +1423,11 @@ dependencies = [
|
||||
"codex-utils-cli",
|
||||
"codex-utils-json-to-toml",
|
||||
"codex-utils-pty",
|
||||
"codex-utils-rustls-provider",
|
||||
"constant_time_eq",
|
||||
"core_test_support",
|
||||
"futures",
|
||||
"gethostname",
|
||||
"hmac",
|
||||
"jsonwebtoken",
|
||||
"opentelemetry",
|
||||
@@ -1448,6 +1450,7 @@ dependencies = [
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
"uuid",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
@@ -52,10 +52,12 @@ codex-state = { workspace = true }
|
||||
codex-tools = { workspace = true }
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
codex-utils-json-to-toml = { workspace = true }
|
||||
codex-utils-rustls-provider = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
constant_time_eq = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
gethostname = { workspace = true }
|
||||
hmac = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
owo-colors = { workspace = true, features = ["supports-colors"] }
|
||||
@@ -76,6 +78,7 @@ tokio-util = { workspace = true }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt", "json"] }
|
||||
url = { workspace = true }
|
||||
uuid = { workspace = true, features = ["serde", "v7"] }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -25,6 +25,7 @@ Supported transports:
|
||||
|
||||
- stdio (`--listen stdio://`, default): newline-delimited JSON (JSONL)
|
||||
- websocket (`--listen ws://IP:PORT`): one JSON-RPC message per websocket text frame (**experimental / unsupported**)
|
||||
- off (`--listen off`): do not expose a local transport
|
||||
|
||||
When running with `--listen ws://IP:PORT`, the same listener also serves basic HTTP health probes:
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ fn transport_name(transport: AppServerTransport) -> &'static str {
|
||||
match transport {
|
||||
AppServerTransport::Stdio => "stdio",
|
||||
AppServerTransport::WebSocket { .. } => "websocket",
|
||||
AppServerTransport::Off => "off",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -74,6 +74,7 @@ use codex_app_server_protocol::Result;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_app_server_protocol::ServerRequest;
|
||||
use codex_arg0::Arg0DispatchPaths;
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config_loader::CloudRequirementsLoader;
|
||||
use codex_core::config_loader::LoaderOverrides;
|
||||
@@ -377,6 +378,14 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
|
||||
}
|
||||
});
|
||||
|
||||
let auth_manager = AuthManager::shared(
|
||||
args.config.codex_home.clone(),
|
||||
args.enable_codex_api_key_env,
|
||||
args.config.cli_auth_credentials_store_mode,
|
||||
);
|
||||
auth_manager
|
||||
.set_forced_chatgpt_workspace_id(args.config.forced_chatgpt_workspace_id.clone());
|
||||
|
||||
let processor_outgoing = Arc::clone(&outgoing_message_sender);
|
||||
let (processor_tx, mut processor_rx) = mpsc::channel::<ProcessorCommand>(channel_capacity);
|
||||
let mut processor_handle = tokio::spawn(async move {
|
||||
@@ -392,7 +401,7 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
|
||||
log_db: None,
|
||||
config_warnings: args.config_warnings,
|
||||
session_source: args.session_source,
|
||||
enable_codex_api_key_env: args.enable_codex_api_key_env,
|
||||
auth_manager,
|
||||
});
|
||||
let mut thread_created_rx = processor.thread_created_receiver();
|
||||
let mut session = ConnectionSessionState::default();
|
||||
|
||||
@@ -8,6 +8,7 @@ use codex_core::config::ConfigBuilder;
|
||||
use codex_core::config_loader::CloudRequirementsLoader;
|
||||
use codex_core::config_loader::ConfigLayerStackOrdering;
|
||||
use codex_core::config_loader::LoaderOverrides;
|
||||
use codex_features::Feature;
|
||||
use codex_utils_cli::CliConfigOverrides;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
@@ -29,6 +30,7 @@ use crate::transport::OutboundConnectionState;
|
||||
use crate::transport::TransportEvent;
|
||||
use crate::transport::auth::policy_from_settings;
|
||||
use crate::transport::route_outgoing_envelope;
|
||||
use crate::transport::start_remote_control;
|
||||
use crate::transport::start_stdio_connection;
|
||||
use crate::transport::start_websocket_acceptor;
|
||||
use codex_app_server_protocol::ConfigLayerSource;
|
||||
@@ -499,13 +501,13 @@ pub async fn run_main_with_transport(
|
||||
|
||||
let feedback_layer = feedback.logger_layer();
|
||||
let feedback_metadata_layer = feedback.metadata_layer();
|
||||
let log_db = codex_state::StateRuntime::init(
|
||||
let state_db = codex_state::StateRuntime::init(
|
||||
config.sqlite_home.clone(),
|
||||
config.model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.ok()
|
||||
.map(log_db::start);
|
||||
.ok();
|
||||
let log_db = state_db.clone().map(log_db::start);
|
||||
let log_db_layer = log_db
|
||||
.clone()
|
||||
.map(|layer| layer.with_filter(Targets::new().with_default(Level::TRACE)));
|
||||
@@ -548,6 +550,32 @@ pub async fn run_main_with_transport(
|
||||
.await?;
|
||||
transport_accept_handles.push(accept_handle);
|
||||
}
|
||||
AppServerTransport::Off => {}
|
||||
}
|
||||
|
||||
let auth_manager = AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
/*enable_codex_api_key_env*/ false,
|
||||
config.cli_auth_credentials_store_mode,
|
||||
);
|
||||
auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone());
|
||||
|
||||
if config.features.enabled(Feature::RemoteControl) {
|
||||
let accept_handle = start_remote_control(
|
||||
config.chatgpt_base_url.clone(),
|
||||
state_db.clone(),
|
||||
auth_manager.clone(),
|
||||
transport_event_tx.clone(),
|
||||
transport_shutdown_token.clone(),
|
||||
)
|
||||
.await?;
|
||||
transport_accept_handles.push(accept_handle);
|
||||
}
|
||||
if transport_accept_handles.is_empty() {
|
||||
return Err(std::io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
"no transport configured; use --listen or enable remote control",
|
||||
));
|
||||
}
|
||||
|
||||
let outbound_handle = tokio::spawn(async move {
|
||||
@@ -622,7 +650,7 @@ pub async fn run_main_with_transport(
|
||||
log_db,
|
||||
config_warnings,
|
||||
session_source,
|
||||
enable_codex_api_key_env: false,
|
||||
auth_manager,
|
||||
});
|
||||
let mut thread_created_rx = processor.thread_created_receiver();
|
||||
let mut running_turn_count_rx = processor.subscribe_running_assistant_turn_count();
|
||||
|
||||
@@ -16,7 +16,7 @@ const MANAGED_CONFIG_PATH_ENV_VAR: &str = "CODEX_APP_SERVER_MANAGED_CONFIG_PATH"
|
||||
#[derive(Debug, Parser)]
|
||||
struct AppServerArgs {
|
||||
/// Transport endpoint URL. Supported values: `stdio://` (default),
|
||||
/// `ws://IP:PORT`.
|
||||
/// `ws://IP:PORT`, `off`.
|
||||
#[arg(
|
||||
long = "listen",
|
||||
value_name = "URL",
|
||||
|
||||
@@ -187,7 +187,7 @@ pub(crate) struct MessageProcessorArgs {
|
||||
pub(crate) log_db: Option<LogDbLayer>,
|
||||
pub(crate) config_warnings: Vec<ConfigWarningNotification>,
|
||||
pub(crate) session_source: SessionSource,
|
||||
pub(crate) enable_codex_api_key_env: bool,
|
||||
pub(crate) auth_manager: Arc<AuthManager>,
|
||||
}
|
||||
|
||||
impl MessageProcessor {
|
||||
@@ -206,13 +206,8 @@ impl MessageProcessor {
|
||||
log_db,
|
||||
config_warnings,
|
||||
session_source,
|
||||
enable_codex_api_key_env,
|
||||
auth_manager,
|
||||
} = args;
|
||||
let auth_manager = AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
enable_codex_api_key_env,
|
||||
config.cli_auth_credentials_store_mode,
|
||||
);
|
||||
let thread_manager = Arc::new(ThreadManager::new(
|
||||
config.as_ref(),
|
||||
auth_manager.clone(),
|
||||
@@ -224,7 +219,6 @@ impl MessageProcessor {
|
||||
},
|
||||
environment_manager,
|
||||
));
|
||||
auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone());
|
||||
auth_manager.set_external_auth_refresher(Arc::new(ExternalAuthRefreshBridge {
|
||||
outgoing: outgoing.clone(),
|
||||
}));
|
||||
|
||||
@@ -20,6 +20,7 @@ use codex_app_server_protocol::TurnStartParams;
|
||||
use codex_app_server_protocol::TurnStartResponse;
|
||||
use codex_app_server_protocol::UserInput;
|
||||
use codex_arg0::Arg0DispatchPaths;
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::ConfigBuilder;
|
||||
use codex_core::config_loader::CloudRequirementsLoader;
|
||||
@@ -231,6 +232,13 @@ fn build_test_processor(
|
||||
MessageProcessor,
|
||||
mpsc::Receiver<crate::outgoing_message::OutgoingEnvelope>,
|
||||
) {
|
||||
let auth_manager = AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
/*enable_codex_api_key_env*/ false,
|
||||
config.cli_auth_credentials_store_mode,
|
||||
);
|
||||
auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone());
|
||||
|
||||
let (outgoing_tx, outgoing_rx) = mpsc::channel(16);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(outgoing_tx));
|
||||
let processor = MessageProcessor::new(MessageProcessorArgs {
|
||||
@@ -245,7 +253,7 @@ fn build_test_processor(
|
||||
log_db: None,
|
||||
config_warnings: Vec::new(),
|
||||
session_source: SessionSource::VSCode,
|
||||
enable_codex_api_key_env: false,
|
||||
auth_manager,
|
||||
});
|
||||
(processor, outgoing_rx)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -28,9 +29,11 @@ use tracing::warn;
|
||||
/// plenty for an interactive CLI.
|
||||
pub(crate) const CHANNEL_CAPACITY: usize = 128;
|
||||
|
||||
mod remote_control;
|
||||
mod stdio;
|
||||
mod websocket;
|
||||
|
||||
pub(crate) use remote_control::start_remote_control;
|
||||
pub(crate) use stdio::start_stdio_connection;
|
||||
pub(crate) use websocket::start_websocket_acceptor;
|
||||
|
||||
@@ -38,6 +41,7 @@ pub(crate) use websocket::start_websocket_acceptor;
|
||||
pub enum AppServerTransport {
|
||||
Stdio,
|
||||
WebSocket { bind_address: SocketAddr },
|
||||
Off,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
@@ -51,7 +55,7 @@ impl std::fmt::Display for AppServerTransportParseError {
|
||||
match self {
|
||||
AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!(
|
||||
f,
|
||||
"unsupported --listen URL `{listen_url}`; expected `stdio://` or `ws://IP:PORT`"
|
||||
"unsupported --listen URL `{listen_url}`; expected `stdio://`, `ws://IP:PORT`, or `off`"
|
||||
),
|
||||
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!(
|
||||
f,
|
||||
@@ -71,6 +75,10 @@ impl AppServerTransport {
|
||||
return Ok(Self::Stdio);
|
||||
}
|
||||
|
||||
if listen_url == "off" {
|
||||
return Ok(Self::Off);
|
||||
}
|
||||
|
||||
if let Some(socket_addr) = listen_url.strip_prefix("ws://") {
|
||||
let bind_address = socket_addr.parse::<SocketAddr>().map_err(|_| {
|
||||
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string())
|
||||
@@ -166,6 +174,12 @@ impl OutboundConnectionState {
|
||||
}
|
||||
}
|
||||
|
||||
static CONNECTION_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
fn next_connection_id() -> ConnectionId {
|
||||
ConnectionId(CONNECTION_ID_COUNTER.fetch_add(1, Ordering::Relaxed))
|
||||
}
|
||||
|
||||
async fn forward_incoming_message(
|
||||
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||
writer: &mpsc::Sender<QueuedOutgoingMessage>,
|
||||
@@ -378,8 +392,11 @@ pub(crate) async fn route_outgoing_envelope(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::error_code::OVERLOADED_ERROR_CODE;
|
||||
use codex_app_server_protocol::ConfigWarningNotification;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -393,41 +410,10 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_transport_parses_stdio_listen_url() {
|
||||
let transport = AppServerTransport::from_listen_url(AppServerTransport::DEFAULT_LISTEN_URL)
|
||||
.expect("stdio listen URL should parse");
|
||||
assert_eq!(transport, AppServerTransport::Stdio);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_transport_parses_websocket_listen_url() {
|
||||
let transport = AppServerTransport::from_listen_url("ws://127.0.0.1:1234")
|
||||
.expect("websocket listen URL should parse");
|
||||
fn listen_off_parses_as_off_transport() {
|
||||
assert_eq!(
|
||||
transport,
|
||||
AppServerTransport::WebSocket {
|
||||
bind_address: "127.0.0.1:1234".parse().expect("valid socket address"),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_transport_rejects_invalid_websocket_listen_url() {
|
||||
let err = AppServerTransport::from_listen_url("ws://localhost:1234")
|
||||
.expect_err("hostname bind address should be rejected");
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"invalid websocket --listen URL `ws://localhost:1234`; expected `ws://IP:PORT`"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_transport_rejects_unsupported_listen_url() {
|
||||
let err = AppServerTransport::from_listen_url("http://127.0.0.1:1234")
|
||||
.expect_err("unsupported scheme should fail");
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"unsupported --listen URL `http://127.0.0.1:1234`; expected `stdio://` or `ws://IP:PORT`"
|
||||
AppServerTransport::from_listen_url("off"),
|
||||
Ok(AppServerTransport::Off)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -437,11 +423,10 @@ mod tests {
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
|
||||
let first_message =
|
||||
JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
let first_message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
@@ -450,8 +435,8 @@ mod tests {
|
||||
.await
|
||||
.expect("queue should accept first message");
|
||||
|
||||
let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest {
|
||||
id: codex_app_server_protocol::RequestId::Integer(7),
|
||||
let request = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(7),
|
||||
method: "config/read".to_string(),
|
||||
params: Some(json!({ "includeLayers": false })),
|
||||
trace: None,
|
||||
@@ -499,11 +484,10 @@ mod tests {
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1);
|
||||
let (writer_tx, _writer_rx) = mpsc::channel(1);
|
||||
|
||||
let first_message =
|
||||
JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
let first_message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
@@ -512,8 +496,8 @@ mod tests {
|
||||
.await
|
||||
.expect("queue should accept first message");
|
||||
|
||||
let response = JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse {
|
||||
id: codex_app_server_protocol::RequestId::Integer(7),
|
||||
let response = JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: RequestId::Integer(7),
|
||||
result: json!({"ok": true}),
|
||||
});
|
||||
let transport_event_tx_for_enqueue = transport_event_tx.clone();
|
||||
@@ -553,11 +537,10 @@ mod tests {
|
||||
match forwarded_event {
|
||||
TransportEvent::IncomingMessage {
|
||||
connection_id: queued_connection_id,
|
||||
message:
|
||||
JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse { id, result }),
|
||||
message: JSONRPCMessage::Response(JSONRPCResponse { id, result }),
|
||||
} => {
|
||||
assert_eq!(queued_connection_id, connection_id);
|
||||
assert_eq!(id, codex_app_server_protocol::RequestId::Integer(7));
|
||||
assert_eq!(id, RequestId::Integer(7));
|
||||
assert_eq!(result, json!({"ok": true}));
|
||||
}
|
||||
_ => panic!("expected forwarded response message"),
|
||||
@@ -573,12 +556,10 @@ mod tests {
|
||||
transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message: JSONRPCMessage::Notification(
|
||||
codex_app_server_protocol::JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
},
|
||||
),
|
||||
message: JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
}),
|
||||
})
|
||||
.await
|
||||
.expect("transport queue should accept first message");
|
||||
@@ -597,15 +578,15 @@ mod tests {
|
||||
.await
|
||||
.expect("writer queue should accept first message");
|
||||
|
||||
let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest {
|
||||
id: codex_app_server_protocol::RequestId::Integer(7),
|
||||
let request = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(7),
|
||||
method: "config/read".to_string(),
|
||||
params: Some(json!({ "includeLayers": false })),
|
||||
trace: None,
|
||||
});
|
||||
|
||||
let enqueue_result = tokio::time::timeout(
|
||||
std::time::Duration::from_millis(100),
|
||||
let enqueue_result = timeout(
|
||||
Duration::from_millis(100),
|
||||
enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request),
|
||||
)
|
||||
.await
|
||||
@@ -781,7 +762,7 @@ mod tests {
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
|
||||
request_id: codex_app_server_protocol::RequestId::Integer(1),
|
||||
request_id: RequestId::Integer(1),
|
||||
params: codex_app_server_protocol::CommandExecutionRequestApprovalParams {
|
||||
thread_id: "thr_123".to_string(),
|
||||
turn_id: "turn_123".to_string(),
|
||||
@@ -843,7 +824,7 @@ mod tests {
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
|
||||
request_id: codex_app_server_protocol::RequestId::Integer(1),
|
||||
request_id: RequestId::Integer(1),
|
||||
params: codex_app_server_protocol::CommandExecutionRequestApprovalParams {
|
||||
thread_id: "thr_123".to_string(),
|
||||
turn_id: "turn_123".to_string(),
|
||||
|
||||
@@ -0,0 +1,422 @@
|
||||
use super::CHANNEL_CAPACITY;
|
||||
use super::TransportEvent;
|
||||
use super::next_connection_id;
|
||||
use super::protocol::ClientEnvelope;
|
||||
pub use super::protocol::ClientEvent;
|
||||
pub use super::protocol::ClientId;
|
||||
use super::protocol::PongStatus;
|
||||
use super::protocol::ServerEvent;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use crate::transport::remote_control::QueuedServerEnvelope;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::watch;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
const REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT: Duration = Duration::from_secs(10 * 60);
|
||||
pub(crate) const REMOTE_CONTROL_IDLE_SWEEP_INTERVAL: Duration = Duration::from_secs(30);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Stopped;
|
||||
|
||||
struct ClientState {
|
||||
connection_id: ConnectionId,
|
||||
disconnect_token: CancellationToken,
|
||||
last_activity_at: Instant,
|
||||
last_inbound_seq_id: Option<u64>,
|
||||
status_tx: watch::Sender<PongStatus>,
|
||||
}
|
||||
|
||||
pub(crate) struct ClientTracker {
|
||||
clients: HashMap<ClientId, ClientState>,
|
||||
join_set: JoinSet<()>,
|
||||
server_event_tx: mpsc::Sender<QueuedServerEnvelope>,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
shutdown_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl ClientTracker {
|
||||
pub(crate) fn new(
|
||||
server_event_tx: mpsc::Sender<QueuedServerEnvelope>,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
shutdown_token: &CancellationToken,
|
||||
) -> Self {
|
||||
Self {
|
||||
clients: HashMap::new(),
|
||||
join_set: JoinSet::new(),
|
||||
server_event_tx,
|
||||
transport_event_tx,
|
||||
shutdown_token: shutdown_token.child_token(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn bookkeep_join_set(&mut self) {
|
||||
while self.join_set.join_next().await.is_some() {}
|
||||
futures::future::pending().await
|
||||
}
|
||||
|
||||
pub(crate) async fn shutdown(&mut self) {
|
||||
self.shutdown_token.cancel();
|
||||
|
||||
while let Some(client_id) = self.clients.keys().next().cloned() {
|
||||
let _ = self.close_client(&client_id).await;
|
||||
}
|
||||
|
||||
self.drain_join_set().await;
|
||||
}
|
||||
|
||||
async fn drain_join_set(&mut self) {
|
||||
while self.join_set.join_next().await.is_some() {}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_message(
|
||||
&mut self,
|
||||
client_envelope: ClientEnvelope,
|
||||
) -> Result<(), Stopped> {
|
||||
let ClientEnvelope {
|
||||
client_id,
|
||||
event,
|
||||
seq_id,
|
||||
cursor: _,
|
||||
} = client_envelope;
|
||||
match event {
|
||||
ClientEvent::ClientMessage { message } => {
|
||||
let is_initialize = remote_control_message_starts_connection(&message);
|
||||
if let Some(seq_id) = seq_id
|
||||
&& let Some(client) = self.clients.get(&client_id)
|
||||
&& client
|
||||
.last_inbound_seq_id
|
||||
.is_some_and(|last_seq_id| last_seq_id >= seq_id)
|
||||
&& !is_initialize
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if is_initialize && self.clients.contains_key(&client_id) {
|
||||
self.close_client(&client_id).await?;
|
||||
}
|
||||
|
||||
if let Some(connection_id) = self.clients.get_mut(&client_id).map(|client| {
|
||||
client.last_activity_at = Instant::now();
|
||||
if let Some(seq_id) = seq_id {
|
||||
client.last_inbound_seq_id = Some(seq_id);
|
||||
}
|
||||
client.connection_id
|
||||
}) {
|
||||
self.transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| Stopped)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !is_initialize {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let connection_id = next_connection_id();
|
||||
let (writer_tx, writer_rx) =
|
||||
mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let disconnect_token = self.shutdown_token.child_token();
|
||||
self.transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer: writer_tx,
|
||||
disconnect_sender: Some(disconnect_token.clone()),
|
||||
})
|
||||
.await
|
||||
.map_err(|_| Stopped)?;
|
||||
|
||||
let (status_tx, status_rx) = watch::channel(PongStatus::Active);
|
||||
self.join_set.spawn(Self::run_client_outbound(
|
||||
client_id.clone(),
|
||||
self.server_event_tx.clone(),
|
||||
writer_rx,
|
||||
status_rx,
|
||||
disconnect_token.clone(),
|
||||
));
|
||||
self.clients.insert(
|
||||
client_id,
|
||||
ClientState {
|
||||
connection_id,
|
||||
disconnect_token,
|
||||
last_activity_at: Instant::now(),
|
||||
last_inbound_seq_id: seq_id,
|
||||
status_tx,
|
||||
},
|
||||
);
|
||||
self.send_transport_event(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message,
|
||||
})
|
||||
.await
|
||||
}
|
||||
ClientEvent::Ack => Ok(()),
|
||||
ClientEvent::Ping => {
|
||||
if let Some(client) = self.clients.get_mut(&client_id) {
|
||||
client.last_activity_at = Instant::now();
|
||||
let _ = client.status_tx.send(PongStatus::Active);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let server_event_tx = self.server_event_tx.clone();
|
||||
self.join_set.spawn(async move {
|
||||
let server_envelope = QueuedServerEnvelope {
|
||||
event: ServerEvent::Pong {
|
||||
status: PongStatus::Unknown,
|
||||
},
|
||||
client_id,
|
||||
write_complete_tx: None,
|
||||
};
|
||||
let _ = server_event_tx.send(server_envelope).await;
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
ClientEvent::ClientClosed => self.close_client(&client_id).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_client_outbound(
|
||||
client_id: ClientId,
|
||||
server_event_tx: mpsc::Sender<QueuedServerEnvelope>,
|
||||
mut writer_rx: mpsc::Receiver<QueuedOutgoingMessage>,
|
||||
mut status_rx: watch::Receiver<PongStatus>,
|
||||
disconnect_token: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
let (event, write_complete_tx) = tokio::select! {
|
||||
_ = disconnect_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
queued_message = writer_rx.recv() => {
|
||||
let Some(queued_message) = queued_message else {
|
||||
break;
|
||||
};
|
||||
let event = ServerEvent::ServerMessage {
|
||||
message: Box::new(queued_message.message),
|
||||
};
|
||||
(event, queued_message.write_complete_tx)
|
||||
}
|
||||
changed = status_rx.changed() => {
|
||||
if changed.is_err() {
|
||||
break;
|
||||
}
|
||||
let event = ServerEvent::Pong { status: status_rx.borrow().clone() };
|
||||
(event, None)
|
||||
}
|
||||
};
|
||||
let send_result = tokio::select! {
|
||||
_ = disconnect_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
send_result = server_event_tx.send(QueuedServerEnvelope {
|
||||
event,
|
||||
client_id: client_id.clone(),
|
||||
write_complete_tx,
|
||||
}) => send_result,
|
||||
};
|
||||
if send_result.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn close_expired_clients(&mut self) -> Result<Vec<ClientId>, Stopped> {
|
||||
let now = Instant::now();
|
||||
let expired_client_ids: Vec<ClientId> = self
|
||||
.clients
|
||||
.iter()
|
||||
.filter_map(|(client_id, client)| {
|
||||
(!remote_control_client_is_alive(client, now)).then_some(client_id.clone())
|
||||
})
|
||||
.collect();
|
||||
for client_id in &expired_client_ids {
|
||||
self.close_client(client_id).await?;
|
||||
}
|
||||
Ok(expired_client_ids)
|
||||
}
|
||||
|
||||
async fn close_client(&mut self, client_id: &ClientId) -> Result<(), Stopped> {
|
||||
let Some(client) = self.clients.remove(client_id) else {
|
||||
return Ok(());
|
||||
};
|
||||
client.disconnect_token.cancel();
|
||||
self.send_transport_event(TransportEvent::ConnectionClosed {
|
||||
connection_id: client.connection_id,
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_transport_event(&self, event: TransportEvent) -> Result<(), Stopped> {
|
||||
self.transport_event_tx
|
||||
.send(event)
|
||||
.await
|
||||
.map_err(|_| Stopped)
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_control_message_starts_connection(message: &JSONRPCMessage) -> bool {
|
||||
matches!(
|
||||
message,
|
||||
JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { method, .. })
|
||||
if method == "initialize"
|
||||
)
|
||||
}
|
||||
|
||||
fn remote_control_client_is_alive(client: &ClientState, now: Instant) -> bool {
|
||||
now.duration_since(client.last_activity_at) < REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::transport::remote_control::protocol::ClientEnvelope;
|
||||
use crate::transport::remote_control::protocol::ClientEvent;
|
||||
use codex_app_server_protocol::ConfigWarningNotification;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tokio::time::timeout;
|
||||
|
||||
fn initialize_envelope(client_id: &str) -> ClientEnvelope {
|
||||
ClientEnvelope {
|
||||
event: ClientEvent::ClientMessage {
|
||||
message: JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(1),
|
||||
method: "initialize".to_string(),
|
||||
params: Some(json!({
|
||||
"clientInfo": {
|
||||
"name": "remote-test-client",
|
||||
"version": "0.1.0"
|
||||
}
|
||||
})),
|
||||
trace: None,
|
||||
}),
|
||||
},
|
||||
client_id: ClientId(client_id.to_string()),
|
||||
seq_id: Some(0),
|
||||
cursor: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancelled_outbound_task_emits_connection_closed() {
|
||||
let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let mut client_tracker =
|
||||
ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token);
|
||||
|
||||
client_tracker
|
||||
.handle_message(initialize_envelope("client-1"))
|
||||
.await
|
||||
.expect("initialize should open client");
|
||||
|
||||
let (connection_id, disconnect_sender) = match transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("connection opened should be sent")
|
||||
{
|
||||
TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
disconnect_sender: Some(disconnect_sender),
|
||||
..
|
||||
} => (connection_id, disconnect_sender),
|
||||
other => panic!("expected connection opened, got {other:?}"),
|
||||
};
|
||||
match transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("initialize should be forwarded")
|
||||
{
|
||||
TransportEvent::IncomingMessage {
|
||||
connection_id: incoming_connection_id,
|
||||
..
|
||||
} => assert_eq!(incoming_connection_id, connection_id),
|
||||
other => panic!("expected incoming initialize, got {other:?}"),
|
||||
}
|
||||
|
||||
disconnect_sender.cancel();
|
||||
timeout(Duration::from_secs(1), client_tracker.bookkeep_join_set())
|
||||
.await
|
||||
.expect_err("bookkeeping should process the closed task and stay pending");
|
||||
|
||||
match transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("connection closed should be sent")
|
||||
{
|
||||
TransportEvent::ConnectionClosed {
|
||||
connection_id: closed_connection_id,
|
||||
} => assert_eq!(closed_connection_id, connection_id),
|
||||
other => panic!("expected connection closed, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_cancels_blocked_outbound_forwarding() {
|
||||
let (server_event_tx, _server_event_rx) = mpsc::channel(1);
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let mut client_tracker =
|
||||
ClientTracker::new(server_event_tx.clone(), transport_event_tx, &shutdown_token);
|
||||
|
||||
server_event_tx
|
||||
.send(QueuedServerEnvelope {
|
||||
event: ServerEvent::Pong {
|
||||
status: PongStatus::Unknown,
|
||||
},
|
||||
client_id: ClientId("queued-client".to_string()),
|
||||
write_complete_tx: None,
|
||||
})
|
||||
.await
|
||||
.expect("server event queue should accept prefill");
|
||||
|
||||
client_tracker
|
||||
.handle_message(initialize_envelope("client-1"))
|
||||
.await
|
||||
.expect("initialize should open client");
|
||||
|
||||
let writer = match transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("connection opened should be sent")
|
||||
{
|
||||
TransportEvent::ConnectionOpened { writer, .. } => writer,
|
||||
other => panic!("expected connection opened, got {other:?}"),
|
||||
};
|
||||
let _ = transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("initialize should be forwarded");
|
||||
|
||||
writer
|
||||
.send(QueuedOutgoingMessage::new(
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "test".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
))
|
||||
.await
|
||||
.expect("writer should accept queued message");
|
||||
|
||||
timeout(Duration::from_secs(1), client_tracker.shutdown())
|
||||
.await
|
||||
.expect("shutdown should not hang on blocked server forwarding");
|
||||
}
|
||||
}
|
||||
399
codex-rs/app-server/src/transport/remote_control/enroll.rs
Normal file
399
codex-rs/app-server/src/transport/remote_control/enroll.rs
Normal file
@@ -0,0 +1,399 @@
|
||||
use super::protocol::EnrollRemoteServerRequest;
|
||||
use super::protocol::EnrollRemoteServerResponse;
|
||||
use super::protocol::RemoteControlTarget;
|
||||
use axum::http::HeaderMap;
|
||||
use codex_core::default_client::build_reqwest_client;
|
||||
use codex_state::StateRuntime;
|
||||
use gethostname::gethostname;
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use tracing::warn;
|
||||
|
||||
const REMOTE_CONTROL_ENROLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
|
||||
const REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES: usize = 4096;
|
||||
|
||||
const REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
const OAI_REQUEST_ID_HEADER: &str = "x-oai-request-id";
|
||||
const CF_RAY_HEADER: &str = "cf-ray";
|
||||
pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(super) struct RemoteControlEnrollment {
|
||||
pub(super) account_id: Option<String>,
|
||||
pub(super) server_id: String,
|
||||
pub(super) server_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(super) struct RemoteControlConnectionAuth {
|
||||
pub(super) bearer_token: String,
|
||||
pub(super) account_id: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) async fn load_persisted_remote_control_enrollment(
|
||||
state_db: Option<&StateRuntime>,
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
account_id: Option<&str>,
|
||||
) -> Option<RemoteControlEnrollment> {
|
||||
let state_db = state_db?;
|
||||
let enrollment = match state_db
|
||||
.get_remote_control_enrollment(&remote_control_target.websocket_url, account_id)
|
||||
.await
|
||||
{
|
||||
Ok(enrollment) => enrollment,
|
||||
Err(err) => {
|
||||
warn!("{err}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
enrollment.map(|(server_id, server_name)| RemoteControlEnrollment {
|
||||
account_id: account_id.map(&str::to_string),
|
||||
server_id,
|
||||
server_name,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) async fn update_persisted_remote_control_enrollment(
|
||||
state_db: Option<&StateRuntime>,
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
account_id: Option<&str>,
|
||||
enrollment: Option<&RemoteControlEnrollment>,
|
||||
) -> io::Result<()> {
|
||||
let Some(state_db) = state_db else {
|
||||
return Ok(());
|
||||
};
|
||||
if let &Some(enrollment) = &enrollment
|
||||
&& enrollment.account_id.as_deref() != account_id
|
||||
{
|
||||
return Err(io::Error::other(format!(
|
||||
"enrollment account_id does not match expected account_id `{account_id:?}`"
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(enrollment) = enrollment {
|
||||
state_db
|
||||
.upsert_remote_control_enrollment(
|
||||
&remote_control_target.websocket_url,
|
||||
account_id,
|
||||
&enrollment.server_id,
|
||||
&enrollment.server_name,
|
||||
)
|
||||
.await
|
||||
.map_err(io::Error::other)
|
||||
} else {
|
||||
state_db
|
||||
.delete_remote_control_enrollment(&remote_control_target.websocket_url, account_id)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(io::Error::other)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn preview_remote_control_response_body(body: &[u8]) -> String {
|
||||
let body = String::from_utf8_lossy(body);
|
||||
let trimmed = body.trim();
|
||||
if trimmed.is_empty() {
|
||||
return "<empty>".to_string();
|
||||
}
|
||||
if trimmed.len() <= REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
|
||||
let mut cut = REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES;
|
||||
while !trimmed.is_char_boundary(cut) {
|
||||
cut = cut.saturating_sub(1);
|
||||
}
|
||||
let mut truncated = trimmed[..cut].to_string();
|
||||
truncated.push_str("...");
|
||||
truncated
|
||||
}
|
||||
|
||||
pub(crate) fn format_headers(headers: &HeaderMap) -> String {
|
||||
let request_id_str = headers
|
||||
.get(REQUEST_ID_HEADER)
|
||||
.or_else(|| headers.get(OAI_REQUEST_ID_HEADER))
|
||||
.map(|value| value.to_str().unwrap_or("<invalid utf-8>").to_owned())
|
||||
.unwrap_or_else(|| "<none>".to_owned());
|
||||
let cf_ray_str = headers
|
||||
.get(CF_RAY_HEADER)
|
||||
.map(|value| value.to_str().unwrap_or("<invalid utf-8>").to_owned())
|
||||
.unwrap_or_else(|| "<none>".to_owned());
|
||||
format!("request-id: {request_id_str}, cf-ray: {cf_ray_str}")
|
||||
}
|
||||
|
||||
pub(super) async fn enroll_remote_control_server(
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
auth: &RemoteControlConnectionAuth,
|
||||
) -> io::Result<RemoteControlEnrollment> {
|
||||
let enroll_url = &remote_control_target.enroll_url;
|
||||
let server_name = gethostname().to_string_lossy().trim().to_string();
|
||||
let request = EnrollRemoteServerRequest {
|
||||
name: server_name.clone(),
|
||||
os: std::env::consts::OS,
|
||||
arch: std::env::consts::ARCH,
|
||||
app_server_version: env!("CARGO_PKG_VERSION"),
|
||||
};
|
||||
let client = build_reqwest_client();
|
||||
let mut http_request = client
|
||||
.post(enroll_url)
|
||||
.timeout(REMOTE_CONTROL_ENROLL_TIMEOUT)
|
||||
.bearer_auth(&auth.bearer_token)
|
||||
.json(&request);
|
||||
let account_id = auth.account_id.as_deref();
|
||||
if let Some(account_id) = account_id {
|
||||
http_request = http_request.header(REMOTE_CONTROL_ACCOUNT_ID_HEADER, account_id);
|
||||
}
|
||||
|
||||
let response = http_request.send().await.map_err(|err| {
|
||||
io::Error::other(format!(
|
||||
"failed to enroll remote control server at `{enroll_url}`: {err}"
|
||||
))
|
||||
})?;
|
||||
let headers = response.headers().clone();
|
||||
let status = response.status();
|
||||
let body = response.bytes().await.map_err(|err| {
|
||||
io::Error::other(format!(
|
||||
"failed to read remote control enrollment response from `{enroll_url}`: {err}"
|
||||
))
|
||||
})?;
|
||||
let body_preview = preview_remote_control_response_body(&body);
|
||||
if !status.is_success() {
|
||||
let headers_str = format_headers(&headers);
|
||||
let error_kind = if matches!(status.as_u16(), 401 | 403) {
|
||||
ErrorKind::PermissionDenied
|
||||
} else {
|
||||
ErrorKind::Other
|
||||
};
|
||||
return Err(io::Error::new(
|
||||
error_kind,
|
||||
format!(
|
||||
"remote control server enrollment failed at `{enroll_url}`: HTTP {status}, {headers_str}, body: {body_preview}"
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let enrollment = serde_json::from_slice::<EnrollRemoteServerResponse>(&body).map_err(|err| {
|
||||
let headers_str = format_headers(&headers);
|
||||
io::Error::other(format!(
|
||||
"failed to parse remote control enrollment response from `{enroll_url}`: HTTP {status}, {headers_str}, body: {body_preview}, decode error: {err}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(RemoteControlEnrollment {
|
||||
account_id: account_id.map(&str::to_string),
|
||||
server_id: enrollment.server_id,
|
||||
server_name,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::transport::remote_control::protocol::normalize_remote_control_url;
|
||||
use codex_state::StateRuntime;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use tempfile::TempDir;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc<StateRuntime> {
|
||||
StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string())
|
||||
.await
|
||||
.expect("state runtime should initialize")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persisted_remote_control_enrollment_round_trips_by_target_and_account() {
|
||||
let codex_home = TempDir::new().expect("temp dir should create");
|
||||
let state_db = remote_control_state_runtime(&codex_home).await;
|
||||
let first_target = normalize_remote_control_url("http://example.com/remote/control")
|
||||
.expect("first target should parse");
|
||||
let second_target = normalize_remote_control_url("http://example.com/other/control")
|
||||
.expect("second target should parse");
|
||||
let first_enrollment = RemoteControlEnrollment {
|
||||
account_id: Some("account-a".to_string()),
|
||||
server_id: "srv_e_first".to_string(),
|
||||
server_name: "first-server".to_string(),
|
||||
};
|
||||
let second_enrollment = RemoteControlEnrollment {
|
||||
account_id: Some("account-a".to_string()),
|
||||
server_id: "srv_e_second".to_string(),
|
||||
server_name: "second-server".to_string(),
|
||||
};
|
||||
|
||||
update_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&first_target,
|
||||
Some("account-a"),
|
||||
Some(&first_enrollment),
|
||||
)
|
||||
.await
|
||||
.expect("first enrollment should persist");
|
||||
update_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&second_target,
|
||||
Some("account-a"),
|
||||
Some(&second_enrollment),
|
||||
)
|
||||
.await
|
||||
.expect("second enrollment should persist");
|
||||
|
||||
assert_eq!(
|
||||
load_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&first_target,
|
||||
Some("account-a"),
|
||||
)
|
||||
.await,
|
||||
Some(first_enrollment.clone())
|
||||
);
|
||||
assert_eq!(
|
||||
load_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&first_target,
|
||||
Some("account-b"),
|
||||
)
|
||||
.await,
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
load_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&second_target,
|
||||
Some("account-a"),
|
||||
)
|
||||
.await,
|
||||
Some(second_enrollment)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn clearing_persisted_remote_control_enrollment_removes_only_matching_entry() {
|
||||
let codex_home = TempDir::new().expect("temp dir should create");
|
||||
let state_db = remote_control_state_runtime(&codex_home).await;
|
||||
let first_target = normalize_remote_control_url("http://example.com/remote/control")
|
||||
.expect("first target should parse");
|
||||
let second_target = normalize_remote_control_url("http://example.com/other/control")
|
||||
.expect("second target should parse");
|
||||
let first_enrollment = RemoteControlEnrollment {
|
||||
account_id: Some("account-a".to_string()),
|
||||
server_id: "srv_e_first".to_string(),
|
||||
server_name: "first-server".to_string(),
|
||||
};
|
||||
let second_enrollment = RemoteControlEnrollment {
|
||||
account_id: Some("account-a".to_string()),
|
||||
server_id: "srv_e_second".to_string(),
|
||||
server_name: "second-server".to_string(),
|
||||
};
|
||||
|
||||
update_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&first_target,
|
||||
Some("account-a"),
|
||||
Some(&first_enrollment),
|
||||
)
|
||||
.await
|
||||
.expect("first enrollment should persist");
|
||||
update_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&second_target,
|
||||
Some("account-a"),
|
||||
Some(&second_enrollment),
|
||||
)
|
||||
.await
|
||||
.expect("second enrollment should persist");
|
||||
|
||||
update_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&first_target,
|
||||
Some("account-a"),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("matching enrollment should clear");
|
||||
|
||||
assert_eq!(
|
||||
load_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&first_target,
|
||||
Some("account-a"),
|
||||
)
|
||||
.await,
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
load_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&second_target,
|
||||
Some("account-a"),
|
||||
)
|
||||
.await,
|
||||
Some(second_enrollment)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enroll_remote_control_server_parse_failure_includes_response_body() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let remote_control_url = format!(
|
||||
"http://{}/backend-api/",
|
||||
listener
|
||||
.local_addr()
|
||||
.expect("listener should have a local addr")
|
||||
);
|
||||
let remote_control_target =
|
||||
normalize_remote_control_url(&remote_control_url).expect("target should parse");
|
||||
let enroll_url = remote_control_target.enroll_url.clone();
|
||||
let response_body = json!({
|
||||
"error": "not enrolled",
|
||||
});
|
||||
let expected_body = response_body.to_string();
|
||||
let server_task = tokio::spawn(async move {
|
||||
let (stream, _) = timeout(Duration::from_secs(5), listener.accept())
|
||||
.await
|
||||
.expect("HTTP request should arrive in time")
|
||||
.expect("listener accept should succeed");
|
||||
respond_with_json(stream, response_body).await;
|
||||
});
|
||||
|
||||
let err = enroll_remote_control_server(
|
||||
&remote_control_target,
|
||||
&RemoteControlConnectionAuth {
|
||||
bearer_token: "Access Token".to_string(),
|
||||
account_id: Some("account_id".to_string()),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect_err("invalid response should fail to parse");
|
||||
|
||||
server_task.await.expect("server task should succeed");
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
format!(
|
||||
"failed to parse remote control enrollment response from `{enroll_url}`: HTTP 200 OK, request-id: <none>, cf-ray: <none>, body: {expected_body}, decode error: missing field `server_id` at line 1 column {}",
|
||||
expected_body.len()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
async fn respond_with_json(mut stream: TcpStream, body: serde_json::Value) {
|
||||
let body = body.to_string();
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
|
||||
body.len()
|
||||
);
|
||||
stream
|
||||
.write_all(response.as_bytes())
|
||||
.await
|
||||
.expect("response should write");
|
||||
stream.flush().await.expect("response should flush");
|
||||
}
|
||||
}
|
||||
59
codex-rs/app-server/src/transport/remote_control/mod.rs
Normal file
59
codex-rs/app-server/src/transport/remote_control/mod.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
mod client_tracker;
|
||||
mod enroll;
|
||||
mod protocol;
|
||||
mod websocket;
|
||||
|
||||
use crate::transport::remote_control::websocket::load_remote_control_auth;
|
||||
|
||||
pub use self::protocol::ClientId;
|
||||
use self::protocol::ServerEvent;
|
||||
use self::protocol::normalize_remote_control_url;
|
||||
use self::websocket::run_remote_control_websocket_loop;
|
||||
use super::CHANNEL_CAPACITY;
|
||||
use super::TransportEvent;
|
||||
use super::next_connection_id;
|
||||
use codex_core::AuthManager;
|
||||
use codex_state::StateRuntime;
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
pub(super) struct QueuedServerEnvelope {
|
||||
pub(super) event: ServerEvent,
|
||||
pub(super) client_id: ClientId,
|
||||
pub(super) write_complete_tx: Option<oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
pub(crate) async fn start_remote_control(
|
||||
remote_control_url: String,
|
||||
state_db: Option<Arc<StateRuntime>>,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
shutdown_token: CancellationToken,
|
||||
) -> io::Result<JoinHandle<()>> {
|
||||
let remote_control_target = normalize_remote_control_url(&remote_control_url)?;
|
||||
validate_remote_control_auth(&auth_manager).await?;
|
||||
|
||||
Ok(tokio::spawn(async move {
|
||||
run_remote_control_websocket_loop(
|
||||
remote_control_target,
|
||||
state_db,
|
||||
auth_manager,
|
||||
transport_event_tx,
|
||||
shutdown_token.child_token(),
|
||||
)
|
||||
.await;
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) async fn validate_remote_control_auth(
|
||||
auth_manager: &Arc<AuthManager>,
|
||||
) -> io::Result<()> {
|
||||
load_remote_control_auth(auth_manager).await.map(|_| ())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
139
codex-rs/app-server/src/transport/remote_control/protocol.rs
Normal file
139
codex-rs/app-server/src/transport/remote_control/protocol.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(super) struct RemoteControlTarget {
|
||||
pub(super) websocket_url: String,
|
||||
pub(super) enroll_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(super) struct EnrollRemoteServerRequest {
|
||||
pub(super) name: String,
|
||||
pub(super) os: &'static str,
|
||||
pub(super) arch: &'static str,
|
||||
pub(super) app_server_version: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(super) struct EnrollRemoteServerResponse {
|
||||
pub(super) server_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct ClientId(pub String);
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ClientEvent {
|
||||
ClientMessage { message: JSONRPCMessage },
|
||||
Ack,
|
||||
Ping,
|
||||
ClientClosed,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub(crate) struct ClientEnvelope {
|
||||
#[serde(flatten)]
|
||||
pub(crate) event: ClientEvent,
|
||||
#[serde(rename = "client_id", alias = "clientId")]
|
||||
pub(crate) client_id: ClientId,
|
||||
#[serde(
|
||||
rename = "seq_id",
|
||||
alias = "seqId",
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
pub(crate) seq_id: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PongStatus {
|
||||
Active,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ServerEvent {
|
||||
ServerMessage {
|
||||
message: Box<OutgoingMessage>,
|
||||
},
|
||||
#[allow(dead_code)]
|
||||
Ack,
|
||||
Pong {
|
||||
status: PongStatus,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub(crate) struct ServerEnvelope {
|
||||
#[serde(flatten)]
|
||||
pub(crate) event: ServerEvent,
|
||||
#[serde(rename = "client_id", alias = "clientId")]
|
||||
pub(crate) client_id: ClientId,
|
||||
#[serde(rename = "seq_id", alias = "seqId")]
|
||||
pub(crate) seq_id: u64,
|
||||
}
|
||||
|
||||
pub(super) fn normalize_remote_control_url(
|
||||
remote_control_url: &str,
|
||||
) -> io::Result<RemoteControlTarget> {
|
||||
let map_url_parse_error = |err: url::ParseError| -> io::Error {
|
||||
io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("invalid remote control URL `{remote_control_url}`: {err}"),
|
||||
)
|
||||
};
|
||||
let map_scheme_error = |_: ()| -> io::Error {
|
||||
io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!(
|
||||
"invalid remote control URL `{remote_control_url}`; expected absolute URL with http:// or https:// scheme"
|
||||
),
|
||||
)
|
||||
};
|
||||
|
||||
let mut remote_control_url = Url::parse(remote_control_url).map_err(map_url_parse_error)?;
|
||||
match remote_control_url.scheme() {
|
||||
"https" | "http" => {}
|
||||
_ => return Err(map_scheme_error(())),
|
||||
}
|
||||
if !remote_control_url.path().ends_with('/') {
|
||||
let normalized_path = format!("{}/", remote_control_url.path());
|
||||
remote_control_url.set_path(&normalized_path);
|
||||
}
|
||||
|
||||
let mut enroll_url = remote_control_url
|
||||
.join("wham/remote/control/server/enroll")
|
||||
.map_err(map_url_parse_error)?;
|
||||
let mut websocket_url = remote_control_url
|
||||
.join("wham/remote/control/server")
|
||||
.map_err(map_url_parse_error)?;
|
||||
match remote_control_url.scheme() {
|
||||
"https" => {
|
||||
enroll_url.set_scheme("https").map_err(map_scheme_error)?;
|
||||
websocket_url.set_scheme("wss").map_err(map_scheme_error)?;
|
||||
}
|
||||
"http" => {
|
||||
enroll_url.set_scheme("http").map_err(map_scheme_error)?;
|
||||
websocket_url.set_scheme("ws").map_err(map_scheme_error)?;
|
||||
}
|
||||
_ => return Err(map_scheme_error(())),
|
||||
}
|
||||
|
||||
Ok(RemoteControlTarget {
|
||||
websocket_url: websocket_url.to_string(),
|
||||
enroll_url: enroll_url.to_string(),
|
||||
})
|
||||
}
|
||||
1090
codex-rs/app-server/src/transport/remote_control/tests.rs
Normal file
1090
codex-rs/app-server/src/transport/remote_control/tests.rs
Normal file
File diff suppressed because it is too large
Load Diff
1094
codex-rs/app-server/src/transport/remote_control/websocket.rs
Normal file
1094
codex-rs/app-server/src/transport/remote_control/websocket.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,8 @@
|
||||
use super::CHANNEL_CAPACITY;
|
||||
use super::TransportEvent;
|
||||
use super::forward_incoming_message;
|
||||
use super::next_connection_id;
|
||||
use super::serialize_outgoing_message;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Result as IoResult;
|
||||
@@ -20,7 +20,7 @@ pub(crate) async fn start_stdio_connection(
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
stdio_handles: &mut Vec<JoinHandle<()>>,
|
||||
) -> IoResult<()> {
|
||||
let connection_id = ConnectionId(0);
|
||||
let connection_id = next_connection_id();
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let writer_tx_for_reader = writer_tx.clone();
|
||||
transport_event_tx
|
||||
|
||||
@@ -4,6 +4,7 @@ use super::auth::WebsocketAuthPolicy;
|
||||
use super::auth::authorize_upgrade;
|
||||
use super::auth::should_warn_about_unauthenticated_non_loopback_listener;
|
||||
use super::forward_incoming_message;
|
||||
use super::next_connection_id;
|
||||
use super::serialize_outgoing_message;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
@@ -32,8 +33,6 @@ use owo_colors::Style;
|
||||
use std::io::Result as IoResult;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
@@ -75,7 +74,6 @@ fn print_websocket_startup_banner(addr: SocketAddr) {
|
||||
#[derive(Clone)]
|
||||
struct WebSocketListenerState {
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
connection_counter: Arc<AtomicU64>,
|
||||
auth_policy: Arc<WebsocketAuthPolicy>,
|
||||
}
|
||||
|
||||
@@ -113,7 +111,7 @@ async fn websocket_upgrade_handler(
|
||||
);
|
||||
return (err.status_code(), err.message()).into_response();
|
||||
}
|
||||
let connection_id = ConnectionId(state.connection_counter.fetch_add(1, Ordering::Relaxed));
|
||||
let connection_id = next_connection_id();
|
||||
info!(%peer_addr, "websocket client connected");
|
||||
websocket
|
||||
.on_upgrade(move |stream| async move {
|
||||
@@ -146,7 +144,6 @@ pub(crate) async fn start_websocket_acceptor(
|
||||
.layer(middleware::from_fn(reject_requests_with_origin_header))
|
||||
.with_state(WebSocketListenerState {
|
||||
transport_event_tx,
|
||||
connection_counter: Arc::new(AtomicU64::new(1)),
|
||||
auth_policy: Arc::new(auth_policy),
|
||||
});
|
||||
let server = axum::serve(
|
||||
|
||||
@@ -328,7 +328,7 @@ struct AppServerCommand {
|
||||
subcommand: Option<AppServerSubcommand>,
|
||||
|
||||
/// Transport endpoint URL. Supported values: `stdio://` (default),
|
||||
/// `ws://IP:PORT`.
|
||||
/// `ws://IP:PORT`, `off`.
|
||||
#[arg(
|
||||
long = "listen",
|
||||
value_name = "URL",
|
||||
@@ -1930,6 +1930,12 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_listen_off_parses() {
|
||||
let app_server = app_server_from_args(["codex", "app-server", "--listen", "off"].as_ref());
|
||||
assert_eq!(app_server.listen, codex_app_server::AppServerTransport::Off);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_listen_invalid_url_fails_to_parse() {
|
||||
let parse_result =
|
||||
|
||||
@@ -434,6 +434,9 @@
|
||||
"realtime_conversation": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"remote_control": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"remote_models": {
|
||||
"type": "boolean"
|
||||
},
|
||||
@@ -2071,6 +2074,9 @@
|
||||
"realtime_conversation": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"remote_control": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"remote_models": {
|
||||
"type": "boolean"
|
||||
},
|
||||
|
||||
@@ -176,6 +176,8 @@ pub enum Feature {
|
||||
VoiceTranscription,
|
||||
/// Enable experimental realtime voice conversation mode in the TUI.
|
||||
RealtimeConversation,
|
||||
/// Connect app-server to the ChatGPT remote control service.
|
||||
RemoteControl,
|
||||
/// Route interactive startup to the app-server-backed TUI implementation.
|
||||
TuiAppServer,
|
||||
/// Prevent idle system sleep while a turn is actively running.
|
||||
@@ -819,6 +821,12 @@ pub const FEATURES: &[FeatureSpec] = &[
|
||||
stage: Stage::UnderDevelopment,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::RemoteControl,
|
||||
key: "remote_control",
|
||||
stage: Stage::UnderDevelopment,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::TuiAppServer,
|
||||
key: "tui_app_server",
|
||||
|
||||
@@ -159,6 +159,12 @@ fn image_detail_original_feature_is_under_development() {
|
||||
assert_eq!(Feature::ImageDetailOriginal.default_enabled(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remote_control_is_under_development() {
|
||||
assert_eq!(Feature::RemoteControl.stage(), Stage::UnderDevelopment);
|
||||
assert_eq!(Feature::RemoteControl.default_enabled(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collab_is_legacy_alias_for_multi_agent() {
|
||||
assert_eq!(feature_for_key("multi_agent"), Some(Feature::Collab));
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
CREATE TABLE remote_control_enrollments (
|
||||
websocket_url TEXT NOT NULL,
|
||||
account_id TEXT NOT NULL,
|
||||
server_id TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
PRIMARY KEY (websocket_url, account_id)
|
||||
);
|
||||
@@ -53,6 +53,7 @@ mod agent_jobs;
|
||||
mod backfill;
|
||||
mod logs;
|
||||
mod memories;
|
||||
mod remote_control;
|
||||
#[cfg(test)]
|
||||
mod test_support;
|
||||
mod threads;
|
||||
|
||||
197
codex-rs/state/src/runtime/remote_control.rs
Normal file
197
codex-rs/state/src/runtime/remote_control.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
use super::*;
|
||||
|
||||
const REMOTE_CONTROL_ACCOUNT_ID_NONE: &str = "";
|
||||
|
||||
fn remote_control_account_id_key(account_id: Option<&str>) -> &str {
|
||||
account_id.unwrap_or(REMOTE_CONTROL_ACCOUNT_ID_NONE)
|
||||
}
|
||||
|
||||
impl StateRuntime {
|
||||
pub async fn get_remote_control_enrollment(
|
||||
&self,
|
||||
websocket_url: &str,
|
||||
account_id: Option<&str>,
|
||||
) -> anyhow::Result<Option<(String, String)>> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT server_id, server_name
|
||||
FROM remote_control_enrollments
|
||||
WHERE websocket_url = ? AND account_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(websocket_url)
|
||||
.bind(remote_control_account_id_key(account_id))
|
||||
.fetch_optional(self.pool.as_ref())
|
||||
.await?;
|
||||
|
||||
row.map(|row| Ok((row.try_get("server_id")?, row.try_get("server_name")?)))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
pub async fn upsert_remote_control_enrollment(
|
||||
&self,
|
||||
websocket_url: &str,
|
||||
account_id: Option<&str>,
|
||||
server_id: &str,
|
||||
server_name: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO remote_control_enrollments (
|
||||
websocket_url,
|
||||
account_id,
|
||||
server_id,
|
||||
server_name,
|
||||
updated_at
|
||||
) VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(websocket_url, account_id) DO UPDATE SET
|
||||
server_id = excluded.server_id,
|
||||
server_name = excluded.server_name,
|
||||
updated_at = excluded.updated_at
|
||||
"#,
|
||||
)
|
||||
.bind(websocket_url)
|
||||
.bind(remote_control_account_id_key(account_id))
|
||||
.bind(server_id)
|
||||
.bind(server_name)
|
||||
.bind(Utc::now().timestamp())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_remote_control_enrollment(
|
||||
&self,
|
||||
websocket_url: &str,
|
||||
account_id: Option<&str>,
|
||||
) -> anyhow::Result<u64> {
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
DELETE FROM remote_control_enrollments
|
||||
WHERE websocket_url = ? AND account_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(websocket_url)
|
||||
.bind(remote_control_account_id_key(account_id))
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::StateRuntime;
|
||||
use super::test_support::unique_temp_dir;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_control_enrollment_round_trips_by_target_and_account() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string())
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
runtime
|
||||
.upsert_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
Some("account-a"),
|
||||
"srv_e_first",
|
||||
"first-server",
|
||||
)
|
||||
.await
|
||||
.expect("insert first enrollment");
|
||||
runtime
|
||||
.upsert_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
Some("account-b"),
|
||||
"srv_e_second",
|
||||
"second-server",
|
||||
)
|
||||
.await
|
||||
.expect("insert second enrollment");
|
||||
|
||||
assert_eq!(
|
||||
runtime
|
||||
.get_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
Some("account-a"),
|
||||
)
|
||||
.await
|
||||
.expect("load first enrollment"),
|
||||
Some(("srv_e_first".to_string(), "first-server".to_string()))
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.get_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("load missing enrollment"),
|
||||
None
|
||||
);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn delete_remote_control_enrollment_removes_only_matching_entry() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string())
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
runtime
|
||||
.upsert_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
None,
|
||||
"srv_e_first",
|
||||
"first-server",
|
||||
)
|
||||
.await
|
||||
.expect("insert first enrollment");
|
||||
runtime
|
||||
.upsert_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
Some("account-a"),
|
||||
"srv_e_second",
|
||||
"second-server",
|
||||
)
|
||||
.await
|
||||
.expect("insert second enrollment");
|
||||
|
||||
assert_eq!(
|
||||
runtime
|
||||
.delete_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("delete first enrollment"),
|
||||
1
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.get_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("load deleted enrollment"),
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.get_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
Some("account-a"),
|
||||
)
|
||||
.await
|
||||
.expect("load retained enrollment"),
|
||||
Some(("srv_e_second".to_string(), "second-server".to_string()))
|
||||
);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user