mirror of
https://github.com/openai/codex.git
synced 2026-05-15 16:53:05 +00:00
Add stdio exec-server client transport
Allow exec-server clients to connect through a shell command over stdio. The connection can now retain a drop resource so the spawned child is terminated when the JSON-RPC client is dropped. Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -17,13 +17,14 @@ use tokio::sync::mpsc;
|
||||
use tokio::sync::watch;
|
||||
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::ProcessId;
|
||||
use crate::client_api::ExecServerClientConnectOptions;
|
||||
use crate::client_api::ExecServerTransport;
|
||||
use crate::client_api::HttpClient;
|
||||
use crate::client_api::RemoteExecServerConnectArgs;
|
||||
use crate::client_api::StdioExecServerConnectArgs;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::process::ExecProcessEvent;
|
||||
use crate::process::ExecProcessEventLog;
|
||||
@@ -105,6 +106,16 @@ impl From<RemoteExecServerConnectArgs> for ExecServerClientConnectOptions {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StdioExecServerConnectArgs> for ExecServerClientConnectOptions {
|
||||
fn from(value: StdioExecServerConnectArgs) -> Self {
|
||||
Self {
|
||||
client_name: value.client_name,
|
||||
initialize_timeout: value.initialize_timeout,
|
||||
resume_session_id: value.resume_session_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RemoteExecServerConnectArgs {
|
||||
pub fn new(websocket_url: String, client_name: String) -> Self {
|
||||
Self {
|
||||
@@ -180,29 +191,23 @@ pub struct ExecServerClient {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct LazyRemoteExecServerClient {
|
||||
websocket_url: String,
|
||||
transport: ExecServerTransport,
|
||||
client: Arc<OnceCell<ExecServerClient>>,
|
||||
}
|
||||
|
||||
impl LazyRemoteExecServerClient {
|
||||
pub(crate) fn new(websocket_url: String) -> Self {
|
||||
pub(crate) fn new(transport: ExecServerTransport) -> Self {
|
||||
Self {
|
||||
websocket_url,
|
||||
transport,
|
||||
client: Arc::new(OnceCell::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn get(&self) -> Result<ExecServerClient, ExecServerError> {
|
||||
self.client
|
||||
.get_or_try_init(|| async {
|
||||
ExecServerClient::connect_websocket(RemoteExecServerConnectArgs {
|
||||
websocket_url: self.websocket_url.clone(),
|
||||
client_name: "codex-environment".to_string(),
|
||||
connect_timeout: Duration::from_secs(5),
|
||||
initialize_timeout: Duration::from_secs(5),
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
.get_or_try_init(|| {
|
||||
let transport = self.transport.clone();
|
||||
async move { transport.connect_for_environment().await }
|
||||
})
|
||||
.await
|
||||
.cloned()
|
||||
@@ -269,32 +274,6 @@ pub enum ExecServerError {
|
||||
}
|
||||
|
||||
impl ExecServerClient {
|
||||
pub async fn connect_websocket(
|
||||
args: RemoteExecServerConnectArgs,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
let websocket_url = args.websocket_url.clone();
|
||||
let connect_timeout = args.connect_timeout;
|
||||
let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str()))
|
||||
.await
|
||||
.map_err(|_| ExecServerError::WebSocketConnectTimeout {
|
||||
url: websocket_url.clone(),
|
||||
timeout: connect_timeout,
|
||||
})?
|
||||
.map_err(|source| ExecServerError::WebSocketConnect {
|
||||
url: websocket_url.clone(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
Self::connect(
|
||||
JsonRpcConnection::from_websocket(
|
||||
stream,
|
||||
format!("exec-server websocket {websocket_url}"),
|
||||
),
|
||||
args.into(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn initialize(
|
||||
&self,
|
||||
options: ExecServerClientConnectOptions,
|
||||
@@ -443,7 +422,7 @@ impl ExecServerClient {
|
||||
.clone()
|
||||
}
|
||||
|
||||
async fn connect(
|
||||
pub(crate) async fn connect(
|
||||
connection: JsonRpcConnection,
|
||||
options: ExecServerClientConnectOptions,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
@@ -905,6 +884,7 @@ mod tests {
|
||||
use super::ExecServerClient;
|
||||
use super::ExecServerClientConnectOptions;
|
||||
use crate::ProcessId;
|
||||
use crate::client_api::StdioExecServerConnectArgs;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::process::ExecProcessEvent;
|
||||
use crate::protocol::EXEC_CLOSED_METHOD;
|
||||
@@ -942,6 +922,21 @@ mod tests {
|
||||
.expect("json-rpc line should write");
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[tokio::test]
|
||||
async fn connect_stdio_command_initializes_json_rpc_client() {
|
||||
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
shell_command: "read _line; printf '%s\\n' '{\"id\":1,\"result\":{\"sessionId\":\"stdio-test\"}}'; read _line; sleep 60".to_string(),
|
||||
client_name: "stdio-test-client".to_string(),
|
||||
initialize_timeout: Duration::from_secs(1),
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
.expect("stdio client should connect");
|
||||
|
||||
assert_eq!(client.session_id().as_deref(), Some("stdio-test"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_events_are_delivered_in_seq_order_when_notifications_are_reordered() {
|
||||
let (client_stdin, server_reader) = duplex(1 << 20);
|
||||
|
||||
@@ -25,6 +25,22 @@ pub struct RemoteExecServerConnectArgs {
|
||||
pub resume_session_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Stdio connection arguments for a command-backed exec-server.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct StdioExecServerConnectArgs {
|
||||
pub shell_command: String,
|
||||
pub client_name: String,
|
||||
pub initialize_timeout: Duration,
|
||||
pub resume_session_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Transport used to connect to a remote exec-server environment.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ExecServerTransport {
|
||||
WebSocketUrl(String),
|
||||
StdioShellCommand(String),
|
||||
}
|
||||
|
||||
/// Sends HTTP requests through a runtime-selected transport.
|
||||
///
|
||||
/// This is the HTTP capability counterpart to [`crate::ExecBackend`]. Callers
|
||||
|
||||
176
codex-rs/exec-server/src/client_transport.rs
Normal file
176
codex-rs/exec-server/src/client_transport.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::process::Child;
|
||||
use tokio::process::Command;
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tracing::debug;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::ExecServerClient;
|
||||
use crate::ExecServerError;
|
||||
use crate::client_api::ExecServerTransport;
|
||||
use crate::client_api::RemoteExecServerConnectArgs;
|
||||
use crate::client_api::StdioExecServerConnectArgs;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
|
||||
const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment";
|
||||
const ENVIRONMENT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const ENVIRONMENT_INITIALIZE_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
impl ExecServerTransport {
|
||||
pub(crate) async fn connect_for_environment(self) -> Result<ExecServerClient, ExecServerError> {
|
||||
match self {
|
||||
ExecServerTransport::WebSocketUrl(websocket_url) => {
|
||||
ExecServerClient::connect_websocket(RemoteExecServerConnectArgs {
|
||||
websocket_url,
|
||||
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
|
||||
connect_timeout: ENVIRONMENT_CONNECT_TIMEOUT,
|
||||
initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT,
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
}
|
||||
ExecServerTransport::StdioShellCommand(shell_command) => {
|
||||
ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
shell_command,
|
||||
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
|
||||
initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT,
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecServerClient {
|
||||
pub async fn connect_websocket(
|
||||
args: RemoteExecServerConnectArgs,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
let websocket_url = args.websocket_url.clone();
|
||||
let connect_timeout = args.connect_timeout;
|
||||
let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str()))
|
||||
.await
|
||||
.map_err(|_| ExecServerError::WebSocketConnectTimeout {
|
||||
url: websocket_url.clone(),
|
||||
timeout: connect_timeout,
|
||||
})?
|
||||
.map_err(|source| ExecServerError::WebSocketConnect {
|
||||
url: websocket_url.clone(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
Self::connect(
|
||||
JsonRpcConnection::from_websocket(
|
||||
stream,
|
||||
format!("exec-server websocket {websocket_url}"),
|
||||
),
|
||||
args.into(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn connect_stdio_command(
|
||||
args: StdioExecServerConnectArgs,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
let shell_command = args.shell_command.clone();
|
||||
let mut child = shell_command_process(&shell_command)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(ExecServerError::Spawn)?;
|
||||
|
||||
let stdin = child.stdin.take().ok_or_else(|| {
|
||||
ExecServerError::Protocol("spawned exec-server command has no stdin".to_string())
|
||||
})?;
|
||||
let stdout = child.stdout.take().ok_or_else(|| {
|
||||
ExecServerError::Protocol("spawned exec-server command has no stdout".to_string())
|
||||
})?;
|
||||
if let Some(stderr) = child.stderr.take() {
|
||||
tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(stderr).lines();
|
||||
loop {
|
||||
match lines.next_line().await {
|
||||
Ok(Some(line)) => debug!("exec-server stdio stderr: {line}"),
|
||||
Ok(None) => break,
|
||||
Err(err) => {
|
||||
warn!("failed to read exec-server stdio stderr: {err}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Self::connect(
|
||||
JsonRpcConnection::from_stdio(
|
||||
stdout,
|
||||
stdin,
|
||||
format!("exec-server stdio command `{shell_command}`"),
|
||||
)
|
||||
.with_lifetime_guard(Box::new(StdioChildGuard { child: Some(child) })),
|
||||
args.into(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
struct StdioChildGuard {
|
||||
child: Option<Child>,
|
||||
}
|
||||
|
||||
impl Drop for StdioChildGuard {
|
||||
fn drop(&mut self) {
|
||||
let Some(child) = self.child.take() else {
|
||||
return;
|
||||
};
|
||||
|
||||
match Handle::try_current() {
|
||||
Ok(handle) => {
|
||||
let _terminate_task = handle.spawn(terminate_stdio_child(child));
|
||||
}
|
||||
Err(_) => {
|
||||
terminate_stdio_child_now(child);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn terminate_stdio_child(mut child: Child) {
|
||||
kill_stdio_child(&mut child);
|
||||
if let Err(err) = child.wait().await {
|
||||
debug!("failed to wait for exec-server stdio child: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
fn terminate_stdio_child_now(mut child: Child) {
|
||||
kill_stdio_child(&mut child);
|
||||
}
|
||||
|
||||
fn kill_stdio_child(child: &mut Child) {
|
||||
if let Err(err) = child.start_kill() {
|
||||
debug!("failed to terminate exec-server stdio child: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
fn shell_command_process(shell_command: &str) -> Command {
|
||||
#[cfg(windows)]
|
||||
{
|
||||
let mut command = Command::new("cmd");
|
||||
command.arg("/C").arg(shell_command);
|
||||
command
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
{
|
||||
let mut command = Command::new("sh");
|
||||
command.arg("-lc").arg(shell_command);
|
||||
command
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,8 @@ use tokio::io::BufWriter;
|
||||
|
||||
pub(crate) const CHANNEL_CAPACITY: usize = 128;
|
||||
|
||||
pub(crate) type JsonRpcConnectionLifetimeGuard = Box<dyn Send>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum JsonRpcConnectionEvent {
|
||||
Message(JSONRPCMessage),
|
||||
@@ -27,6 +29,7 @@ pub(crate) struct JsonRpcConnection {
|
||||
incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
disconnected_rx: watch::Receiver<bool>,
|
||||
task_handles: Vec<tokio::task::JoinHandle<()>>,
|
||||
lifetime_guard: Option<JsonRpcConnectionLifetimeGuard>,
|
||||
}
|
||||
|
||||
impl JsonRpcConnection {
|
||||
@@ -117,6 +120,7 @@ impl JsonRpcConnection {
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
lifetime_guard: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,9 +255,15 @@ impl JsonRpcConnection {
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
lifetime_guard: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_lifetime_guard(mut self, guard: JsonRpcConnectionLifetimeGuard) -> Self {
|
||||
self.lifetime_guard = Some(guard);
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn into_parts(
|
||||
self,
|
||||
) -> (
|
||||
@@ -261,12 +271,14 @@ impl JsonRpcConnection {
|
||||
mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
watch::Receiver<bool>,
|
||||
Vec<tokio::task::JoinHandle<()>>,
|
||||
Option<JsonRpcConnectionLifetimeGuard>,
|
||||
) {
|
||||
(
|
||||
self.outgoing_tx,
|
||||
self.incoming_rx,
|
||||
self.disconnected_rx,
|
||||
self.task_handles,
|
||||
self.lifetime_guard,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ use crate::ExecutorFileSystem;
|
||||
use crate::HttpClient;
|
||||
use crate::client::LazyRemoteExecServerClient;
|
||||
use crate::client::http_client::ReqwestHttpClient;
|
||||
use crate::client_api::ExecServerTransport;
|
||||
use crate::environment_provider::DefaultEnvironmentProvider;
|
||||
use crate::environment_provider::EnvironmentProvider;
|
||||
use crate::environment_provider::normalize_exec_server_url;
|
||||
@@ -274,7 +275,9 @@ impl Environment {
|
||||
exec_server_url: String,
|
||||
local_runtime_paths: Option<ExecServerRuntimePaths>,
|
||||
) -> Self {
|
||||
let client = LazyRemoteExecServerClient::new(exec_server_url.clone());
|
||||
let client = LazyRemoteExecServerClient::new(ExecServerTransport::WebSocketUrl(
|
||||
exec_server_url.clone(),
|
||||
));
|
||||
let exec_backend: Arc<dyn ExecBackend> = Arc::new(RemoteProcess::new(client.clone()));
|
||||
let filesystem: Arc<dyn ExecutorFileSystem> =
|
||||
Arc::new(RemoteFileSystem::new(client.clone()));
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
mod client;
|
||||
mod client_api;
|
||||
mod client_transport;
|
||||
mod connection;
|
||||
mod environment;
|
||||
mod environment_provider;
|
||||
@@ -24,8 +25,10 @@ pub use client::ExecServerError;
|
||||
pub use client::http_client::HttpResponseBodyStream;
|
||||
pub use client::http_client::ReqwestHttpClient;
|
||||
pub use client_api::ExecServerClientConnectOptions;
|
||||
pub use client_api::ExecServerTransport;
|
||||
pub use client_api::HttpClient;
|
||||
pub use client_api::RemoteExecServerConnectArgs;
|
||||
pub use client_api::StdioExecServerConnectArgs;
|
||||
pub use codex_file_system::CopyOptions;
|
||||
pub use codex_file_system::CreateDirectoryOptions;
|
||||
pub use codex_file_system::ExecutorFileSystem;
|
||||
|
||||
@@ -23,6 +23,7 @@ use tokio::task::JoinHandle;
|
||||
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::connection::JsonRpcConnectionEvent;
|
||||
use crate::connection::JsonRpcConnectionLifetimeGuard;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum RpcCallError {
|
||||
@@ -229,12 +230,14 @@ pub(crate) struct RpcClient {
|
||||
disconnected_rx: watch::Receiver<bool>,
|
||||
next_request_id: AtomicI64,
|
||||
transport_tasks: Vec<JoinHandle<()>>,
|
||||
_transport_lifetime_guard: Option<JsonRpcConnectionLifetimeGuard>,
|
||||
reader_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl RpcClient {
|
||||
pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver<RpcClientEvent>) {
|
||||
let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks) = connection.into_parts();
|
||||
let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks, lifetime_guard) =
|
||||
connection.into_parts();
|
||||
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
|
||||
let (event_tx, event_rx) = mpsc::channel(128);
|
||||
|
||||
@@ -275,6 +278,7 @@ impl RpcClient {
|
||||
disconnected_rx,
|
||||
next_request_id: AtomicI64::new(1),
|
||||
transport_tasks,
|
||||
_transport_lifetime_guard: lifetime_guard,
|
||||
reader_task,
|
||||
},
|
||||
event_rx,
|
||||
|
||||
@@ -47,7 +47,7 @@ async fn run_connection(
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
) {
|
||||
let router = Arc::new(build_router());
|
||||
let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) =
|
||||
let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks, _lifetime_guard) =
|
||||
connection.into_parts();
|
||||
let (outgoing_tx, mut outgoing_rx) =
|
||||
mpsc::channel::<RpcServerOutboundMessage>(CHANNEL_CAPACITY);
|
||||
|
||||
Reference in New Issue
Block a user