mirror of
https://github.com/openai/codex.git
synced 2026-06-01 19:02:59 +00:00
app-server: move transport into dedicated crate (#20545)
## Why `codex-app-server` currently owns both request-processing code and transport implementation details. Splitting the transport layer into its own crate makes that boundary explicit, reduces the amount of transport-specific dependency surface carried by `codex-app-server`, and gives future transport work a narrower place to evolve. ## What changed - Added `codex-app-server-transport` and moved the existing transport tree into it, including stdio, unix socket, websocket, remote-control transport, and websocket auth. - Moved shared transport-facing message types into the new crate so both the transport implementation and `codex-app-server` use the same definitions. - Kept processor-facing connection state and outbound routing in `codex-app-server`, with the routing tests moved next to that local wrapper. - Updated workspace metadata, Bazel crate metadata, and `codex-app-server` dependencies for the new crate boundary. ## Validation - `cargo metadata --locked --no-deps` - `git diff --check` - Attempted `cargo test -p codex-app-server-transport`, `cargo test -p codex-app-server`, `just fix -p codex-app-server-transport`, and `just fix -p codex-app-server`; all were blocked before compilation by the existing `packageproxy` resolution failure for locked `rustls-webpki = 0.103.13`. - Attempted Bazel build / lockfile validation; those were blocked by external fetch failures against BuildBuddy / GitHub while resolving `v8`.
This commit is contained in:
committed by
GitHub
parent
5744b85b9a
commit
41e171fcf2
47
codex-rs/Cargo.lock
generated
47
codex-rs/Cargo.lock
generated
@@ -1857,8 +1857,8 @@ dependencies = [
|
|||||||
"chrono",
|
"chrono",
|
||||||
"clap",
|
"clap",
|
||||||
"codex-analytics",
|
"codex-analytics",
|
||||||
"codex-api",
|
|
||||||
"codex-app-server-protocol",
|
"codex-app-server-protocol",
|
||||||
|
"codex-app-server-transport",
|
||||||
"codex-arg0",
|
"codex-arg0",
|
||||||
"codex-backend-client",
|
"codex-backend-client",
|
||||||
"codex-chatgpt",
|
"codex-chatgpt",
|
||||||
@@ -1891,23 +1891,17 @@ dependencies = [
|
|||||||
"codex-state",
|
"codex-state",
|
||||||
"codex-thread-store",
|
"codex-thread-store",
|
||||||
"codex-tools",
|
"codex-tools",
|
||||||
"codex-uds",
|
|
||||||
"codex-utils-absolute-path",
|
"codex-utils-absolute-path",
|
||||||
"codex-utils-cargo-bin",
|
"codex-utils-cargo-bin",
|
||||||
"codex-utils-cli",
|
"codex-utils-cli",
|
||||||
"codex-utils-json-to-toml",
|
"codex-utils-json-to-toml",
|
||||||
"codex-utils-pty",
|
"codex-utils-pty",
|
||||||
"codex-utils-rustls-provider",
|
|
||||||
"constant_time_eq 0.3.1",
|
|
||||||
"core_test_support",
|
"core_test_support",
|
||||||
"flate2",
|
"flate2",
|
||||||
"futures",
|
"futures",
|
||||||
"gethostname",
|
|
||||||
"hmac",
|
"hmac",
|
||||||
"jsonwebtoken",
|
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"opentelemetry_sdk",
|
"opentelemetry_sdk",
|
||||||
"owo-colors",
|
|
||||||
"pretty_assertions",
|
"pretty_assertions",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"rmcp",
|
"rmcp",
|
||||||
@@ -2005,6 +1999,45 @@ dependencies = [
|
|||||||
"uuid",
|
"uuid",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "codex-app-server-transport"
|
||||||
|
version = "0.0.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"axum",
|
||||||
|
"base64 0.22.1",
|
||||||
|
"chrono",
|
||||||
|
"clap",
|
||||||
|
"codex-api",
|
||||||
|
"codex-app-server-protocol",
|
||||||
|
"codex-config",
|
||||||
|
"codex-core",
|
||||||
|
"codex-login",
|
||||||
|
"codex-model-provider",
|
||||||
|
"codex-state",
|
||||||
|
"codex-uds",
|
||||||
|
"codex-utils-absolute-path",
|
||||||
|
"codex-utils-rustls-provider",
|
||||||
|
"constant_time_eq 0.3.1",
|
||||||
|
"futures",
|
||||||
|
"gethostname",
|
||||||
|
"hmac",
|
||||||
|
"jsonwebtoken",
|
||||||
|
"owo-colors",
|
||||||
|
"pretty_assertions",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"sha2",
|
||||||
|
"tempfile",
|
||||||
|
"time",
|
||||||
|
"tokio",
|
||||||
|
"tokio-tungstenite",
|
||||||
|
"tokio-util",
|
||||||
|
"tracing",
|
||||||
|
"url",
|
||||||
|
"uuid",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "codex-apply-patch"
|
name = "codex-apply-patch"
|
||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ members = [
|
|||||||
"ansi-escape",
|
"ansi-escape",
|
||||||
"async-utils",
|
"async-utils",
|
||||||
"app-server",
|
"app-server",
|
||||||
|
"app-server-transport",
|
||||||
"app-server-client",
|
"app-server-client",
|
||||||
"app-server-protocol",
|
"app-server-protocol",
|
||||||
"app-server-test-client",
|
"app-server-test-client",
|
||||||
@@ -127,6 +128,7 @@ codex-ansi-escape = { path = "ansi-escape" }
|
|||||||
codex-api = { path = "codex-api" }
|
codex-api = { path = "codex-api" }
|
||||||
codex-aws-auth = { path = "aws-auth" }
|
codex-aws-auth = { path = "aws-auth" }
|
||||||
codex-app-server = { path = "app-server" }
|
codex-app-server = { path = "app-server" }
|
||||||
|
codex-app-server-transport = { path = "app-server-transport" }
|
||||||
codex-app-server-client = { path = "app-server-client" }
|
codex-app-server-client = { path = "app-server-client" }
|
||||||
codex-app-server-protocol = { path = "app-server-protocol" }
|
codex-app-server-protocol = { path = "app-server-protocol" }
|
||||||
codex-app-server-test-client = { path = "app-server-test-client" }
|
codex-app-server-test-client = { path = "app-server-test-client" }
|
||||||
|
|||||||
6
codex-rs/app-server-transport/BUILD.bazel
Normal file
6
codex-rs/app-server-transport/BUILD.bazel
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
load("//:defs.bzl", "codex_rust_crate")
|
||||||
|
|
||||||
|
codex_rust_crate(
|
||||||
|
name = "app-server-transport",
|
||||||
|
crate_name = "codex_app_server_transport",
|
||||||
|
)
|
||||||
58
codex-rs/app-server-transport/Cargo.toml
Normal file
58
codex-rs/app-server-transport/Cargo.toml
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
[package]
|
||||||
|
name = "codex-app-server-transport"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
name = "codex_app_server_transport"
|
||||||
|
path = "src/lib.rs"
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = { workspace = true }
|
||||||
|
axum = { workspace = true, default-features = false, features = [
|
||||||
|
"http1",
|
||||||
|
"json",
|
||||||
|
"tokio",
|
||||||
|
"ws",
|
||||||
|
] }
|
||||||
|
base64 = { workspace = true }
|
||||||
|
clap = { workspace = true, features = ["derive"] }
|
||||||
|
codex-api = { workspace = true }
|
||||||
|
codex-app-server-protocol = { workspace = true }
|
||||||
|
codex-core = { workspace = true }
|
||||||
|
codex-login = { workspace = true }
|
||||||
|
codex-model-provider = { workspace = true }
|
||||||
|
codex-state = { workspace = true }
|
||||||
|
codex-uds = { workspace = true }
|
||||||
|
codex-utils-absolute-path = { workspace = true }
|
||||||
|
codex-utils-rustls-provider = { workspace = true }
|
||||||
|
constant_time_eq = { workspace = true }
|
||||||
|
futures = { workspace = true }
|
||||||
|
gethostname = { workspace = true }
|
||||||
|
hmac = { workspace = true }
|
||||||
|
jsonwebtoken = { workspace = true }
|
||||||
|
owo-colors = { workspace = true, features = ["supports-colors"] }
|
||||||
|
serde = { workspace = true, features = ["derive"] }
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
sha2 = { workspace = true }
|
||||||
|
time = { workspace = true }
|
||||||
|
tokio = { workspace = true, features = [
|
||||||
|
"io-std",
|
||||||
|
"macros",
|
||||||
|
"rt-multi-thread",
|
||||||
|
] }
|
||||||
|
tokio-tungstenite = { workspace = true }
|
||||||
|
tokio-util = { workspace = true }
|
||||||
|
tracing = { workspace = true, features = ["log"] }
|
||||||
|
url = { workspace = true }
|
||||||
|
uuid = { workspace = true, features = ["serde", "v7"] }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
chrono = { workspace = true }
|
||||||
|
codex-config = { workspace = true }
|
||||||
|
pretty_assertions = { workspace = true }
|
||||||
|
tempfile = { workspace = true }
|
||||||
20
codex-rs/app-server-transport/src/lib.rs
Normal file
20
codex-rs/app-server-transport/src/lib.rs
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
mod outgoing_message;
|
||||||
|
mod transport;
|
||||||
|
|
||||||
|
pub use outgoing_message::ConnectionId;
|
||||||
|
pub use outgoing_message::OutgoingError;
|
||||||
|
pub use outgoing_message::OutgoingMessage;
|
||||||
|
pub use outgoing_message::OutgoingResponse;
|
||||||
|
pub use outgoing_message::QueuedOutgoingMessage;
|
||||||
|
pub use transport::AppServerTransport;
|
||||||
|
pub use transport::AppServerTransportParseError;
|
||||||
|
pub use transport::CHANNEL_CAPACITY;
|
||||||
|
pub use transport::ConnectionOrigin;
|
||||||
|
pub use transport::RemoteControlHandle;
|
||||||
|
pub use transport::TransportEvent;
|
||||||
|
pub use transport::app_server_control_socket_path;
|
||||||
|
pub use transport::auth;
|
||||||
|
pub use transport::start_control_socket_acceptor;
|
||||||
|
pub use transport::start_remote_control;
|
||||||
|
pub use transport::start_stdio_connection;
|
||||||
|
pub use transport::start_websocket_acceptor;
|
||||||
58
codex-rs/app-server-transport/src/outgoing_message.rs
Normal file
58
codex-rs/app-server-transport/src/outgoing_message.rs
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
use codex_app_server_protocol::JSONRPCErrorError;
|
||||||
|
use codex_app_server_protocol::RequestId;
|
||||||
|
use codex_app_server_protocol::Result;
|
||||||
|
use codex_app_server_protocol::ServerNotification;
|
||||||
|
use codex_app_server_protocol::ServerRequest;
|
||||||
|
use serde::Serialize;
|
||||||
|
use tokio::sync::oneshot;
|
||||||
|
|
||||||
|
/// Stable identifier for a transport connection.
|
||||||
|
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
|
||||||
|
pub struct ConnectionId(pub u64);
|
||||||
|
|
||||||
|
impl fmt::Display for ConnectionId {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "{}", self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Outgoing message from the server to the client.
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum OutgoingMessage {
|
||||||
|
Request(ServerRequest),
|
||||||
|
/// AppServerNotification is specific to the case where this is run as an
|
||||||
|
/// "app server" as opposed to an MCP server.
|
||||||
|
AppServerNotification(ServerNotification),
|
||||||
|
Response(OutgoingResponse),
|
||||||
|
Error(OutgoingError),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||||
|
pub struct OutgoingResponse {
|
||||||
|
pub id: RequestId,
|
||||||
|
pub result: Result,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||||
|
pub struct OutgoingError {
|
||||||
|
pub error: JSONRPCErrorError,
|
||||||
|
pub id: RequestId,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct QueuedOutgoingMessage {
|
||||||
|
pub message: OutgoingMessage,
|
||||||
|
pub write_complete_tx: Option<oneshot::Sender<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QueuedOutgoingMessage {
|
||||||
|
pub fn new(message: OutgoingMessage) -> Self {
|
||||||
|
Self {
|
||||||
|
message,
|
||||||
|
write_complete_tx: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -86,7 +86,7 @@ pub enum AppServerWebsocketCapabilityTokenSource {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default)]
|
#[derive(Clone, Debug, Default)]
|
||||||
pub(crate) struct WebsocketAuthPolicy {
|
pub struct WebsocketAuthPolicy {
|
||||||
pub(crate) mode: Option<WebsocketAuthMode>,
|
pub(crate) mode: Option<WebsocketAuthMode>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -219,7 +219,7 @@ impl AppServerWebsocketAuthArgs {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn policy_from_settings(
|
pub fn policy_from_settings(
|
||||||
settings: &AppServerWebsocketAuthSettings,
|
settings: &AppServerWebsocketAuthSettings,
|
||||||
) -> io::Result<WebsocketAuthPolicy> {
|
) -> io::Result<WebsocketAuthPolicy> {
|
||||||
let mode = match settings.config.as_ref() {
|
let mode = match settings.config.as_ref() {
|
||||||
478
codex-rs/app-server-transport/src/transport/mod.rs
Normal file
478
codex-rs/app-server-transport/src/transport/mod.rs
Normal file
@@ -0,0 +1,478 @@
|
|||||||
|
pub mod auth;
|
||||||
|
|
||||||
|
use crate::outgoing_message::ConnectionId;
|
||||||
|
use crate::outgoing_message::OutgoingError;
|
||||||
|
use crate::outgoing_message::OutgoingMessage;
|
||||||
|
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||||
|
use codex_app_server_protocol::JSONRPCErrorError;
|
||||||
|
use codex_app_server_protocol::JSONRPCMessage;
|
||||||
|
use codex_core::config::find_codex_home;
|
||||||
|
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::str::FromStr;
|
||||||
|
use std::sync::atomic::AtomicU64;
|
||||||
|
use std::sync::atomic::Ordering;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
use tracing::error;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
/// Size of the bounded channels used to communicate between tasks. The value
|
||||||
|
/// is a balance between throughput and memory usage - 128 messages should be
|
||||||
|
/// plenty for an interactive CLI.
|
||||||
|
pub const CHANNEL_CAPACITY: usize = 128;
|
||||||
|
|
||||||
|
mod remote_control;
|
||||||
|
mod stdio;
|
||||||
|
mod unix_socket;
|
||||||
|
#[cfg(test)]
|
||||||
|
mod unix_socket_tests;
|
||||||
|
mod websocket;
|
||||||
|
|
||||||
|
pub use remote_control::RemoteControlHandle;
|
||||||
|
pub use remote_control::start_remote_control;
|
||||||
|
pub use stdio::start_stdio_connection;
|
||||||
|
pub use unix_socket::start_control_socket_acceptor;
|
||||||
|
pub use websocket::start_websocket_acceptor;
|
||||||
|
|
||||||
|
const OVERLOADED_ERROR_CODE: i64 = -32001;
|
||||||
|
|
||||||
|
const APP_SERVER_CONTROL_SOCKET_DIR_NAME: &str = "app-server-control";
|
||||||
|
const APP_SERVER_CONTROL_SOCKET_FILE_NAME: &str = "app-server-control.sock";
|
||||||
|
|
||||||
|
pub fn app_server_control_socket_path(codex_home: &Path) -> std::io::Result<AbsolutePathBuf> {
|
||||||
|
AbsolutePathBuf::from_absolute_path(
|
||||||
|
codex_home
|
||||||
|
.join(APP_SERVER_CONTROL_SOCKET_DIR_NAME)
|
||||||
|
.join(APP_SERVER_CONTROL_SOCKET_FILE_NAME),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||||
|
pub enum AppServerTransport {
|
||||||
|
Stdio,
|
||||||
|
UnixSocket { socket_path: AbsolutePathBuf },
|
||||||
|
WebSocket { bind_address: SocketAddr },
|
||||||
|
Off,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||||
|
pub enum AppServerTransportParseError {
|
||||||
|
UnsupportedListenUrl(String),
|
||||||
|
InvalidUnixSocketPath { listen_url: String, message: String },
|
||||||
|
InvalidWebSocketListenUrl(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for AppServerTransportParseError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!(
|
||||||
|
f,
|
||||||
|
"unsupported --listen URL `{listen_url}`; expected `stdio://`, `unix://`, `unix://PATH`, `ws://IP:PORT`, or `off`"
|
||||||
|
),
|
||||||
|
AppServerTransportParseError::InvalidUnixSocketPath {
|
||||||
|
listen_url,
|
||||||
|
message,
|
||||||
|
} => write!(
|
||||||
|
f,
|
||||||
|
"invalid unix socket --listen URL `{listen_url}`; failed to resolve socket path: {message}"
|
||||||
|
),
|
||||||
|
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!(
|
||||||
|
f,
|
||||||
|
"invalid websocket --listen URL `{listen_url}`; expected `ws://IP:PORT`"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for AppServerTransportParseError {}
|
||||||
|
|
||||||
|
impl AppServerTransport {
|
||||||
|
pub const DEFAULT_LISTEN_URL: &'static str = "stdio://";
|
||||||
|
|
||||||
|
pub fn from_listen_url(listen_url: &str) -> Result<Self, AppServerTransportParseError> {
|
||||||
|
if listen_url == Self::DEFAULT_LISTEN_URL {
|
||||||
|
return Ok(Self::Stdio);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(raw_socket_path) = listen_url.strip_prefix("unix://") {
|
||||||
|
let socket_path = if raw_socket_path.is_empty() {
|
||||||
|
let codex_home = find_codex_home().map_err(|err| {
|
||||||
|
AppServerTransportParseError::InvalidUnixSocketPath {
|
||||||
|
listen_url: listen_url.to_string(),
|
||||||
|
message: format!("failed to resolve CODEX_HOME: {err}"),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
app_server_control_socket_path(&codex_home).map_err(|err| {
|
||||||
|
AppServerTransportParseError::InvalidUnixSocketPath {
|
||||||
|
listen_url: listen_url.to_string(),
|
||||||
|
message: err.to_string(),
|
||||||
|
}
|
||||||
|
})?
|
||||||
|
} else {
|
||||||
|
AbsolutePathBuf::relative_to_current_dir(raw_socket_path).map_err(|err| {
|
||||||
|
AppServerTransportParseError::InvalidUnixSocketPath {
|
||||||
|
listen_url: listen_url.to_string(),
|
||||||
|
message: err.to_string(),
|
||||||
|
}
|
||||||
|
})?
|
||||||
|
};
|
||||||
|
return Ok(Self::UnixSocket { socket_path });
|
||||||
|
}
|
||||||
|
|
||||||
|
if listen_url == "off" {
|
||||||
|
return Ok(Self::Off);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(socket_addr) = listen_url.strip_prefix("ws://") {
|
||||||
|
let bind_address = socket_addr.parse::<SocketAddr>().map_err(|_| {
|
||||||
|
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string())
|
||||||
|
})?;
|
||||||
|
return Ok(Self::WebSocket { bind_address });
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(AppServerTransportParseError::UnsupportedListenUrl(
|
||||||
|
listen_url.to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromStr for AppServerTransport {
|
||||||
|
type Err = AppServerTransportParseError;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
|
Self::from_listen_url(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum TransportEvent {
|
||||||
|
ConnectionOpened {
|
||||||
|
connection_id: ConnectionId,
|
||||||
|
origin: ConnectionOrigin,
|
||||||
|
writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||||
|
disconnect_sender: Option<CancellationToken>,
|
||||||
|
},
|
||||||
|
ConnectionClosed {
|
||||||
|
connection_id: ConnectionId,
|
||||||
|
},
|
||||||
|
IncomingMessage {
|
||||||
|
connection_id: ConnectionId,
|
||||||
|
message: JSONRPCMessage,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum ConnectionOrigin {
|
||||||
|
Stdio,
|
||||||
|
InProcess,
|
||||||
|
WebSocket,
|
||||||
|
RemoteControl,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnectionOrigin {
|
||||||
|
pub fn allows_device_key_requests(self) -> bool {
|
||||||
|
// Device-key endpoints are only for local connections that own the app-server instance.
|
||||||
|
// Do not include remote transports such as SSH or remote-control websocket connections.
|
||||||
|
matches!(self, Self::Stdio | Self::InProcess)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static CONNECTION_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||||
|
|
||||||
|
fn next_connection_id() -> ConnectionId {
|
||||||
|
ConnectionId(CONNECTION_ID_COUNTER.fetch_add(1, Ordering::Relaxed))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn forward_incoming_message(
|
||||||
|
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||||
|
writer: &mpsc::Sender<QueuedOutgoingMessage>,
|
||||||
|
connection_id: ConnectionId,
|
||||||
|
payload: &str,
|
||||||
|
) -> bool {
|
||||||
|
match serde_json::from_str::<JSONRPCMessage>(payload) {
|
||||||
|
Ok(message) => {
|
||||||
|
enqueue_incoming_message(transport_event_tx, writer, connection_id, message).await
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
error!("Failed to deserialize JSONRPCMessage: {err}");
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn enqueue_incoming_message(
|
||||||
|
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||||
|
writer: &mpsc::Sender<QueuedOutgoingMessage>,
|
||||||
|
connection_id: ConnectionId,
|
||||||
|
message: JSONRPCMessage,
|
||||||
|
) -> bool {
|
||||||
|
let event = TransportEvent::IncomingMessage {
|
||||||
|
connection_id,
|
||||||
|
message,
|
||||||
|
};
|
||||||
|
match transport_event_tx.try_send(event) {
|
||||||
|
Ok(()) => true,
|
||||||
|
Err(mpsc::error::TrySendError::Closed(_)) => false,
|
||||||
|
Err(mpsc::error::TrySendError::Full(TransportEvent::IncomingMessage {
|
||||||
|
connection_id,
|
||||||
|
message: JSONRPCMessage::Request(request),
|
||||||
|
})) => {
|
||||||
|
let overload_error = OutgoingMessage::Error(OutgoingError {
|
||||||
|
id: request.id,
|
||||||
|
error: JSONRPCErrorError {
|
||||||
|
code: OVERLOADED_ERROR_CODE,
|
||||||
|
message: "Server overloaded; retry later.".to_string(),
|
||||||
|
data: None,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
match writer.try_send(QueuedOutgoingMessage::new(overload_error)) {
|
||||||
|
Ok(()) => true,
|
||||||
|
Err(mpsc::error::TrySendError::Closed(_)) => false,
|
||||||
|
Err(mpsc::error::TrySendError::Full(_overload_error)) => {
|
||||||
|
warn!(
|
||||||
|
"dropping overload response for connection {:?}: outbound queue is full",
|
||||||
|
connection_id
|
||||||
|
);
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(mpsc::error::TrySendError::Full(event)) => transport_event_tx.send(event).await.is_ok(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option<String> {
|
||||||
|
let value = match serde_json::to_value(outgoing_message) {
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(err) => {
|
||||||
|
error!("Failed to convert OutgoingMessage to JSON value: {err}");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match serde_json::to_string(&value) {
|
||||||
|
Ok(json) => Some(json),
|
||||||
|
Err(err) => {
|
||||||
|
error!("Failed to serialize JSONRPCMessage: {err}");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use codex_app_server_protocol::ConfigWarningNotification;
|
||||||
|
use codex_app_server_protocol::JSONRPCNotification;
|
||||||
|
use codex_app_server_protocol::JSONRPCRequest;
|
||||||
|
use codex_app_server_protocol::JSONRPCResponse;
|
||||||
|
use codex_app_server_protocol::RequestId;
|
||||||
|
use codex_app_server_protocol::ServerNotification;
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
|
use serde_json::json;
|
||||||
|
use tokio::time::Duration;
|
||||||
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn listen_off_parses_as_off_transport() {
|
||||||
|
assert_eq!(
|
||||||
|
AppServerTransport::from_listen_url("off"),
|
||||||
|
Ok(AppServerTransport::Off)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_incoming_request_returns_overload_error_when_queue_is_full() {
|
||||||
|
let connection_id = ConnectionId(42);
|
||||||
|
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1);
|
||||||
|
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||||
|
|
||||||
|
let first_message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||||
|
method: "initialized".to_string(),
|
||||||
|
params: None,
|
||||||
|
});
|
||||||
|
transport_event_tx
|
||||||
|
.send(TransportEvent::IncomingMessage {
|
||||||
|
connection_id,
|
||||||
|
message: first_message.clone(),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("queue should accept first message");
|
||||||
|
|
||||||
|
let request = JSONRPCMessage::Request(JSONRPCRequest {
|
||||||
|
id: RequestId::Integer(7),
|
||||||
|
method: "config/read".to_string(),
|
||||||
|
params: Some(json!({ "includeLayers": false })),
|
||||||
|
trace: None,
|
||||||
|
});
|
||||||
|
assert!(
|
||||||
|
enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request).await
|
||||||
|
);
|
||||||
|
|
||||||
|
let queued_event = transport_event_rx
|
||||||
|
.recv()
|
||||||
|
.await
|
||||||
|
.expect("first event should stay queued");
|
||||||
|
match queued_event {
|
||||||
|
TransportEvent::IncomingMessage {
|
||||||
|
connection_id: queued_connection_id,
|
||||||
|
message,
|
||||||
|
} => {
|
||||||
|
assert_eq!(queued_connection_id, connection_id);
|
||||||
|
assert_eq!(message, first_message);
|
||||||
|
}
|
||||||
|
_ => panic!("expected queued incoming message"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let overload = writer_rx
|
||||||
|
.recv()
|
||||||
|
.await
|
||||||
|
.expect("request should receive overload error");
|
||||||
|
let overload_json =
|
||||||
|
serde_json::to_value(overload.message).expect("serialize overload error");
|
||||||
|
assert_eq!(
|
||||||
|
overload_json,
|
||||||
|
json!({
|
||||||
|
"id": 7,
|
||||||
|
"error": {
|
||||||
|
"code": OVERLOADED_ERROR_CODE,
|
||||||
|
"message": "Server overloaded; retry later."
|
||||||
|
}
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_incoming_response_waits_instead_of_dropping_when_queue_is_full() {
|
||||||
|
let connection_id = ConnectionId(42);
|
||||||
|
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1);
|
||||||
|
let (writer_tx, _writer_rx) = mpsc::channel(1);
|
||||||
|
|
||||||
|
let first_message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||||
|
method: "initialized".to_string(),
|
||||||
|
params: None,
|
||||||
|
});
|
||||||
|
transport_event_tx
|
||||||
|
.send(TransportEvent::IncomingMessage {
|
||||||
|
connection_id,
|
||||||
|
message: first_message.clone(),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("queue should accept first message");
|
||||||
|
|
||||||
|
let response = JSONRPCMessage::Response(JSONRPCResponse {
|
||||||
|
id: RequestId::Integer(7),
|
||||||
|
result: json!({"ok": true}),
|
||||||
|
});
|
||||||
|
let transport_event_tx_for_enqueue = transport_event_tx.clone();
|
||||||
|
let writer_tx_for_enqueue = writer_tx.clone();
|
||||||
|
let enqueue_handle = tokio::spawn(async move {
|
||||||
|
enqueue_incoming_message(
|
||||||
|
&transport_event_tx_for_enqueue,
|
||||||
|
&writer_tx_for_enqueue,
|
||||||
|
connection_id,
|
||||||
|
response,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
});
|
||||||
|
|
||||||
|
let queued_event = transport_event_rx
|
||||||
|
.recv()
|
||||||
|
.await
|
||||||
|
.expect("first event should be dequeued");
|
||||||
|
match queued_event {
|
||||||
|
TransportEvent::IncomingMessage {
|
||||||
|
connection_id: queued_connection_id,
|
||||||
|
message,
|
||||||
|
} => {
|
||||||
|
assert_eq!(queued_connection_id, connection_id);
|
||||||
|
assert_eq!(message, first_message);
|
||||||
|
}
|
||||||
|
_ => panic!("expected queued incoming message"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let enqueue_result = enqueue_handle.await.expect("enqueue task should not panic");
|
||||||
|
assert!(enqueue_result);
|
||||||
|
|
||||||
|
let forwarded_event = transport_event_rx
|
||||||
|
.recv()
|
||||||
|
.await
|
||||||
|
.expect("response should be forwarded instead of dropped");
|
||||||
|
match forwarded_event {
|
||||||
|
TransportEvent::IncomingMessage {
|
||||||
|
connection_id: queued_connection_id,
|
||||||
|
message: JSONRPCMessage::Response(JSONRPCResponse { id, result }),
|
||||||
|
} => {
|
||||||
|
assert_eq!(queued_connection_id, connection_id);
|
||||||
|
assert_eq!(id, RequestId::Integer(7));
|
||||||
|
assert_eq!(result, json!({"ok": true}));
|
||||||
|
}
|
||||||
|
_ => panic!("expected forwarded response message"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_incoming_request_does_not_block_when_writer_queue_is_full() {
|
||||||
|
let connection_id = ConnectionId(42);
|
||||||
|
let (transport_event_tx, _transport_event_rx) = mpsc::channel(1);
|
||||||
|
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||||
|
|
||||||
|
transport_event_tx
|
||||||
|
.send(TransportEvent::IncomingMessage {
|
||||||
|
connection_id,
|
||||||
|
message: JSONRPCMessage::Notification(JSONRPCNotification {
|
||||||
|
method: "initialized".to_string(),
|
||||||
|
params: None,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("transport queue should accept first message");
|
||||||
|
|
||||||
|
writer_tx
|
||||||
|
.send(QueuedOutgoingMessage::new(
|
||||||
|
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||||
|
ConfigWarningNotification {
|
||||||
|
summary: "queued".to_string(),
|
||||||
|
details: None,
|
||||||
|
path: None,
|
||||||
|
range: None,
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.expect("writer queue should accept first message");
|
||||||
|
|
||||||
|
let request = JSONRPCMessage::Request(JSONRPCRequest {
|
||||||
|
id: RequestId::Integer(7),
|
||||||
|
method: "config/read".to_string(),
|
||||||
|
params: Some(json!({ "includeLayers": false })),
|
||||||
|
trace: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
let enqueue_result = timeout(
|
||||||
|
Duration::from_millis(100),
|
||||||
|
enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("enqueue should not block while writer queue is full");
|
||||||
|
assert!(enqueue_result);
|
||||||
|
|
||||||
|
let queued_outgoing = writer_rx
|
||||||
|
.recv()
|
||||||
|
.await
|
||||||
|
.expect("writer queue should still contain original message");
|
||||||
|
let queued_json =
|
||||||
|
serde_json::to_value(queued_outgoing.message).expect("serialize queued message");
|
||||||
|
assert_eq!(
|
||||||
|
queued_json,
|
||||||
|
json!({
|
||||||
|
"method": "configWarning",
|
||||||
|
"params": {
|
||||||
|
"summary": "queued",
|
||||||
|
"details": null,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -36,14 +36,14 @@ pub(super) struct QueuedServerEnvelope {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct RemoteControlHandle {
|
pub struct RemoteControlHandle {
|
||||||
enabled_tx: Arc<watch::Sender<bool>>,
|
enabled_tx: Arc<watch::Sender<bool>>,
|
||||||
status_tx: Arc<watch::Sender<RemoteControlStatusChangedNotification>>,
|
status_tx: Arc<watch::Sender<RemoteControlStatusChangedNotification>>,
|
||||||
state_db_available: bool,
|
state_db_available: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RemoteControlHandle {
|
impl RemoteControlHandle {
|
||||||
pub(crate) fn set_enabled(&self, enabled: bool) {
|
pub fn set_enabled(&self, enabled: bool) {
|
||||||
let requested_enabled = enabled;
|
let requested_enabled = enabled;
|
||||||
let enabled = enabled && self.state_db_available;
|
let enabled = enabled && self.state_db_available;
|
||||||
if requested_enabled && !self.state_db_available {
|
if requested_enabled && !self.state_db_available {
|
||||||
@@ -56,14 +56,12 @@ impl RemoteControlHandle {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn status_receiver(
|
pub fn status_receiver(&self) -> watch::Receiver<RemoteControlStatusChangedNotification> {
|
||||||
&self,
|
|
||||||
) -> watch::Receiver<RemoteControlStatusChangedNotification> {
|
|
||||||
self.status_tx.subscribe()
|
self.status_tx.subscribe()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn start_remote_control(
|
pub async fn start_remote_control(
|
||||||
remote_control_url: String,
|
remote_control_url: String,
|
||||||
state_db: Option<Arc<StateRuntime>>,
|
state_db: Option<Arc<StateRuntime>>,
|
||||||
auth_manager: Arc<AuthManager>,
|
auth_manager: Arc<AuthManager>,
|
||||||
@@ -21,7 +21,7 @@ use tracing::debug;
|
|||||||
use tracing::error;
|
use tracing::error;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
pub(crate) async fn start_stdio_connection(
|
pub async fn start_stdio_connection(
|
||||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||||
stdio_handles: &mut Vec<JoinHandle<()>>,
|
stdio_handles: &mut Vec<JoinHandle<()>>,
|
||||||
initialize_client_name_tx: oneshot::Sender<String>,
|
initialize_client_name_tx: oneshot::Sender<String>,
|
||||||
@@ -20,7 +20,7 @@ use tracing::warn;
|
|||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
const CONTROL_SOCKET_MODE: u32 = 0o600;
|
const CONTROL_SOCKET_MODE: u32 = 0o600;
|
||||||
|
|
||||||
pub(crate) async fn start_control_socket_acceptor(
|
pub async fn start_control_socket_acceptor(
|
||||||
socket_path: AbsolutePathBuf,
|
socket_path: AbsolutePathBuf,
|
||||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||||
shutdown_token: CancellationToken,
|
shutdown_token: CancellationToken,
|
||||||
@@ -128,7 +128,7 @@ async fn websocket_upgrade_handler(
|
|||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn start_websocket_acceptor(
|
pub async fn start_websocket_acceptor(
|
||||||
bind_address: SocketAddr,
|
bind_address: SocketAddr,
|
||||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||||
shutdown_token: CancellationToken,
|
shutdown_token: CancellationToken,
|
||||||
@@ -30,7 +30,6 @@ axum = { workspace = true, default-features = false, features = [
|
|||||||
"ws",
|
"ws",
|
||||||
] }
|
] }
|
||||||
codex-analytics = { workspace = true }
|
codex-analytics = { workspace = true }
|
||||||
codex-api = { workspace = true }
|
|
||||||
codex-arg0 = { workspace = true }
|
codex-arg0 = { workspace = true }
|
||||||
codex-cloud-requirements = { workspace = true }
|
codex-cloud-requirements = { workspace = true }
|
||||||
codex-config = { workspace = true }
|
codex-config = { workspace = true }
|
||||||
@@ -58,6 +57,7 @@ codex-model-provider = { workspace = true }
|
|||||||
codex-models-manager = { workspace = true }
|
codex-models-manager = { workspace = true }
|
||||||
codex-protocol = { workspace = true }
|
codex-protocol = { workspace = true }
|
||||||
codex-app-server-protocol = { workspace = true }
|
codex-app-server-protocol = { workspace = true }
|
||||||
|
codex-app-server-transport = { workspace = true }
|
||||||
codex-feedback = { workspace = true }
|
codex-feedback = { workspace = true }
|
||||||
codex-rmcp-client = { workspace = true }
|
codex-rmcp-client = { workspace = true }
|
||||||
codex-rollout = { workspace = true }
|
codex-rollout = { workspace = true }
|
||||||
@@ -65,18 +65,11 @@ codex-sandboxing = { workspace = true }
|
|||||||
codex-state = { workspace = true }
|
codex-state = { workspace = true }
|
||||||
codex-thread-store = { workspace = true }
|
codex-thread-store = { workspace = true }
|
||||||
codex-tools = { workspace = true }
|
codex-tools = { workspace = true }
|
||||||
codex-uds = { workspace = true }
|
|
||||||
codex-utils-absolute-path = { workspace = true }
|
codex-utils-absolute-path = { workspace = true }
|
||||||
codex-utils-json-to-toml = { workspace = true }
|
codex-utils-json-to-toml = { workspace = true }
|
||||||
codex-utils-rustls-provider = { workspace = true }
|
|
||||||
chrono = { workspace = true }
|
chrono = { workspace = true }
|
||||||
clap = { workspace = true, features = ["derive"] }
|
clap = { workspace = true, features = ["derive"] }
|
||||||
constant_time_eq = { workspace = true }
|
|
||||||
futures = { workspace = true }
|
futures = { workspace = true }
|
||||||
gethostname = { workspace = true }
|
|
||||||
hmac = { workspace = true }
|
|
||||||
jsonwebtoken = { workspace = true }
|
|
||||||
owo-colors = { workspace = true, features = ["supports-colors"] }
|
|
||||||
serde = { workspace = true, features = ["derive"] }
|
serde = { workspace = true, features = ["derive"] }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
sha2 = { workspace = true }
|
sha2 = { workspace = true }
|
||||||
@@ -93,7 +86,6 @@ tokio = { workspace = true, features = [
|
|||||||
"signal",
|
"signal",
|
||||||
] }
|
] }
|
||||||
tokio-util = { workspace = true }
|
tokio-util = { workspace = true }
|
||||||
tokio-tungstenite = { workspace = true }
|
|
||||||
tracing = { workspace = true, features = ["log"] }
|
tracing = { workspace = true, features = ["log"] }
|
||||||
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt", "json"] }
|
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt", "json"] }
|
||||||
url = { workspace = true }
|
url = { workspace = true }
|
||||||
@@ -111,6 +103,7 @@ core_test_support = { workspace = true }
|
|||||||
codex-model-provider-info = { workspace = true }
|
codex-model-provider-info = { workspace = true }
|
||||||
codex-utils-cargo-bin = { workspace = true }
|
codex-utils-cargo-bin = { workspace = true }
|
||||||
flate2 = { workspace = true }
|
flate2 = { workspace = true }
|
||||||
|
hmac = { workspace = true }
|
||||||
opentelemetry = { workspace = true }
|
opentelemetry = { workspace = true }
|
||||||
opentelemetry_sdk = { workspace = true }
|
opentelemetry_sdk = { workspace = true }
|
||||||
pretty_assertions = { workspace = true }
|
pretty_assertions = { workspace = true }
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::AtomicI64;
|
use std::sync::atomic::AtomicI64;
|
||||||
use std::sync::atomic::Ordering;
|
use std::sync::atomic::Ordering;
|
||||||
@@ -15,7 +14,6 @@ use codex_app_server_protocol::ServerRequestPayload;
|
|||||||
use codex_otel::span_w3c_trace_context;
|
use codex_otel::span_w3c_trace_context;
|
||||||
use codex_protocol::ThreadId;
|
use codex_protocol::ThreadId;
|
||||||
use codex_protocol::protocol::W3cTraceContext;
|
use codex_protocol::protocol::W3cTraceContext;
|
||||||
use serde::Serialize;
|
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
@@ -26,22 +24,17 @@ use tracing::warn;
|
|||||||
use crate::error_code::INTERNAL_ERROR_CODE;
|
use crate::error_code::INTERNAL_ERROR_CODE;
|
||||||
use crate::error_code::internal_error;
|
use crate::error_code::internal_error;
|
||||||
use crate::server_request_error::TURN_TRANSITION_PENDING_REQUEST_ERROR_REASON;
|
use crate::server_request_error::TURN_TRANSITION_PENDING_REQUEST_ERROR_REASON;
|
||||||
|
pub(crate) use codex_app_server_transport::ConnectionId;
|
||||||
|
pub(crate) use codex_app_server_transport::OutgoingError;
|
||||||
|
pub(crate) use codex_app_server_transport::OutgoingMessage;
|
||||||
|
pub(crate) use codex_app_server_transport::OutgoingResponse;
|
||||||
|
pub(crate) use codex_app_server_transport::QueuedOutgoingMessage;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use codex_protocol::account::PlanType;
|
use codex_protocol::account::PlanType;
|
||||||
|
|
||||||
pub(crate) type ClientRequestResult = std::result::Result<Result, JSONRPCErrorError>;
|
pub(crate) type ClientRequestResult = std::result::Result<Result, JSONRPCErrorError>;
|
||||||
|
|
||||||
/// Stable identifier for a transport connection.
|
|
||||||
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
|
|
||||||
pub(crate) struct ConnectionId(pub(crate) u64);
|
|
||||||
|
|
||||||
impl fmt::Display for ConnectionId {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
write!(f, "{}", self.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Stable identifier for a client request scoped to a transport connection.
|
/// Stable identifier for a client request scoped to a transport connection.
|
||||||
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
||||||
pub(crate) struct ConnectionRequestId {
|
pub(crate) struct ConnectionRequestId {
|
||||||
@@ -96,21 +89,6 @@ pub(crate) enum OutgoingEnvelope {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub(crate) struct QueuedOutgoingMessage {
|
|
||||||
pub(crate) message: OutgoingMessage,
|
|
||||||
pub(crate) write_complete_tx: Option<oneshot::Sender<()>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QueuedOutgoingMessage {
|
|
||||||
pub(crate) fn new(message: OutgoingMessage) -> Self {
|
|
||||||
Self {
|
|
||||||
message,
|
|
||||||
write_complete_tx: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sends messages to the client and manages request callbacks.
|
/// Sends messages to the client and manages request callbacks.
|
||||||
pub(crate) struct OutgoingMessageSender {
|
pub(crate) struct OutgoingMessageSender {
|
||||||
next_server_request_id: AtomicI64,
|
next_server_request_id: AtomicI64,
|
||||||
@@ -665,30 +643,6 @@ impl OutgoingMessageSender {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Outgoing message from the server to the client.
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub(crate) enum OutgoingMessage {
|
|
||||||
Request(ServerRequest),
|
|
||||||
/// AppServerNotification is specific to the case where this is run as an
|
|
||||||
/// "app server" as opposed to an MCP server.
|
|
||||||
AppServerNotification(ServerNotification),
|
|
||||||
Response(OutgoingResponse),
|
|
||||||
Error(OutgoingError),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
|
||||||
pub(crate) struct OutgoingResponse {
|
|
||||||
pub id: RequestId,
|
|
||||||
pub result: Result,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
|
||||||
pub(crate) struct OutgoingError {
|
|
||||||
pub error: JSONRPCErrorError,
|
|
||||||
pub id: RequestId,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|||||||
232
codex-rs/app-server/src/transport.rs
Normal file
232
codex-rs/app-server/src/transport.rs
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
use crate::message_processor::ConnectionSessionState;
|
||||||
|
use crate::outgoing_message::OutgoingEnvelope;
|
||||||
|
use codex_app_server_protocol::ExperimentalApi;
|
||||||
|
use codex_app_server_protocol::ServerRequest;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::RwLock;
|
||||||
|
use std::sync::atomic::AtomicBool;
|
||||||
|
use std::sync::atomic::Ordering;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
pub use codex_app_server_transport::AppServerTransport;
|
||||||
|
pub(crate) use codex_app_server_transport::CHANNEL_CAPACITY;
|
||||||
|
pub(crate) use codex_app_server_transport::ConnectionId;
|
||||||
|
pub(crate) use codex_app_server_transport::ConnectionOrigin;
|
||||||
|
pub(crate) use codex_app_server_transport::OutgoingMessage;
|
||||||
|
pub(crate) use codex_app_server_transport::QueuedOutgoingMessage;
|
||||||
|
pub(crate) use codex_app_server_transport::RemoteControlHandle;
|
||||||
|
pub(crate) use codex_app_server_transport::TransportEvent;
|
||||||
|
pub use codex_app_server_transport::app_server_control_socket_path;
|
||||||
|
pub use codex_app_server_transport::auth;
|
||||||
|
pub(crate) use codex_app_server_transport::start_control_socket_acceptor;
|
||||||
|
pub(crate) use codex_app_server_transport::start_remote_control;
|
||||||
|
pub(crate) use codex_app_server_transport::start_stdio_connection;
|
||||||
|
pub(crate) use codex_app_server_transport::start_websocket_acceptor;
|
||||||
|
|
||||||
|
pub(crate) struct ConnectionState {
|
||||||
|
pub(crate) outbound_initialized: Arc<AtomicBool>,
|
||||||
|
pub(crate) outbound_experimental_api_enabled: Arc<AtomicBool>,
|
||||||
|
pub(crate) outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||||
|
pub(crate) session: Arc<ConnectionSessionState>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnectionState {
|
||||||
|
pub(crate) fn new(
|
||||||
|
origin: ConnectionOrigin,
|
||||||
|
outbound_initialized: Arc<AtomicBool>,
|
||||||
|
outbound_experimental_api_enabled: Arc<AtomicBool>,
|
||||||
|
outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
outbound_initialized,
|
||||||
|
outbound_experimental_api_enabled,
|
||||||
|
outbound_opted_out_notification_methods,
|
||||||
|
session: Arc::new(ConnectionSessionState::new(origin)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct OutboundConnectionState {
|
||||||
|
pub(crate) initialized: Arc<AtomicBool>,
|
||||||
|
pub(crate) experimental_api_enabled: Arc<AtomicBool>,
|
||||||
|
pub(crate) opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||||
|
pub(crate) writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||||
|
disconnect_sender: Option<CancellationToken>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OutboundConnectionState {
|
||||||
|
pub(crate) fn new(
|
||||||
|
writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||||
|
initialized: Arc<AtomicBool>,
|
||||||
|
experimental_api_enabled: Arc<AtomicBool>,
|
||||||
|
opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||||
|
disconnect_sender: Option<CancellationToken>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
initialized,
|
||||||
|
experimental_api_enabled,
|
||||||
|
opted_out_notification_methods,
|
||||||
|
writer,
|
||||||
|
disconnect_sender,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn can_disconnect(&self) -> bool {
|
||||||
|
self.disconnect_sender.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn request_disconnect(&self) {
|
||||||
|
if let Some(disconnect_sender) = &self.disconnect_sender {
|
||||||
|
disconnect_sender.cancel();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn should_skip_notification_for_connection(
|
||||||
|
connection_state: &OutboundConnectionState,
|
||||||
|
message: &OutgoingMessage,
|
||||||
|
) -> bool {
|
||||||
|
let Ok(opted_out_notification_methods) = connection_state.opted_out_notification_methods.read()
|
||||||
|
else {
|
||||||
|
warn!("failed to read outbound opted-out notifications");
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
match message {
|
||||||
|
OutgoingMessage::AppServerNotification(notification) => {
|
||||||
|
if notification.experimental_reason().is_some()
|
||||||
|
&& !connection_state
|
||||||
|
.experimental_api_enabled
|
||||||
|
.load(Ordering::Acquire)
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
let method = notification.to_string();
|
||||||
|
opted_out_notification_methods.contains(method.as_str())
|
||||||
|
}
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn disconnect_connection(
|
||||||
|
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||||
|
connection_id: ConnectionId,
|
||||||
|
) -> bool {
|
||||||
|
if let Some(connection_state) = connections.remove(&connection_id) {
|
||||||
|
connection_state.request_disconnect();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_message_to_connection(
|
||||||
|
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||||
|
connection_id: ConnectionId,
|
||||||
|
message: OutgoingMessage,
|
||||||
|
write_complete_tx: Option<tokio::sync::oneshot::Sender<()>>,
|
||||||
|
) -> bool {
|
||||||
|
let Some(connection_state) = connections.get(&connection_id) else {
|
||||||
|
warn!("dropping message for disconnected connection: {connection_id:?}");
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
let message = filter_outgoing_message_for_connection(connection_state, message);
|
||||||
|
if should_skip_notification_for_connection(connection_state, &message) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let writer = connection_state.writer.clone();
|
||||||
|
let queued_message = QueuedOutgoingMessage {
|
||||||
|
message,
|
||||||
|
write_complete_tx,
|
||||||
|
};
|
||||||
|
if connection_state.can_disconnect() {
|
||||||
|
match writer.try_send(queued_message) {
|
||||||
|
Ok(()) => false,
|
||||||
|
Err(mpsc::error::TrySendError::Full(_)) => {
|
||||||
|
warn!(
|
||||||
|
"disconnecting slow connection after outbound queue filled: {connection_id:?}"
|
||||||
|
);
|
||||||
|
disconnect_connection(connections, connection_id)
|
||||||
|
}
|
||||||
|
Err(mpsc::error::TrySendError::Closed(_)) => {
|
||||||
|
disconnect_connection(connections, connection_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if writer.send(queued_message).await.is_err() {
|
||||||
|
disconnect_connection(connections, connection_id)
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn filter_outgoing_message_for_connection(
|
||||||
|
connection_state: &OutboundConnectionState,
|
||||||
|
message: OutgoingMessage,
|
||||||
|
) -> OutgoingMessage {
|
||||||
|
let experimental_api_enabled = connection_state
|
||||||
|
.experimental_api_enabled
|
||||||
|
.load(Ordering::Acquire);
|
||||||
|
match message {
|
||||||
|
OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
|
||||||
|
request_id,
|
||||||
|
mut params,
|
||||||
|
}) => {
|
||||||
|
if !experimental_api_enabled {
|
||||||
|
params.strip_experimental_fields();
|
||||||
|
}
|
||||||
|
OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
|
||||||
|
request_id,
|
||||||
|
params,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_ => message,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn route_outgoing_envelope(
|
||||||
|
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||||
|
envelope: OutgoingEnvelope,
|
||||||
|
) {
|
||||||
|
match envelope {
|
||||||
|
OutgoingEnvelope::ToConnection {
|
||||||
|
connection_id,
|
||||||
|
message,
|
||||||
|
write_complete_tx,
|
||||||
|
} => {
|
||||||
|
let _ =
|
||||||
|
send_message_to_connection(connections, connection_id, message, write_complete_tx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
OutgoingEnvelope::Broadcast { message } => {
|
||||||
|
let target_connections: Vec<ConnectionId> = connections
|
||||||
|
.iter()
|
||||||
|
.filter_map(|(connection_id, connection_state)| {
|
||||||
|
if connection_state.initialized.load(Ordering::Acquire)
|
||||||
|
&& !should_skip_notification_for_connection(connection_state, &message)
|
||||||
|
{
|
||||||
|
Some(*connection_id)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for connection_id in target_connections {
|
||||||
|
let _ = send_message_to_connection(
|
||||||
|
connections,
|
||||||
|
connection_id,
|
||||||
|
message.clone(),
|
||||||
|
/*write_complete_tx*/ None,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "transport_tests.rs"]
|
||||||
|
mod tests;
|
||||||
File diff suppressed because it is too large
Load Diff
532
codex-rs/app-server/src/transport_tests.rs
Normal file
532
codex-rs/app-server/src/transport_tests.rs
Normal file
@@ -0,0 +1,532 @@
|
|||||||
|
use super::*;
|
||||||
|
use codex_app_server_protocol::ConfigWarningNotification;
|
||||||
|
use codex_app_server_protocol::RequestId;
|
||||||
|
use codex_app_server_protocol::ServerNotification;
|
||||||
|
use codex_app_server_protocol::ThreadGoal;
|
||||||
|
use codex_app_server_protocol::ThreadGoalStatus;
|
||||||
|
use codex_app_server_protocol::ThreadGoalUpdatedNotification;
|
||||||
|
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
|
use serde_json::json;
|
||||||
|
use tokio::time::Duration;
|
||||||
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
fn absolute_path(path: &str) -> AbsolutePathBuf {
|
||||||
|
AbsolutePathBuf::from_absolute_path(path).expect("absolute path")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn thread_goal_updated_notification() -> ServerNotification {
|
||||||
|
ServerNotification::ThreadGoalUpdated(ThreadGoalUpdatedNotification {
|
||||||
|
thread_id: "thread-1".to_string(),
|
||||||
|
turn_id: None,
|
||||||
|
goal: ThreadGoal {
|
||||||
|
thread_id: "thread-1".to_string(),
|
||||||
|
objective: "ship goal mode".to_string(),
|
||||||
|
status: ThreadGoalStatus::Active,
|
||||||
|
token_budget: None,
|
||||||
|
tokens_used: 0,
|
||||||
|
time_used_seconds: 0,
|
||||||
|
created_at: 1,
|
||||||
|
updated_at: 1,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn to_connection_notification_respects_opt_out_filters() {
|
||||||
|
let connection_id = ConnectionId(7);
|
||||||
|
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||||
|
let initialized = Arc::new(AtomicBool::new(true));
|
||||||
|
let opted_out_notification_methods =
|
||||||
|
Arc::new(RwLock::new(HashSet::from(["configWarning".to_string()])));
|
||||||
|
|
||||||
|
let mut connections = HashMap::new();
|
||||||
|
connections.insert(
|
||||||
|
connection_id,
|
||||||
|
OutboundConnectionState::new(
|
||||||
|
writer_tx,
|
||||||
|
initialized,
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
opted_out_notification_methods,
|
||||||
|
/*disconnect_sender*/ None,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
route_outgoing_envelope(
|
||||||
|
&mut connections,
|
||||||
|
OutgoingEnvelope::ToConnection {
|
||||||
|
connection_id,
|
||||||
|
message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||||
|
ConfigWarningNotification {
|
||||||
|
summary: "task_started".to_string(),
|
||||||
|
details: None,
|
||||||
|
path: None,
|
||||||
|
range: None,
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
write_complete_tx: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
writer_rx.try_recv().is_err(),
|
||||||
|
"opted-out notification should be dropped"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn to_connection_notifications_are_dropped_for_opted_out_clients() {
|
||||||
|
let connection_id = ConnectionId(10);
|
||||||
|
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||||
|
|
||||||
|
let mut connections = HashMap::new();
|
||||||
|
connections.insert(
|
||||||
|
connection_id,
|
||||||
|
OutboundConnectionState::new(
|
||||||
|
writer_tx,
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(RwLock::new(HashSet::from(["configWarning".to_string()]))),
|
||||||
|
/*disconnect_sender*/ None,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
route_outgoing_envelope(
|
||||||
|
&mut connections,
|
||||||
|
OutgoingEnvelope::ToConnection {
|
||||||
|
connection_id,
|
||||||
|
message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||||
|
ConfigWarningNotification {
|
||||||
|
summary: "task_started".to_string(),
|
||||||
|
details: None,
|
||||||
|
path: None,
|
||||||
|
range: None,
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
write_complete_tx: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
writer_rx.try_recv().is_err(),
|
||||||
|
"opted-out notifications should not reach clients"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn to_connection_notifications_are_preserved_for_non_opted_out_clients() {
|
||||||
|
let connection_id = ConnectionId(11);
|
||||||
|
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||||
|
|
||||||
|
let mut connections = HashMap::new();
|
||||||
|
connections.insert(
|
||||||
|
connection_id,
|
||||||
|
OutboundConnectionState::new(
|
||||||
|
writer_tx,
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(RwLock::new(HashSet::new())),
|
||||||
|
/*disconnect_sender*/ None,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
route_outgoing_envelope(
|
||||||
|
&mut connections,
|
||||||
|
OutgoingEnvelope::ToConnection {
|
||||||
|
connection_id,
|
||||||
|
message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||||
|
ConfigWarningNotification {
|
||||||
|
summary: "task_started".to_string(),
|
||||||
|
details: None,
|
||||||
|
path: None,
|
||||||
|
range: None,
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
write_complete_tx: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let message = writer_rx
|
||||||
|
.recv()
|
||||||
|
.await
|
||||||
|
.expect("notification should reach non-opted-out clients");
|
||||||
|
assert!(matches!(
|
||||||
|
message.message,
|
||||||
|
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||||
|
ConfigWarningNotification { summary, .. }
|
||||||
|
)) if summary == "task_started"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn experimental_notifications_are_dropped_without_capability() {
|
||||||
|
let connection_id = ConnectionId(12);
|
||||||
|
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||||
|
|
||||||
|
let mut connections = HashMap::new();
|
||||||
|
connections.insert(
|
||||||
|
connection_id,
|
||||||
|
OutboundConnectionState::new(
|
||||||
|
writer_tx,
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(AtomicBool::new(false)),
|
||||||
|
Arc::new(RwLock::new(HashSet::new())),
|
||||||
|
/*disconnect_sender*/ None,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
route_outgoing_envelope(
|
||||||
|
&mut connections,
|
||||||
|
OutgoingEnvelope::ToConnection {
|
||||||
|
connection_id,
|
||||||
|
message: OutgoingMessage::AppServerNotification(thread_goal_updated_notification()),
|
||||||
|
write_complete_tx: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
writer_rx.try_recv().is_err(),
|
||||||
|
"experimental notifications should not reach clients without capability"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn experimental_notifications_are_preserved_with_capability() {
|
||||||
|
let connection_id = ConnectionId(13);
|
||||||
|
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||||
|
|
||||||
|
let mut connections = HashMap::new();
|
||||||
|
connections.insert(
|
||||||
|
connection_id,
|
||||||
|
OutboundConnectionState::new(
|
||||||
|
writer_tx,
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(RwLock::new(HashSet::new())),
|
||||||
|
/*disconnect_sender*/ None,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
route_outgoing_envelope(
|
||||||
|
&mut connections,
|
||||||
|
OutgoingEnvelope::ToConnection {
|
||||||
|
connection_id,
|
||||||
|
message: OutgoingMessage::AppServerNotification(thread_goal_updated_notification()),
|
||||||
|
write_complete_tx: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let message = writer_rx
|
||||||
|
.recv()
|
||||||
|
.await
|
||||||
|
.expect("experimental notification should reach opted-in client");
|
||||||
|
assert!(matches!(
|
||||||
|
message.message,
|
||||||
|
OutgoingMessage::AppServerNotification(ServerNotification::ThreadGoalUpdated(_))
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn command_execution_request_approval_strips_additional_permissions_without_capability() {
|
||||||
|
let connection_id = ConnectionId(8);
|
||||||
|
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||||
|
|
||||||
|
let mut connections = HashMap::new();
|
||||||
|
connections.insert(
|
||||||
|
connection_id,
|
||||||
|
OutboundConnectionState::new(
|
||||||
|
writer_tx,
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(AtomicBool::new(false)),
|
||||||
|
Arc::new(RwLock::new(HashSet::new())),
|
||||||
|
/*disconnect_sender*/ None,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
route_outgoing_envelope(
|
||||||
|
&mut connections,
|
||||||
|
OutgoingEnvelope::ToConnection {
|
||||||
|
connection_id,
|
||||||
|
message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
|
||||||
|
request_id: RequestId::Integer(1),
|
||||||
|
params: codex_app_server_protocol::CommandExecutionRequestApprovalParams {
|
||||||
|
thread_id: "thr_123".to_string(),
|
||||||
|
turn_id: "turn_123".to_string(),
|
||||||
|
item_id: "call_123".to_string(),
|
||||||
|
approval_id: None,
|
||||||
|
reason: Some("Need extra read access".to_string()),
|
||||||
|
network_approval_context: None,
|
||||||
|
command: Some("cat file".to_string()),
|
||||||
|
cwd: Some(absolute_path("/tmp")),
|
||||||
|
command_actions: None,
|
||||||
|
additional_permissions: Some(
|
||||||
|
codex_app_server_protocol::AdditionalPermissionProfile {
|
||||||
|
network: None,
|
||||||
|
file_system: Some(
|
||||||
|
codex_app_server_protocol::AdditionalFileSystemPermissions {
|
||||||
|
read: Some(vec![absolute_path("/tmp/allowed")]),
|
||||||
|
write: None,
|
||||||
|
glob_scan_max_depth: None,
|
||||||
|
entries: None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
proposed_execpolicy_amendment: None,
|
||||||
|
proposed_network_policy_amendments: None,
|
||||||
|
available_decisions: None,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
write_complete_tx: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let message = writer_rx
|
||||||
|
.recv()
|
||||||
|
.await
|
||||||
|
.expect("request should be delivered to the connection");
|
||||||
|
let json = serde_json::to_value(message.message).expect("request should serialize");
|
||||||
|
assert_eq!(json["params"].get("additionalPermissions"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn command_execution_request_approval_keeps_additional_permissions_with_capability() {
|
||||||
|
let connection_id = ConnectionId(9);
|
||||||
|
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||||
|
|
||||||
|
let mut connections = HashMap::new();
|
||||||
|
connections.insert(
|
||||||
|
connection_id,
|
||||||
|
OutboundConnectionState::new(
|
||||||
|
writer_tx,
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(RwLock::new(HashSet::new())),
|
||||||
|
/*disconnect_sender*/ None,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
route_outgoing_envelope(
|
||||||
|
&mut connections,
|
||||||
|
OutgoingEnvelope::ToConnection {
|
||||||
|
connection_id,
|
||||||
|
message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
|
||||||
|
request_id: RequestId::Integer(1),
|
||||||
|
params: codex_app_server_protocol::CommandExecutionRequestApprovalParams {
|
||||||
|
thread_id: "thr_123".to_string(),
|
||||||
|
turn_id: "turn_123".to_string(),
|
||||||
|
item_id: "call_123".to_string(),
|
||||||
|
approval_id: None,
|
||||||
|
reason: Some("Need extra read access".to_string()),
|
||||||
|
network_approval_context: None,
|
||||||
|
command: Some("cat file".to_string()),
|
||||||
|
cwd: Some(absolute_path("/tmp")),
|
||||||
|
command_actions: None,
|
||||||
|
additional_permissions: Some(
|
||||||
|
codex_app_server_protocol::AdditionalPermissionProfile {
|
||||||
|
network: None,
|
||||||
|
file_system: Some(
|
||||||
|
codex_app_server_protocol::AdditionalFileSystemPermissions {
|
||||||
|
read: Some(vec![absolute_path("/tmp/allowed")]),
|
||||||
|
write: None,
|
||||||
|
glob_scan_max_depth: None,
|
||||||
|
entries: None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
proposed_execpolicy_amendment: None,
|
||||||
|
proposed_network_policy_amendments: None,
|
||||||
|
available_decisions: None,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
write_complete_tx: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let message = writer_rx
|
||||||
|
.recv()
|
||||||
|
.await
|
||||||
|
.expect("request should be delivered to the connection");
|
||||||
|
let json = serde_json::to_value(message.message).expect("request should serialize");
|
||||||
|
let allowed_path = absolute_path("/tmp/allowed").to_string_lossy().into_owned();
|
||||||
|
assert_eq!(
|
||||||
|
json["params"]["additionalPermissions"],
|
||||||
|
json!({
|
||||||
|
"network": null,
|
||||||
|
"fileSystem": {
|
||||||
|
"read": [allowed_path],
|
||||||
|
"write": null,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn broadcast_does_not_block_on_slow_connection() {
|
||||||
|
let fast_connection_id = ConnectionId(1);
|
||||||
|
let slow_connection_id = ConnectionId(2);
|
||||||
|
|
||||||
|
let (fast_writer_tx, mut fast_writer_rx) = mpsc::channel(1);
|
||||||
|
let (slow_writer_tx, mut slow_writer_rx) = mpsc::channel(1);
|
||||||
|
let fast_disconnect_token = CancellationToken::new();
|
||||||
|
let slow_disconnect_token = CancellationToken::new();
|
||||||
|
|
||||||
|
let mut connections = HashMap::new();
|
||||||
|
connections.insert(
|
||||||
|
fast_connection_id,
|
||||||
|
OutboundConnectionState::new(
|
||||||
|
fast_writer_tx,
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(RwLock::new(HashSet::new())),
|
||||||
|
Some(fast_disconnect_token.clone()),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
connections.insert(
|
||||||
|
slow_connection_id,
|
||||||
|
OutboundConnectionState::new(
|
||||||
|
slow_writer_tx.clone(),
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(RwLock::new(HashSet::new())),
|
||||||
|
Some(slow_disconnect_token.clone()),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
let queued_message = OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||||
|
ConfigWarningNotification {
|
||||||
|
summary: "already-buffered".to_string(),
|
||||||
|
details: None,
|
||||||
|
path: None,
|
||||||
|
range: None,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
slow_writer_tx
|
||||||
|
.try_send(QueuedOutgoingMessage::new(queued_message))
|
||||||
|
.expect("channel should have room");
|
||||||
|
|
||||||
|
let broadcast_message = OutgoingMessage::AppServerNotification(
|
||||||
|
ServerNotification::ConfigWarning(ConfigWarningNotification {
|
||||||
|
summary: "test".to_string(),
|
||||||
|
details: None,
|
||||||
|
path: None,
|
||||||
|
range: None,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
timeout(
|
||||||
|
Duration::from_millis(100),
|
||||||
|
route_outgoing_envelope(
|
||||||
|
&mut connections,
|
||||||
|
OutgoingEnvelope::Broadcast {
|
||||||
|
message: broadcast_message,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("broadcast should return even when one connection is slow");
|
||||||
|
assert!(!connections.contains_key(&slow_connection_id));
|
||||||
|
assert!(slow_disconnect_token.is_cancelled());
|
||||||
|
assert!(!fast_disconnect_token.is_cancelled());
|
||||||
|
let fast_message = fast_writer_rx
|
||||||
|
.try_recv()
|
||||||
|
.expect("fast connection should receive the broadcast notification");
|
||||||
|
assert!(matches!(
|
||||||
|
fast_message.message,
|
||||||
|
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||||
|
ConfigWarningNotification { summary, .. }
|
||||||
|
)) if summary == "test"
|
||||||
|
));
|
||||||
|
|
||||||
|
let slow_message = slow_writer_rx
|
||||||
|
.try_recv()
|
||||||
|
.expect("slow connection should retain its original buffered message");
|
||||||
|
assert!(matches!(
|
||||||
|
slow_message.message,
|
||||||
|
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||||
|
ConfigWarningNotification { summary, .. }
|
||||||
|
)) if summary == "already-buffered"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn to_connection_stdio_waits_instead_of_disconnecting_when_writer_queue_is_full() {
|
||||||
|
let connection_id = ConnectionId(3);
|
||||||
|
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||||
|
writer_tx
|
||||||
|
.send(QueuedOutgoingMessage::new(
|
||||||
|
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||||
|
ConfigWarningNotification {
|
||||||
|
summary: "queued".to_string(),
|
||||||
|
details: None,
|
||||||
|
path: None,
|
||||||
|
range: None,
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.expect("channel should accept the first queued message");
|
||||||
|
|
||||||
|
let mut connections = HashMap::new();
|
||||||
|
connections.insert(
|
||||||
|
connection_id,
|
||||||
|
OutboundConnectionState::new(
|
||||||
|
writer_tx,
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(AtomicBool::new(true)),
|
||||||
|
Arc::new(RwLock::new(HashSet::new())),
|
||||||
|
/*disconnect_sender*/ None,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
let route_task = tokio::spawn(async move {
|
||||||
|
route_outgoing_envelope(
|
||||||
|
&mut connections,
|
||||||
|
OutgoingEnvelope::ToConnection {
|
||||||
|
connection_id,
|
||||||
|
message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||||
|
ConfigWarningNotification {
|
||||||
|
summary: "second".to_string(),
|
||||||
|
details: None,
|
||||||
|
path: None,
|
||||||
|
range: None,
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
write_complete_tx: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
});
|
||||||
|
|
||||||
|
let first = timeout(Duration::from_millis(100), writer_rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("first queued message should be readable")
|
||||||
|
.expect("first queued message should exist");
|
||||||
|
timeout(Duration::from_millis(100), route_task)
|
||||||
|
.await
|
||||||
|
.expect("routing should finish after the first queued message is drained")
|
||||||
|
.expect("routing task should succeed");
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
first.message,
|
||||||
|
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||||
|
ConfigWarningNotification { summary, .. }
|
||||||
|
)) if summary == "queued"
|
||||||
|
));
|
||||||
|
let second = writer_rx
|
||||||
|
.try_recv()
|
||||||
|
.expect("second notification should be delivered once the queue has room");
|
||||||
|
assert!(matches!(
|
||||||
|
second.message,
|
||||||
|
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||||
|
ConfigWarningNotification { summary, .. }
|
||||||
|
)) if summary == "second"
|
||||||
|
));
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user