refactor(exec-server): tighten client lifecycle and output model

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
starr-openai
2026-03-17 02:34:49 +00:00
parent dc5f035527
commit 695b6ab90e
6 changed files with 334 additions and 167 deletions

View File

@@ -4,6 +4,7 @@ use std::sync::Mutex as StdMutex;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
use std::time::Duration;
use codex_app_server_protocol::JSONRPCError;
use codex_app_server_protocol::JSONRPCErrorError;
@@ -22,6 +23,7 @@ use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tokio::time::timeout;
use tokio_tungstenite::connect_async;
use tracing::debug;
use tracing::warn;
@@ -49,12 +51,14 @@ use crate::protocol::WriteResponse;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ExecServerClientConnectOptions {
pub client_name: String,
pub initialize_timeout: Duration,
}
impl Default for ExecServerClientConnectOptions {
fn default() -> Self {
Self {
client_name: "codex-core".to_string(),
initialize_timeout: INITIALIZE_TIMEOUT,
}
}
}
@@ -63,19 +67,42 @@ impl Default for ExecServerClientConnectOptions {
pub struct RemoteExecServerConnectArgs {
pub websocket_url: String,
pub client_name: String,
pub connect_timeout: Duration,
pub initialize_timeout: Duration,
}
impl From<RemoteExecServerConnectArgs> for ExecServerClientConnectOptions {
fn from(value: RemoteExecServerConnectArgs) -> Self {
Self {
client_name: value.client_name,
initialize_timeout: value.initialize_timeout,
}
}
}
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const INITIALIZE_TIMEOUT: Duration = Duration::from_secs(10);
impl RemoteExecServerConnectArgs {
pub fn new(websocket_url: String, client_name: String) -> Self {
Self {
websocket_url,
client_name,
connect_timeout: CONNECT_TIMEOUT,
initialize_timeout: INITIALIZE_TIMEOUT,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ExecServerOutput {
pub stream: crate::protocol::ExecOutputStream,
pub chunk: Vec<u8>,
}
pub struct ExecServerProcess {
process_id: String,
output_rx: broadcast::Receiver<Vec<u8>>,
output_rx: broadcast::Receiver<ExecServerOutput>,
writer_tx: mpsc::Sender<Vec<u8>>,
status: Arc<RemoteProcessStatus>,
client: ExecServerClient,
@@ -96,7 +123,7 @@ impl ExecServerProcess {
self.writer_tx.clone()
}
pub fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
pub fn output_receiver(&self) -> broadcast::Receiver<ExecServerOutput> {
self.output_rx.resubscribe()
}
@@ -109,7 +136,6 @@ impl ExecServerProcess {
}
pub fn terminate(&self) {
self.status.mark_exited(None);
let client = self.client.clone();
let process_id = self.process_id.clone();
tokio::spawn(async move {
@@ -157,7 +183,7 @@ impl RemoteProcessStatus {
}
struct RegisteredProcess {
output_tx: broadcast::Sender<Vec<u8>>,
output_tx: broadcast::Sender<ExecServerOutput>,
status: Arc<RemoteProcessStatus>,
}
@@ -184,12 +210,16 @@ pub struct ExecServerClient {
pub enum ExecServerError {
#[error("failed to spawn exec-server: {0}")]
Spawn(#[source] std::io::Error),
#[error("timed out connecting to exec-server websocket `{url}` after {timeout:?}")]
WebSocketConnectTimeout { url: String, timeout: Duration },
#[error("failed to connect to exec-server websocket `{url}`: {source}")]
WebSocketConnect {
url: String,
#[source]
source: tokio_tungstenite::tungstenite::Error,
},
#[error("timed out waiting for exec-server initialize handshake after {timeout:?}")]
InitializeTimedOut { timeout: Duration },
#[error("exec-server transport closed")]
Closed,
#[error("failed to serialize or deserialize exec-server JSON: {0}")]
@@ -221,8 +251,13 @@ impl ExecServerClient {
args: RemoteExecServerConnectArgs,
) -> Result<Self, ExecServerError> {
let websocket_url = args.websocket_url.clone();
let (stream, _) = connect_async(websocket_url.as_str())
let connect_timeout = args.connect_timeout;
let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str()))
.await
.map_err(|_| ExecServerError::WebSocketConnectTimeout {
url: websocket_url.clone(),
timeout: connect_timeout,
})?
.map_err(|source| ExecServerError::WebSocketConnect {
url: websocket_url.clone(),
source,
@@ -325,32 +360,12 @@ impl ExecServerClient {
return Err(err);
}
};
if !response.running {
status.mark_exited(response.exit_code);
}
if let Some(stdout) = response.stdout {
let _ = self
.inner
.processes
.lock()
.await
.get(&process_id)
.map(|process| process.output_tx.send(stdout.into_inner()));
}
if let Some(stderr) = response.stderr {
let _ = self
.inner
.processes
.lock()
.await
.get(&process_id)
.map(|process| process.output_tx.send(stderr.into_inner()));
}
if let Some(exit_code) = response.exit_code {
status.mark_exited(Some(exit_code));
if response.process_id != process_id {
self.inner.processes.lock().await.remove(&process_id);
return Err(ExecServerError::Protocol(format!(
"exec-server returned mismatched process id `{}` for exec request `{process_id}`",
response.process_id
)));
}
Ok(ExecServerProcess {
@@ -366,16 +381,21 @@ impl ExecServerClient {
&self,
options: ExecServerClientConnectOptions,
) -> Result<(), ExecServerError> {
let _: InitializeResponse = self
.request(
INITIALIZE_METHOD,
&InitializeParams {
client_name: options.client_name,
},
)
.await?;
self.notify(INITIALIZED_METHOD, &serde_json::json!({}))
.await
let ExecServerClientConnectOptions {
client_name,
initialize_timeout,
} = options;
timeout(initialize_timeout, async {
let _: InitializeResponse = self
.request(INITIALIZE_METHOD, &InitializeParams { client_name })
.await?;
self.notify(INITIALIZED_METHOD, &serde_json::json!({}))
.await
})
.await
.map_err(|_| ExecServerError::InitializeTimedOut {
timeout: initialize_timeout,
})?
}
async fn write_process(&self, params: WriteParams) -> Result<WriteResponse, ExecServerError> {
@@ -412,6 +432,7 @@ impl ExecServerClient {
P: Serialize,
R: DeserializeOwned,
{
let params = serde_json::to_value(params)?;
let request_id =
RequestId::Integer(self.inner.next_request_id.fetch_add(1, Ordering::SeqCst));
let (response_tx, response_rx) = oneshot::channel();
@@ -421,7 +442,6 @@ impl ExecServerClient {
.await
.insert(request_id.clone(), response_tx);
let params = serde_json::to_value(params)?;
let message = JSONRPCMessage::Request(JSONRPCRequest {
id: request_id.clone(),
method: method.to_string(),
@@ -482,10 +502,13 @@ async fn handle_server_notification(
EXEC_OUTPUT_DELTA_METHOD => {
let params: ExecOutputDeltaNotification =
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
let chunk = params.chunk.into_inner();
let output = ExecServerOutput {
stream: params.stream,
chunk: params.chunk.into_inner(),
};
let processes = inner.processes.lock().await;
if let Some(process) = processes.get(&params.process_id) {
let _ = process.output_tx.send(chunk);
let _ = process.output_tx.send(output);
}
}
EXEC_EXITED_METHOD => {
@@ -543,6 +566,9 @@ mod tests {
use super::ExecServerClientConnectOptions;
use super::ExecServerError;
use crate::protocol::EXEC_METHOD;
use crate::protocol::EXEC_OUTPUT_DELTA_METHOD;
use crate::protocol::EXEC_TERMINATE_METHOD;
use crate::protocol::ExecOutputStream;
use crate::protocol::ExecParams;
use crate::protocol::INITIALIZE_METHOD;
use crate::protocol::INITIALIZED_METHOD;
@@ -554,6 +580,13 @@ mod tests {
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::JSONRPCResponse;
fn test_options() -> ExecServerClientConnectOptions {
ExecServerClientConnectOptions {
client_name: "test-client".to_string(),
initialize_timeout: Duration::from_secs(1),
}
}
async fn read_jsonrpc_line<R>(lines: &mut tokio::io::Lines<BufReader<R>>) -> JSONRPCMessage
where
R: tokio::io::AsyncRead + Unpin,
@@ -625,14 +658,8 @@ mod tests {
assert_eq!(params, Some(serde_json::json!({})));
});
let client = ExecServerClient::connect_stdio(
client_stdin,
client_stdout,
ExecServerClientConnectOptions {
client_name: "test-client".to_string(),
},
)
.await;
let client =
ExecServerClient::connect_stdio(client_stdin, client_stdout, test_options()).await;
if let Err(err) = client {
panic!("failed to connect test client: {err}");
}
@@ -668,14 +695,8 @@ mod tests {
.await;
});
let result = ExecServerClient::connect_stdio(
client_stdin,
client_stdout,
ExecServerClientConnectOptions {
client_name: "test-client".to_string(),
},
)
.await;
let result =
ExecServerClient::connect_stdio(client_stdin, client_stdout, test_options()).await;
match result {
Err(ExecServerError::Server { code, message }) => {
@@ -736,9 +757,7 @@ mod tests {
let client = match ExecServerClient::connect_stdio(
client_stdin,
client_stdout,
ExecServerClientConnectOptions {
client_name: "test-client".to_string(),
},
test_options(),
)
.await
{
@@ -753,7 +772,6 @@ mod tests {
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
env: HashMap::new(),
tty: true,
output_bytes_cap: 4096,
arg0: None,
})
.await;
@@ -772,4 +790,215 @@ mod tests {
"failed requests should not leave registered process state behind"
);
}
#[tokio::test]
async fn connect_stdio_times_out_during_initialize_handshake() {
let (client_stdin, server_reader) = tokio::io::duplex(4096);
let (_server_writer, client_stdout) = tokio::io::duplex(4096);
tokio::spawn(async move {
let mut lines = BufReader::new(server_reader).lines();
let _ = read_jsonrpc_line(&mut lines).await;
tokio::time::sleep(Duration::from_millis(200)).await;
});
let result = ExecServerClient::connect_stdio(
client_stdin,
client_stdout,
ExecServerClientConnectOptions {
client_name: "test-client".to_string(),
initialize_timeout: Duration::from_millis(25),
},
)
.await;
match result {
Err(ExecServerError::InitializeTimedOut { timeout }) => {
assert_eq!(timeout, Duration::from_millis(25));
}
Err(err) => panic!("unexpected initialize timeout failure: {err}"),
Ok(_) => panic!("expected initialize timeout"),
}
}
#[tokio::test]
async fn start_process_preserves_output_stream_metadata() {
let (client_stdin, server_reader) = tokio::io::duplex(4096);
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
tokio::spawn(async move {
let mut lines = BufReader::new(server_reader).lines();
let initialize = read_jsonrpc_line(&mut lines).await;
let JSONRPCMessage::Request(initialize_request) = initialize else {
panic!("expected initialize request");
};
write_jsonrpc_line(
&mut server_writer,
JSONRPCMessage::Response(JSONRPCResponse {
id: initialize_request.id,
result: serde_json::json!({ "protocolVersion": PROTOCOL_VERSION }),
}),
)
.await;
let initialized = read_jsonrpc_line(&mut lines).await;
let JSONRPCMessage::Notification(notification) = initialized else {
panic!("expected initialized notification");
};
assert_eq!(notification.method, INITIALIZED_METHOD);
let exec_request = read_jsonrpc_line(&mut lines).await;
let JSONRPCMessage::Request(JSONRPCRequest { id, method, .. }) = exec_request else {
panic!("expected exec request");
};
assert_eq!(method, EXEC_METHOD);
write_jsonrpc_line(
&mut server_writer,
JSONRPCMessage::Response(JSONRPCResponse {
id,
result: serde_json::json!({ "processId": "proc-1" }),
}),
)
.await;
tokio::time::sleep(Duration::from_millis(25)).await;
write_jsonrpc_line(
&mut server_writer,
JSONRPCMessage::Notification(JSONRPCNotification {
method: EXEC_OUTPUT_DELTA_METHOD.to_string(),
params: Some(serde_json::json!({
"processId": "proc-1",
"stream": "stderr",
"chunk": "ZXJyb3IK"
})),
}),
)
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
});
let client = match ExecServerClient::connect_stdio(
client_stdin,
client_stdout,
test_options(),
)
.await
{
Ok(client) => client,
Err(err) => panic!("failed to connect test client: {err}"),
};
let process = match client
.start_process(ExecParams {
process_id: "proc-1".to_string(),
argv: vec!["bash".to_string(), "-lc".to_string(), "true".to_string()],
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
env: HashMap::new(),
tty: true,
arg0: None,
})
.await
{
Ok(process) => process,
Err(err) => panic!("failed to start process: {err}"),
};
let mut output = process.output_receiver();
let output = timeout(Duration::from_secs(1), output.recv())
.await
.unwrap_or_else(|err| panic!("timed out waiting for process output: {err}"))
.unwrap_or_else(|err| panic!("failed to receive process output: {err}"));
assert_eq!(output.stream, ExecOutputStream::Stderr);
assert_eq!(output.chunk, b"error\n".to_vec());
}
#[tokio::test]
async fn terminate_does_not_mark_process_exited_before_exit_notification() {
let (client_stdin, server_reader) = tokio::io::duplex(4096);
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
tokio::spawn(async move {
let mut lines = BufReader::new(server_reader).lines();
let initialize = read_jsonrpc_line(&mut lines).await;
let JSONRPCMessage::Request(initialize_request) = initialize else {
panic!("expected initialize request");
};
write_jsonrpc_line(
&mut server_writer,
JSONRPCMessage::Response(JSONRPCResponse {
id: initialize_request.id,
result: serde_json::json!({ "protocolVersion": PROTOCOL_VERSION }),
}),
)
.await;
let initialized = read_jsonrpc_line(&mut lines).await;
let JSONRPCMessage::Notification(notification) = initialized else {
panic!("expected initialized notification");
};
assert_eq!(notification.method, INITIALIZED_METHOD);
let exec_request = read_jsonrpc_line(&mut lines).await;
let JSONRPCMessage::Request(JSONRPCRequest { id, method, .. }) = exec_request else {
panic!("expected exec request");
};
assert_eq!(method, EXEC_METHOD);
write_jsonrpc_line(
&mut server_writer,
JSONRPCMessage::Response(JSONRPCResponse {
id,
result: serde_json::json!({ "processId": "proc-1" }),
}),
)
.await;
let terminate_request = read_jsonrpc_line(&mut lines).await;
let JSONRPCMessage::Request(JSONRPCRequest { id, method, .. }) = terminate_request
else {
panic!("expected terminate request");
};
assert_eq!(method, EXEC_TERMINATE_METHOD);
write_jsonrpc_line(
&mut server_writer,
JSONRPCMessage::Response(JSONRPCResponse {
id,
result: serde_json::json!({ "running": true }),
}),
)
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
});
let client = match ExecServerClient::connect_stdio(
client_stdin,
client_stdout,
test_options(),
)
.await
{
Ok(client) => client,
Err(err) => panic!("failed to connect test client: {err}"),
};
let process = match client
.start_process(ExecParams {
process_id: "proc-1".to_string(),
argv: vec!["bash".to_string(), "-lc".to_string(), "true".to_string()],
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
env: HashMap::new(),
tty: true,
arg0: None,
})
.await
{
Ok(process) => process,
Err(err) => panic!("failed to start process: {err}"),
};
process.terminate();
tokio::time::sleep(Duration::from_millis(25)).await;
assert!(!process.has_exited(), "terminate should not imply exit");
assert_eq!(process.exit_code(), None);
}
}