mirror of
https://github.com/openai/codex.git
synced 2026-05-02 02:17:22 +00:00
refactor(exec-server): tighten client lifecycle and output model
Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -4,6 +4,7 @@ use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCError;
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
@@ -22,6 +23,7 @@ use tokio::sync::broadcast;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tracing::debug;
|
||||
use tracing::warn;
|
||||
@@ -49,12 +51,14 @@ use crate::protocol::WriteResponse;
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ExecServerClientConnectOptions {
|
||||
pub client_name: String,
|
||||
pub initialize_timeout: Duration,
|
||||
}
|
||||
|
||||
impl Default for ExecServerClientConnectOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
client_name: "codex-core".to_string(),
|
||||
initialize_timeout: INITIALIZE_TIMEOUT,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -63,19 +67,42 @@ impl Default for ExecServerClientConnectOptions {
|
||||
pub struct RemoteExecServerConnectArgs {
|
||||
pub websocket_url: String,
|
||||
pub client_name: String,
|
||||
pub connect_timeout: Duration,
|
||||
pub initialize_timeout: Duration,
|
||||
}
|
||||
|
||||
impl From<RemoteExecServerConnectArgs> for ExecServerClientConnectOptions {
|
||||
fn from(value: RemoteExecServerConnectArgs) -> Self {
|
||||
Self {
|
||||
client_name: value.client_name,
|
||||
initialize_timeout: value.initialize_timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const INITIALIZE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
impl RemoteExecServerConnectArgs {
|
||||
pub fn new(websocket_url: String, client_name: String) -> Self {
|
||||
Self {
|
||||
websocket_url,
|
||||
client_name,
|
||||
connect_timeout: CONNECT_TIMEOUT,
|
||||
initialize_timeout: INITIALIZE_TIMEOUT,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ExecServerOutput {
|
||||
pub stream: crate::protocol::ExecOutputStream,
|
||||
pub chunk: Vec<u8>,
|
||||
}
|
||||
|
||||
pub struct ExecServerProcess {
|
||||
process_id: String,
|
||||
output_rx: broadcast::Receiver<Vec<u8>>,
|
||||
output_rx: broadcast::Receiver<ExecServerOutput>,
|
||||
writer_tx: mpsc::Sender<Vec<u8>>,
|
||||
status: Arc<RemoteProcessStatus>,
|
||||
client: ExecServerClient,
|
||||
@@ -96,7 +123,7 @@ impl ExecServerProcess {
|
||||
self.writer_tx.clone()
|
||||
}
|
||||
|
||||
pub fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
|
||||
pub fn output_receiver(&self) -> broadcast::Receiver<ExecServerOutput> {
|
||||
self.output_rx.resubscribe()
|
||||
}
|
||||
|
||||
@@ -109,7 +136,6 @@ impl ExecServerProcess {
|
||||
}
|
||||
|
||||
pub fn terminate(&self) {
|
||||
self.status.mark_exited(None);
|
||||
let client = self.client.clone();
|
||||
let process_id = self.process_id.clone();
|
||||
tokio::spawn(async move {
|
||||
@@ -157,7 +183,7 @@ impl RemoteProcessStatus {
|
||||
}
|
||||
|
||||
struct RegisteredProcess {
|
||||
output_tx: broadcast::Sender<Vec<u8>>,
|
||||
output_tx: broadcast::Sender<ExecServerOutput>,
|
||||
status: Arc<RemoteProcessStatus>,
|
||||
}
|
||||
|
||||
@@ -184,12 +210,16 @@ pub struct ExecServerClient {
|
||||
pub enum ExecServerError {
|
||||
#[error("failed to spawn exec-server: {0}")]
|
||||
Spawn(#[source] std::io::Error),
|
||||
#[error("timed out connecting to exec-server websocket `{url}` after {timeout:?}")]
|
||||
WebSocketConnectTimeout { url: String, timeout: Duration },
|
||||
#[error("failed to connect to exec-server websocket `{url}`: {source}")]
|
||||
WebSocketConnect {
|
||||
url: String,
|
||||
#[source]
|
||||
source: tokio_tungstenite::tungstenite::Error,
|
||||
},
|
||||
#[error("timed out waiting for exec-server initialize handshake after {timeout:?}")]
|
||||
InitializeTimedOut { timeout: Duration },
|
||||
#[error("exec-server transport closed")]
|
||||
Closed,
|
||||
#[error("failed to serialize or deserialize exec-server JSON: {0}")]
|
||||
@@ -221,8 +251,13 @@ impl ExecServerClient {
|
||||
args: RemoteExecServerConnectArgs,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
let websocket_url = args.websocket_url.clone();
|
||||
let (stream, _) = connect_async(websocket_url.as_str())
|
||||
let connect_timeout = args.connect_timeout;
|
||||
let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str()))
|
||||
.await
|
||||
.map_err(|_| ExecServerError::WebSocketConnectTimeout {
|
||||
url: websocket_url.clone(),
|
||||
timeout: connect_timeout,
|
||||
})?
|
||||
.map_err(|source| ExecServerError::WebSocketConnect {
|
||||
url: websocket_url.clone(),
|
||||
source,
|
||||
@@ -325,32 +360,12 @@ impl ExecServerClient {
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
if !response.running {
|
||||
status.mark_exited(response.exit_code);
|
||||
}
|
||||
|
||||
if let Some(stdout) = response.stdout {
|
||||
let _ = self
|
||||
.inner
|
||||
.processes
|
||||
.lock()
|
||||
.await
|
||||
.get(&process_id)
|
||||
.map(|process| process.output_tx.send(stdout.into_inner()));
|
||||
}
|
||||
if let Some(stderr) = response.stderr {
|
||||
let _ = self
|
||||
.inner
|
||||
.processes
|
||||
.lock()
|
||||
.await
|
||||
.get(&process_id)
|
||||
.map(|process| process.output_tx.send(stderr.into_inner()));
|
||||
}
|
||||
|
||||
if let Some(exit_code) = response.exit_code {
|
||||
status.mark_exited(Some(exit_code));
|
||||
if response.process_id != process_id {
|
||||
self.inner.processes.lock().await.remove(&process_id);
|
||||
return Err(ExecServerError::Protocol(format!(
|
||||
"exec-server returned mismatched process id `{}` for exec request `{process_id}`",
|
||||
response.process_id
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(ExecServerProcess {
|
||||
@@ -366,16 +381,21 @@ impl ExecServerClient {
|
||||
&self,
|
||||
options: ExecServerClientConnectOptions,
|
||||
) -> Result<(), ExecServerError> {
|
||||
let _: InitializeResponse = self
|
||||
.request(
|
||||
INITIALIZE_METHOD,
|
||||
&InitializeParams {
|
||||
client_name: options.client_name,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
self.notify(INITIALIZED_METHOD, &serde_json::json!({}))
|
||||
.await
|
||||
let ExecServerClientConnectOptions {
|
||||
client_name,
|
||||
initialize_timeout,
|
||||
} = options;
|
||||
timeout(initialize_timeout, async {
|
||||
let _: InitializeResponse = self
|
||||
.request(INITIALIZE_METHOD, &InitializeParams { client_name })
|
||||
.await?;
|
||||
self.notify(INITIALIZED_METHOD, &serde_json::json!({}))
|
||||
.await
|
||||
})
|
||||
.await
|
||||
.map_err(|_| ExecServerError::InitializeTimedOut {
|
||||
timeout: initialize_timeout,
|
||||
})?
|
||||
}
|
||||
|
||||
async fn write_process(&self, params: WriteParams) -> Result<WriteResponse, ExecServerError> {
|
||||
@@ -412,6 +432,7 @@ impl ExecServerClient {
|
||||
P: Serialize,
|
||||
R: DeserializeOwned,
|
||||
{
|
||||
let params = serde_json::to_value(params)?;
|
||||
let request_id =
|
||||
RequestId::Integer(self.inner.next_request_id.fetch_add(1, Ordering::SeqCst));
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
@@ -421,7 +442,6 @@ impl ExecServerClient {
|
||||
.await
|
||||
.insert(request_id.clone(), response_tx);
|
||||
|
||||
let params = serde_json::to_value(params)?;
|
||||
let message = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: request_id.clone(),
|
||||
method: method.to_string(),
|
||||
@@ -482,10 +502,13 @@ async fn handle_server_notification(
|
||||
EXEC_OUTPUT_DELTA_METHOD => {
|
||||
let params: ExecOutputDeltaNotification =
|
||||
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
|
||||
let chunk = params.chunk.into_inner();
|
||||
let output = ExecServerOutput {
|
||||
stream: params.stream,
|
||||
chunk: params.chunk.into_inner(),
|
||||
};
|
||||
let processes = inner.processes.lock().await;
|
||||
if let Some(process) = processes.get(¶ms.process_id) {
|
||||
let _ = process.output_tx.send(chunk);
|
||||
let _ = process.output_tx.send(output);
|
||||
}
|
||||
}
|
||||
EXEC_EXITED_METHOD => {
|
||||
@@ -543,6 +566,9 @@ mod tests {
|
||||
use super::ExecServerClientConnectOptions;
|
||||
use super::ExecServerError;
|
||||
use crate::protocol::EXEC_METHOD;
|
||||
use crate::protocol::EXEC_OUTPUT_DELTA_METHOD;
|
||||
use crate::protocol::EXEC_TERMINATE_METHOD;
|
||||
use crate::protocol::ExecOutputStream;
|
||||
use crate::protocol::ExecParams;
|
||||
use crate::protocol::INITIALIZE_METHOD;
|
||||
use crate::protocol::INITIALIZED_METHOD;
|
||||
@@ -554,6 +580,13 @@ mod tests {
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
|
||||
fn test_options() -> ExecServerClientConnectOptions {
|
||||
ExecServerClientConnectOptions {
|
||||
client_name: "test-client".to_string(),
|
||||
initialize_timeout: Duration::from_secs(1),
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_jsonrpc_line<R>(lines: &mut tokio::io::Lines<BufReader<R>>) -> JSONRPCMessage
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin,
|
||||
@@ -625,14 +658,8 @@ mod tests {
|
||||
assert_eq!(params, Some(serde_json::json!({})));
|
||||
});
|
||||
|
||||
let client = ExecServerClient::connect_stdio(
|
||||
client_stdin,
|
||||
client_stdout,
|
||||
ExecServerClientConnectOptions {
|
||||
client_name: "test-client".to_string(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
let client =
|
||||
ExecServerClient::connect_stdio(client_stdin, client_stdout, test_options()).await;
|
||||
if let Err(err) = client {
|
||||
panic!("failed to connect test client: {err}");
|
||||
}
|
||||
@@ -668,14 +695,8 @@ mod tests {
|
||||
.await;
|
||||
});
|
||||
|
||||
let result = ExecServerClient::connect_stdio(
|
||||
client_stdin,
|
||||
client_stdout,
|
||||
ExecServerClientConnectOptions {
|
||||
client_name: "test-client".to_string(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
let result =
|
||||
ExecServerClient::connect_stdio(client_stdin, client_stdout, test_options()).await;
|
||||
|
||||
match result {
|
||||
Err(ExecServerError::Server { code, message }) => {
|
||||
@@ -736,9 +757,7 @@ mod tests {
|
||||
let client = match ExecServerClient::connect_stdio(
|
||||
client_stdin,
|
||||
client_stdout,
|
||||
ExecServerClientConnectOptions {
|
||||
client_name: "test-client".to_string(),
|
||||
},
|
||||
test_options(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -753,7 +772,6 @@ mod tests {
|
||||
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
|
||||
env: HashMap::new(),
|
||||
tty: true,
|
||||
output_bytes_cap: 4096,
|
||||
arg0: None,
|
||||
})
|
||||
.await;
|
||||
@@ -772,4 +790,215 @@ mod tests {
|
||||
"failed requests should not leave registered process state behind"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_stdio_times_out_during_initialize_handshake() {
|
||||
let (client_stdin, server_reader) = tokio::io::duplex(4096);
|
||||
let (_server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
let _ = read_jsonrpc_line(&mut lines).await;
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
});
|
||||
|
||||
let result = ExecServerClient::connect_stdio(
|
||||
client_stdin,
|
||||
client_stdout,
|
||||
ExecServerClientConnectOptions {
|
||||
client_name: "test-client".to_string(),
|
||||
initialize_timeout: Duration::from_millis(25),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(ExecServerError::InitializeTimedOut { timeout }) => {
|
||||
assert_eq!(timeout, Duration::from_millis(25));
|
||||
}
|
||||
Err(err) => panic!("unexpected initialize timeout failure: {err}"),
|
||||
Ok(_) => panic!("expected initialize timeout"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_process_preserves_output_stream_metadata() {
|
||||
let (client_stdin, server_reader) = tokio::io::duplex(4096);
|
||||
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
|
||||
let initialize = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(initialize_request) = initialize else {
|
||||
panic!("expected initialize request");
|
||||
};
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: initialize_request.id,
|
||||
result: serde_json::json!({ "protocolVersion": PROTOCOL_VERSION }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let initialized = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Notification(notification) = initialized else {
|
||||
panic!("expected initialized notification");
|
||||
};
|
||||
assert_eq!(notification.method, INITIALIZED_METHOD);
|
||||
|
||||
let exec_request = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(JSONRPCRequest { id, method, .. }) = exec_request else {
|
||||
panic!("expected exec request");
|
||||
};
|
||||
assert_eq!(method, EXEC_METHOD);
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id,
|
||||
result: serde_json::json!({ "processId": "proc-1" }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: EXEC_OUTPUT_DELTA_METHOD.to_string(),
|
||||
params: Some(serde_json::json!({
|
||||
"processId": "proc-1",
|
||||
"stream": "stderr",
|
||||
"chunk": "ZXJyb3IK"
|
||||
})),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
});
|
||||
|
||||
let client = match ExecServerClient::connect_stdio(
|
||||
client_stdin,
|
||||
client_stdout,
|
||||
test_options(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(client) => client,
|
||||
Err(err) => panic!("failed to connect test client: {err}"),
|
||||
};
|
||||
|
||||
let process = match client
|
||||
.start_process(ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["bash".to_string(), "-lc".to_string(), "true".to_string()],
|
||||
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
|
||||
env: HashMap::new(),
|
||||
tty: true,
|
||||
arg0: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(process) => process,
|
||||
Err(err) => panic!("failed to start process: {err}"),
|
||||
};
|
||||
|
||||
let mut output = process.output_receiver();
|
||||
let output = timeout(Duration::from_secs(1), output.recv())
|
||||
.await
|
||||
.unwrap_or_else(|err| panic!("timed out waiting for process output: {err}"))
|
||||
.unwrap_or_else(|err| panic!("failed to receive process output: {err}"));
|
||||
assert_eq!(output.stream, ExecOutputStream::Stderr);
|
||||
assert_eq!(output.chunk, b"error\n".to_vec());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn terminate_does_not_mark_process_exited_before_exit_notification() {
|
||||
let (client_stdin, server_reader) = tokio::io::duplex(4096);
|
||||
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
|
||||
let initialize = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(initialize_request) = initialize else {
|
||||
panic!("expected initialize request");
|
||||
};
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: initialize_request.id,
|
||||
result: serde_json::json!({ "protocolVersion": PROTOCOL_VERSION }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let initialized = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Notification(notification) = initialized else {
|
||||
panic!("expected initialized notification");
|
||||
};
|
||||
assert_eq!(notification.method, INITIALIZED_METHOD);
|
||||
|
||||
let exec_request = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(JSONRPCRequest { id, method, .. }) = exec_request else {
|
||||
panic!("expected exec request");
|
||||
};
|
||||
assert_eq!(method, EXEC_METHOD);
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id,
|
||||
result: serde_json::json!({ "processId": "proc-1" }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let terminate_request = read_jsonrpc_line(&mut lines).await;
|
||||
let JSONRPCMessage::Request(JSONRPCRequest { id, method, .. }) = terminate_request
|
||||
else {
|
||||
panic!("expected terminate request");
|
||||
};
|
||||
assert_eq!(method, EXEC_TERMINATE_METHOD);
|
||||
write_jsonrpc_line(
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id,
|
||||
result: serde_json::json!({ "running": true }),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
});
|
||||
|
||||
let client = match ExecServerClient::connect_stdio(
|
||||
client_stdin,
|
||||
client_stdout,
|
||||
test_options(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(client) => client,
|
||||
Err(err) => panic!("failed to connect test client: {err}"),
|
||||
};
|
||||
|
||||
let process = match client
|
||||
.start_process(ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["bash".to_string(), "-lc".to_string(), "true".to_string()],
|
||||
cwd: std::env::current_dir().unwrap_or_else(|err| panic!("missing cwd: {err}")),
|
||||
env: HashMap::new(),
|
||||
tty: true,
|
||||
arg0: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(process) => process,
|
||||
Err(err) => panic!("failed to start process: {err}"),
|
||||
};
|
||||
|
||||
process.terminate();
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
assert!(!process.has_exited(), "terminate should not imply exit");
|
||||
assert_eq!(process.exit_code(), None);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user