mirror of
https://github.com/openai/codex.git
synced 2026-05-03 02:46:39 +00:00
Compare commits
5 Commits
etraut/sid
...
starr/exec
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
751ed42d78 | ||
|
|
312f454377 | ||
|
|
4f0b7d98c2 | ||
|
|
c429fcf77f | ||
|
|
e2fda326ea |
@@ -446,7 +446,7 @@ struct AppServerCommand {
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
struct ExecServerCommand {
|
||||
/// Transport endpoint URL. Supported values: `ws://IP:PORT` (default).
|
||||
/// Transport endpoint URL. Supported values: `ws://IP:PORT` (default), `stdio`, `stdio://`.
|
||||
#[arg(
|
||||
long = "listen",
|
||||
value_name = "URL",
|
||||
|
||||
@@ -17,13 +17,14 @@ use tokio::sync::mpsc;
|
||||
use tokio::sync::watch;
|
||||
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::ProcessId;
|
||||
use crate::client_api::ExecServerClientConnectOptions;
|
||||
use crate::client_api::ExecServerTransport;
|
||||
use crate::client_api::HttpClient;
|
||||
use crate::client_api::RemoteExecServerConnectArgs;
|
||||
use crate::client_api::StdioExecServerConnectArgs;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::process::ExecProcessEvent;
|
||||
use crate::process::ExecProcessEventLog;
|
||||
@@ -105,6 +106,16 @@ impl From<RemoteExecServerConnectArgs> for ExecServerClientConnectOptions {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StdioExecServerConnectArgs> for ExecServerClientConnectOptions {
|
||||
fn from(value: StdioExecServerConnectArgs) -> Self {
|
||||
Self {
|
||||
client_name: value.client_name,
|
||||
initialize_timeout: value.initialize_timeout,
|
||||
resume_session_id: value.resume_session_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RemoteExecServerConnectArgs {
|
||||
pub fn new(websocket_url: String, client_name: String) -> Self {
|
||||
Self {
|
||||
@@ -180,29 +191,23 @@ pub struct ExecServerClient {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct LazyRemoteExecServerClient {
|
||||
websocket_url: String,
|
||||
transport: ExecServerTransport,
|
||||
client: Arc<OnceCell<ExecServerClient>>,
|
||||
}
|
||||
|
||||
impl LazyRemoteExecServerClient {
|
||||
pub(crate) fn new(websocket_url: String) -> Self {
|
||||
pub(crate) fn new(transport: ExecServerTransport) -> Self {
|
||||
Self {
|
||||
websocket_url,
|
||||
transport,
|
||||
client: Arc::new(OnceCell::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn get(&self) -> Result<ExecServerClient, ExecServerError> {
|
||||
self.client
|
||||
.get_or_try_init(|| async {
|
||||
ExecServerClient::connect_websocket(RemoteExecServerConnectArgs {
|
||||
websocket_url: self.websocket_url.clone(),
|
||||
client_name: "codex-environment".to_string(),
|
||||
connect_timeout: Duration::from_secs(5),
|
||||
initialize_timeout: Duration::from_secs(5),
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
.get_or_try_init(|| {
|
||||
let transport = self.transport.clone();
|
||||
async move { transport.connect_for_environment().await }
|
||||
})
|
||||
.await
|
||||
.cloned()
|
||||
@@ -257,32 +262,6 @@ pub enum ExecServerError {
|
||||
}
|
||||
|
||||
impl ExecServerClient {
|
||||
pub async fn connect_websocket(
|
||||
args: RemoteExecServerConnectArgs,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
let websocket_url = args.websocket_url.clone();
|
||||
let connect_timeout = args.connect_timeout;
|
||||
let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str()))
|
||||
.await
|
||||
.map_err(|_| ExecServerError::WebSocketConnectTimeout {
|
||||
url: websocket_url.clone(),
|
||||
timeout: connect_timeout,
|
||||
})?
|
||||
.map_err(|source| ExecServerError::WebSocketConnect {
|
||||
url: websocket_url.clone(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
Self::connect(
|
||||
JsonRpcConnection::from_websocket(
|
||||
stream,
|
||||
format!("exec-server websocket {websocket_url}"),
|
||||
),
|
||||
args.into(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn initialize(
|
||||
&self,
|
||||
options: ExecServerClientConnectOptions,
|
||||
@@ -431,7 +410,7 @@ impl ExecServerClient {
|
||||
.clone()
|
||||
}
|
||||
|
||||
async fn connect(
|
||||
pub(crate) async fn connect(
|
||||
connection: JsonRpcConnection,
|
||||
options: ExecServerClientConnectOptions,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
@@ -893,6 +872,8 @@ mod tests {
|
||||
use super::ExecServerClient;
|
||||
use super::ExecServerClientConnectOptions;
|
||||
use crate::ProcessId;
|
||||
#[cfg(not(windows))]
|
||||
use crate::StdioExecServerConnectArgs;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::process::ExecProcessEvent;
|
||||
use crate::protocol::EXEC_CLOSED_METHOD;
|
||||
@@ -930,6 +911,21 @@ mod tests {
|
||||
.expect("json-rpc line should write");
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[tokio::test]
|
||||
async fn connect_stdio_command_initializes_json_rpc_client() {
|
||||
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
shell_command: "read _line; printf '%s\\n' '{\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}}'; read _line; sleep 60".to_string(),
|
||||
client_name: "stdio-test-client".to_string(),
|
||||
initialize_timeout: Duration::from_secs(1),
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
.expect("stdio client should connect");
|
||||
|
||||
assert_eq!(client.session_id().as_deref(), Some("stdio-test"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_events_are_delivered_in_seq_order_when_notifications_are_reordered() {
|
||||
let (client_stdin, server_reader) = duplex(1 << 20);
|
||||
|
||||
@@ -25,6 +25,22 @@ pub struct RemoteExecServerConnectArgs {
|
||||
pub resume_session_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Stdio connection arguments for a command-backed exec-server.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct StdioExecServerConnectArgs {
|
||||
pub shell_command: String,
|
||||
pub client_name: String,
|
||||
pub initialize_timeout: Duration,
|
||||
pub resume_session_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Transport used to connect to a remote exec-server environment.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ExecServerTransport {
|
||||
WebSocketUrl(String),
|
||||
StdioShellCommand(String),
|
||||
}
|
||||
|
||||
/// Sends HTTP requests through a runtime-selected transport.
|
||||
///
|
||||
/// This is the HTTP capability counterpart to [`crate::ExecBackend`]. Callers
|
||||
|
||||
176
codex-rs/exec-server/src/client_transport.rs
Normal file
176
codex-rs/exec-server/src/client_transport.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::process::Child;
|
||||
use tokio::process::Command;
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tracing::debug;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::ExecServerClient;
|
||||
use crate::ExecServerError;
|
||||
use crate::client_api::ExecServerTransport;
|
||||
use crate::client_api::RemoteExecServerConnectArgs;
|
||||
use crate::client_api::StdioExecServerConnectArgs;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
|
||||
const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment";
|
||||
const ENVIRONMENT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const ENVIRONMENT_INITIALIZE_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
impl ExecServerTransport {
|
||||
pub(crate) async fn connect_for_environment(self) -> Result<ExecServerClient, ExecServerError> {
|
||||
match self {
|
||||
ExecServerTransport::WebSocketUrl(websocket_url) => {
|
||||
ExecServerClient::connect_websocket(RemoteExecServerConnectArgs {
|
||||
websocket_url,
|
||||
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
|
||||
connect_timeout: ENVIRONMENT_CONNECT_TIMEOUT,
|
||||
initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT,
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
}
|
||||
ExecServerTransport::StdioShellCommand(shell_command) => {
|
||||
ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
shell_command,
|
||||
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
|
||||
initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT,
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecServerClient {
|
||||
pub async fn connect_websocket(
|
||||
args: RemoteExecServerConnectArgs,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
let websocket_url = args.websocket_url.clone();
|
||||
let connect_timeout = args.connect_timeout;
|
||||
let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str()))
|
||||
.await
|
||||
.map_err(|_| ExecServerError::WebSocketConnectTimeout {
|
||||
url: websocket_url.clone(),
|
||||
timeout: connect_timeout,
|
||||
})?
|
||||
.map_err(|source| ExecServerError::WebSocketConnect {
|
||||
url: websocket_url.clone(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
Self::connect(
|
||||
JsonRpcConnection::from_websocket(
|
||||
stream,
|
||||
format!("exec-server websocket {websocket_url}"),
|
||||
),
|
||||
args.into(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn connect_stdio_command(
|
||||
args: StdioExecServerConnectArgs,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
let shell_command = args.shell_command.clone();
|
||||
let mut child = shell_command_process(&shell_command)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(ExecServerError::Spawn)?;
|
||||
|
||||
let stdin = child.stdin.take().ok_or_else(|| {
|
||||
ExecServerError::Protocol("spawned exec-server command has no stdin".to_string())
|
||||
})?;
|
||||
let stdout = child.stdout.take().ok_or_else(|| {
|
||||
ExecServerError::Protocol("spawned exec-server command has no stdout".to_string())
|
||||
})?;
|
||||
if let Some(stderr) = child.stderr.take() {
|
||||
tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(stderr).lines();
|
||||
loop {
|
||||
match lines.next_line().await {
|
||||
Ok(Some(line)) => debug!("exec-server stdio stderr: {line}"),
|
||||
Ok(None) => break,
|
||||
Err(err) => {
|
||||
warn!("failed to read exec-server stdio stderr: {err}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Self::connect(
|
||||
JsonRpcConnection::from_stdio(
|
||||
stdout,
|
||||
stdin,
|
||||
format!("exec-server stdio command `{shell_command}`"),
|
||||
)
|
||||
.with_lifetime_guard(Box::new(StdioChildGuard { child: Some(child) })),
|
||||
args.into(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
struct StdioChildGuard {
|
||||
child: Option<Child>,
|
||||
}
|
||||
|
||||
impl Drop for StdioChildGuard {
|
||||
fn drop(&mut self) {
|
||||
let Some(child) = self.child.take() else {
|
||||
return;
|
||||
};
|
||||
|
||||
match Handle::try_current() {
|
||||
Ok(handle) => {
|
||||
let _terminate_task = handle.spawn(terminate_stdio_child(child));
|
||||
}
|
||||
Err(_) => {
|
||||
terminate_stdio_child_now(child);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn terminate_stdio_child(mut child: Child) {
|
||||
kill_stdio_child(&mut child);
|
||||
if let Err(err) = child.wait().await {
|
||||
debug!("failed to wait for exec-server stdio child: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
fn terminate_stdio_child_now(mut child: Child) {
|
||||
kill_stdio_child(&mut child);
|
||||
}
|
||||
|
||||
fn kill_stdio_child(child: &mut Child) {
|
||||
if let Err(err) = child.start_kill() {
|
||||
debug!("failed to terminate exec-server stdio child: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
fn shell_command_process(shell_command: &str) -> Command {
|
||||
#[cfg(windows)]
|
||||
{
|
||||
let mut command = Command::new("cmd");
|
||||
command.arg("/C").arg(shell_command);
|
||||
command
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
{
|
||||
let mut command = Command::new("sh");
|
||||
command.arg("-lc").arg(shell_command);
|
||||
command
|
||||
}
|
||||
}
|
||||
@@ -8,17 +8,22 @@ use tokio::sync::watch;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
#[cfg(test)]
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
#[cfg(test)]
|
||||
use tokio::io::AsyncWriteExt;
|
||||
#[cfg(test)]
|
||||
use tokio::io::BufReader;
|
||||
#[cfg(test)]
|
||||
use tokio::io::BufWriter;
|
||||
|
||||
pub(crate) const CHANNEL_CAPACITY: usize = 128;
|
||||
|
||||
pub(crate) type JsonRpcConnectionLifetimeGuard = Box<dyn Send>;
|
||||
pub(crate) type JsonRpcConnectionParts = (
|
||||
mpsc::Sender<JSONRPCMessage>,
|
||||
mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
watch::Receiver<bool>,
|
||||
Vec<tokio::task::JoinHandle<()>>,
|
||||
Option<JsonRpcConnectionLifetimeGuard>,
|
||||
);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum JsonRpcConnectionEvent {
|
||||
Message(JSONRPCMessage),
|
||||
@@ -31,10 +36,10 @@ pub(crate) struct JsonRpcConnection {
|
||||
incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
disconnected_rx: watch::Receiver<bool>,
|
||||
task_handles: Vec<tokio::task::JoinHandle<()>>,
|
||||
lifetime_guard: Option<JsonRpcConnectionLifetimeGuard>,
|
||||
}
|
||||
|
||||
impl JsonRpcConnection {
|
||||
#[cfg(test)]
|
||||
pub(crate) fn from_stdio<R, W>(reader: R, writer: W, connection_label: String) -> Self
|
||||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
@@ -122,6 +127,7 @@ impl JsonRpcConnection {
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
lifetime_guard: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,22 +262,22 @@ impl JsonRpcConnection {
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
lifetime_guard: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn into_parts(
|
||||
self,
|
||||
) -> (
|
||||
mpsc::Sender<JSONRPCMessage>,
|
||||
mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
watch::Receiver<bool>,
|
||||
Vec<tokio::task::JoinHandle<()>>,
|
||||
) {
|
||||
pub(crate) fn with_lifetime_guard(mut self, guard: JsonRpcConnectionLifetimeGuard) -> Self {
|
||||
self.lifetime_guard = Some(guard);
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn into_parts(self) -> JsonRpcConnectionParts {
|
||||
(
|
||||
self.outgoing_tx,
|
||||
self.incoming_rx,
|
||||
self.disconnected_rx,
|
||||
self.task_handles,
|
||||
self.lifetime_guard,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -298,7 +304,6 @@ async fn send_malformed_message(
|
||||
.await;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn write_jsonrpc_line_message<W>(
|
||||
writer: &mut BufWriter<W>,
|
||||
message: &JSONRPCMessage,
|
||||
|
||||
@@ -7,6 +7,7 @@ use crate::ExecutorFileSystem;
|
||||
use crate::HttpClient;
|
||||
use crate::client::LazyRemoteExecServerClient;
|
||||
use crate::client::http_client::ReqwestHttpClient;
|
||||
use crate::client_api::ExecServerTransport;
|
||||
use crate::environment_provider::DefaultEnvironmentProvider;
|
||||
use crate::environment_provider::EnvironmentProvider;
|
||||
use crate::environment_provider::normalize_exec_server_url;
|
||||
@@ -274,7 +275,9 @@ impl Environment {
|
||||
exec_server_url: String,
|
||||
local_runtime_paths: Option<ExecServerRuntimePaths>,
|
||||
) -> Self {
|
||||
let client = LazyRemoteExecServerClient::new(exec_server_url.clone());
|
||||
let client = LazyRemoteExecServerClient::new(ExecServerTransport::WebSocketUrl(
|
||||
exec_server_url.clone(),
|
||||
));
|
||||
let exec_backend: Arc<dyn ExecBackend> = Arc::new(RemoteProcess::new(client.clone()));
|
||||
let filesystem: Arc<dyn ExecutorFileSystem> =
|
||||
Arc::new(RemoteFileSystem::new(client.clone()));
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
mod client;
|
||||
mod client_api;
|
||||
mod client_transport;
|
||||
mod connection;
|
||||
mod environment;
|
||||
mod environment_provider;
|
||||
@@ -23,8 +24,10 @@ pub use client::ExecServerError;
|
||||
pub use client::http_client::HttpResponseBodyStream;
|
||||
pub use client::http_client::ReqwestHttpClient;
|
||||
pub use client_api::ExecServerClientConnectOptions;
|
||||
pub use client_api::ExecServerTransport;
|
||||
pub use client_api::HttpClient;
|
||||
pub use client_api::RemoteExecServerConnectArgs;
|
||||
pub use client_api::StdioExecServerConnectArgs;
|
||||
pub use codex_file_system::CopyOptions;
|
||||
pub use codex_file_system::CreateDirectoryOptions;
|
||||
pub use codex_file_system::ExecutorFileSystem;
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
@@ -23,6 +24,7 @@ use tokio::task::JoinHandle;
|
||||
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::connection::JsonRpcConnectionEvent;
|
||||
use crate::connection::JsonRpcConnectionLifetimeGuard;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum RpcCallError {
|
||||
@@ -229,12 +231,14 @@ pub(crate) struct RpcClient {
|
||||
disconnected_rx: watch::Receiver<bool>,
|
||||
next_request_id: AtomicI64,
|
||||
transport_tasks: Vec<JoinHandle<()>>,
|
||||
_transport_lifetime_guard: Option<StdMutex<JsonRpcConnectionLifetimeGuard>>,
|
||||
reader_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl RpcClient {
|
||||
pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver<RpcClientEvent>) {
|
||||
let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks) = connection.into_parts();
|
||||
let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks, lifetime_guard) =
|
||||
connection.into_parts();
|
||||
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
|
||||
let (event_tx, event_rx) = mpsc::channel(128);
|
||||
|
||||
@@ -275,6 +279,7 @@ impl RpcClient {
|
||||
disconnected_rx,
|
||||
next_request_id: AtomicI64::new(1),
|
||||
transport_tasks,
|
||||
_transport_lifetime_guard: lifetime_guard.map(StdMutex::new),
|
||||
reader_task,
|
||||
},
|
||||
event_rx,
|
||||
|
||||
@@ -47,7 +47,7 @@ async fn run_connection(
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
) {
|
||||
let router = Arc::new(build_router());
|
||||
let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) =
|
||||
let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks, _lifetime_guard) =
|
||||
connection.into_parts();
|
||||
let (outgoing_tx, mut outgoing_rx) =
|
||||
mpsc::channel::<RpcServerOutboundMessage>(CHANNEL_CAPACITY);
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use std::io::Write as _;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::io;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_tungstenite::accept_async;
|
||||
use tracing::warn;
|
||||
@@ -10,6 +13,12 @@ use crate::server::processor::ConnectionProcessor;
|
||||
|
||||
pub const DEFAULT_LISTEN_URL: &str = "ws://127.0.0.1:0";
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub(crate) enum ExecServerListenTransport {
|
||||
WebSocket(SocketAddr),
|
||||
Stdio,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum ExecServerListenUrlParseError {
|
||||
UnsupportedListenUrl(String),
|
||||
@@ -21,7 +30,7 @@ impl std::fmt::Display for ExecServerListenUrlParseError {
|
||||
match self {
|
||||
ExecServerListenUrlParseError::UnsupportedListenUrl(listen_url) => write!(
|
||||
f,
|
||||
"unsupported --listen URL `{listen_url}`; expected `ws://IP:PORT`"
|
||||
"unsupported --listen URL `{listen_url}`; expected `ws://IP:PORT` or `stdio`"
|
||||
),
|
||||
ExecServerListenUrlParseError::InvalidWebSocketListenUrl(listen_url) => write!(
|
||||
f,
|
||||
@@ -35,11 +44,18 @@ impl std::error::Error for ExecServerListenUrlParseError {}
|
||||
|
||||
pub(crate) fn parse_listen_url(
|
||||
listen_url: &str,
|
||||
) -> Result<SocketAddr, ExecServerListenUrlParseError> {
|
||||
) -> Result<ExecServerListenTransport, ExecServerListenUrlParseError> {
|
||||
if matches!(listen_url, "stdio" | "stdio://") {
|
||||
return Ok(ExecServerListenTransport::Stdio);
|
||||
}
|
||||
|
||||
if let Some(socket_addr) = listen_url.strip_prefix("ws://") {
|
||||
return socket_addr.parse::<SocketAddr>().map_err(|_| {
|
||||
ExecServerListenUrlParseError::InvalidWebSocketListenUrl(listen_url.to_string())
|
||||
});
|
||||
return socket_addr
|
||||
.parse::<SocketAddr>()
|
||||
.map(ExecServerListenTransport::WebSocket)
|
||||
.map_err(|_| {
|
||||
ExecServerListenUrlParseError::InvalidWebSocketListenUrl(listen_url.to_string())
|
||||
});
|
||||
}
|
||||
|
||||
Err(ExecServerListenUrlParseError::UnsupportedListenUrl(
|
||||
@@ -51,8 +67,39 @@ pub(crate) async fn run_transport(
|
||||
listen_url: &str,
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let bind_address = parse_listen_url(listen_url)?;
|
||||
run_websocket_listener(bind_address, runtime_paths).await
|
||||
match parse_listen_url(listen_url)? {
|
||||
ExecServerListenTransport::WebSocket(bind_address) => {
|
||||
run_websocket_listener(bind_address, runtime_paths).await
|
||||
}
|
||||
ExecServerListenTransport::Stdio => run_stdio_connection(runtime_paths).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_stdio_connection(
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
run_stdio_connection_with_io(io::stdin(), io::stdout(), runtime_paths).await
|
||||
}
|
||||
|
||||
async fn run_stdio_connection_with_io<R, W>(
|
||||
reader: R,
|
||||
writer: W,
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
|
||||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let processor = ConnectionProcessor::new(runtime_paths);
|
||||
tracing::info!("codex-exec-server listening on stdio");
|
||||
processor
|
||||
.run_connection(JsonRpcConnection::from_stdio(
|
||||
reader,
|
||||
writer,
|
||||
"exec-server stdio".to_string(),
|
||||
))
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_websocket_listener(
|
||||
|
||||
@@ -1,31 +1,127 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::io::duplex;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::DEFAULT_LISTEN_URL;
|
||||
use super::ExecServerListenTransport;
|
||||
use super::parse_listen_url;
|
||||
use super::run_stdio_connection_with_io;
|
||||
use crate::ExecServerRuntimePaths;
|
||||
use crate::protocol::INITIALIZE_METHOD;
|
||||
use crate::protocol::INITIALIZED_METHOD;
|
||||
use crate::protocol::InitializeParams;
|
||||
use crate::protocol::InitializeResponse;
|
||||
|
||||
#[test]
|
||||
fn parse_listen_url_accepts_default_websocket_url() {
|
||||
let bind_address =
|
||||
parse_listen_url(DEFAULT_LISTEN_URL).expect("default listen URL should parse");
|
||||
let transport = parse_listen_url(DEFAULT_LISTEN_URL).expect("default listen URL should parse");
|
||||
assert_eq!(
|
||||
bind_address,
|
||||
"127.0.0.1:0"
|
||||
.parse::<SocketAddr>()
|
||||
.expect("valid socket address")
|
||||
transport,
|
||||
ExecServerListenTransport::WebSocket(
|
||||
"127.0.0.1:0"
|
||||
.parse::<SocketAddr>()
|
||||
.expect("valid socket address")
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_listen_url_accepts_stdio() {
|
||||
let transport = parse_listen_url("stdio").expect("stdio listen URL should parse");
|
||||
assert_eq!(transport, ExecServerListenTransport::Stdio);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_listen_url_accepts_stdio_url() {
|
||||
let transport = parse_listen_url("stdio://").expect("stdio listen URL should parse");
|
||||
assert_eq!(transport, ExecServerListenTransport::Stdio);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stdio_listen_transport_serves_initialize() {
|
||||
let transport = parse_listen_url("stdio").expect("stdio listen URL should parse");
|
||||
let ExecServerListenTransport::Stdio = transport else {
|
||||
panic!("expected stdio listen transport, got {transport:?}");
|
||||
};
|
||||
|
||||
let (mut client_writer, server_reader) = duplex(1 << 20);
|
||||
let (server_writer, client_reader) = duplex(1 << 20);
|
||||
let server_task = tokio::spawn(run_stdio_connection_with_io(
|
||||
server_reader,
|
||||
server_writer,
|
||||
test_runtime_paths(),
|
||||
));
|
||||
let mut client_lines = BufReader::new(client_reader).lines();
|
||||
|
||||
let initialize = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(1),
|
||||
method: INITIALIZE_METHOD.to_string(),
|
||||
params: Some(
|
||||
serde_json::to_value(InitializeParams {
|
||||
client_name: "exec-server-transport-test".to_string(),
|
||||
resume_session_id: None,
|
||||
})
|
||||
.expect("initialize params should serialize"),
|
||||
),
|
||||
trace: None,
|
||||
});
|
||||
write_jsonrpc_line(&mut client_writer, &initialize).await;
|
||||
|
||||
let response = timeout(Duration::from_secs(1), client_lines.next_line())
|
||||
.await
|
||||
.expect("initialize response should arrive")
|
||||
.expect("initialize response read should succeed")
|
||||
.expect("initialize response should be present");
|
||||
let response: JSONRPCMessage =
|
||||
serde_json::from_str(&response).expect("initialize response should parse");
|
||||
let JSONRPCMessage::Response(JSONRPCResponse { id, result }) = response else {
|
||||
panic!("expected initialize response, got {response:?}");
|
||||
};
|
||||
assert_eq!(id, RequestId::Integer(1));
|
||||
let initialize_response: InitializeResponse =
|
||||
serde_json::from_value(result).expect("initialize response should decode");
|
||||
assert!(
|
||||
!initialize_response.session_id.is_empty(),
|
||||
"initialize should return a session id"
|
||||
);
|
||||
|
||||
let initialized = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: INITIALIZED_METHOD.to_string(),
|
||||
params: Some(serde_json::to_value(()).expect("initialized params should serialize")),
|
||||
});
|
||||
write_jsonrpc_line(&mut client_writer, &initialized).await;
|
||||
|
||||
drop(client_writer);
|
||||
drop(client_lines);
|
||||
timeout(Duration::from_secs(1), server_task)
|
||||
.await
|
||||
.expect("stdio transport should finish after client disconnect")
|
||||
.expect("stdio transport task should join")
|
||||
.expect("stdio transport should not fail");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_listen_url_accepts_websocket_url() {
|
||||
let bind_address =
|
||||
let transport =
|
||||
parse_listen_url("ws://127.0.0.1:1234").expect("websocket listen URL should parse");
|
||||
assert_eq!(
|
||||
bind_address,
|
||||
"127.0.0.1:1234"
|
||||
.parse::<SocketAddr>()
|
||||
.expect("valid socket address")
|
||||
transport,
|
||||
ExecServerListenTransport::WebSocket(
|
||||
"127.0.0.1:1234"
|
||||
.parse::<SocketAddr>()
|
||||
.expect("valid socket address")
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -45,6 +141,26 @@ fn parse_listen_url_rejects_unsupported_url() {
|
||||
parse_listen_url("http://127.0.0.1:1234").expect_err("unsupported scheme should fail");
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"unsupported --listen URL `http://127.0.0.1:1234`; expected `ws://IP:PORT`"
|
||||
"unsupported --listen URL `http://127.0.0.1:1234`; expected `ws://IP:PORT` or `stdio`"
|
||||
);
|
||||
}
|
||||
|
||||
async fn write_jsonrpc_line(writer: &mut tokio::io::DuplexStream, message: &JSONRPCMessage) {
|
||||
let encoded = serde_json::to_vec(message).expect("JSON-RPC message should serialize");
|
||||
writer
|
||||
.write_all(&encoded)
|
||||
.await
|
||||
.expect("JSON-RPC message should write");
|
||||
writer
|
||||
.write_all(b"\n")
|
||||
.await
|
||||
.expect("JSON-RPC newline should write");
|
||||
}
|
||||
|
||||
fn test_runtime_paths() -> ExecServerRuntimePaths {
|
||||
ExecServerRuntimePaths::new(
|
||||
std::env::current_exe().expect("current exe"),
|
||||
/*codex_linux_sandbox_exe*/ None,
|
||||
)
|
||||
.expect("runtime paths")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user