mirror of
https://github.com/openai/codex.git
synced 2026-06-01 19:02:59 +00:00
Restore exec-server processor ownership boundary
Keep the server-side connection processor on the original by-value parts API, and move the compatibility needed for that shape into JsonRpcConnection. The client still borrows the connection mutably so it can keep transport ownership with ExecServerClient. Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -5,6 +5,7 @@ use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::process::Child;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::watch;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::debug;
|
||||
@@ -74,6 +75,7 @@ impl StdioTransport {
|
||||
struct JsonRpcConnectionRuntime {
|
||||
outgoing_tx: mpsc::Sender<JSONRPCMessage>,
|
||||
incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
disconnected_rx: watch::Receiver<bool>,
|
||||
task_handles: Vec<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
@@ -100,9 +102,11 @@ impl JsonRpcConnection {
|
||||
{
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (disconnected_tx, disconnected_rx) = watch::channel(false);
|
||||
|
||||
let reader_label = connection_label.clone();
|
||||
let incoming_tx_for_reader = incoming_tx.clone();
|
||||
let disconnected_tx_for_reader = disconnected_tx.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(reader).lines();
|
||||
loop {
|
||||
@@ -133,12 +137,18 @@ impl JsonRpcConnection {
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await;
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
Some(format!(
|
||||
"failed to read JSON-RPC message from {reader_label}: {err}"
|
||||
)),
|
||||
@@ -156,6 +166,7 @@ impl JsonRpcConnection {
|
||||
if let Err(err) = write_jsonrpc_line_message(&mut writer, &message).await {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to write JSON-RPC message to {connection_label}: {err}"
|
||||
)),
|
||||
@@ -170,6 +181,7 @@ impl JsonRpcConnection {
|
||||
runtime: Some(JsonRpcConnectionRuntime {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
}),
|
||||
transport: JsonRpcTransport::Plain,
|
||||
@@ -182,10 +194,12 @@ impl JsonRpcConnection {
|
||||
{
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (disconnected_tx, disconnected_rx) = watch::channel(false);
|
||||
let (mut websocket_writer, mut websocket_reader) = stream.split();
|
||||
|
||||
let reader_label = connection_label.clone();
|
||||
let incoming_tx_for_reader = incoming_tx.clone();
|
||||
let disconnected_tx_for_reader = disconnected_tx.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match websocket_reader.next().await {
|
||||
@@ -234,7 +248,12 @@ impl JsonRpcConnection {
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) => {
|
||||
send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await;
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {}
|
||||
@@ -242,6 +261,7 @@ impl JsonRpcConnection {
|
||||
Some(Err(err)) => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
Some(format!(
|
||||
"failed to read websocket JSON-RPC message from {reader_label}: {err}"
|
||||
)),
|
||||
@@ -250,7 +270,12 @@ impl JsonRpcConnection {
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await;
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -265,6 +290,7 @@ impl JsonRpcConnection {
|
||||
{
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to write websocket JSON-RPC message to {connection_label}: {err}"
|
||||
)),
|
||||
@@ -276,6 +302,7 @@ impl JsonRpcConnection {
|
||||
Err(err) => {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to serialize JSON-RPC message for {connection_label}: {err}"
|
||||
)),
|
||||
@@ -291,6 +318,7 @@ impl JsonRpcConnection {
|
||||
runtime: Some(JsonRpcConnectionRuntime {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
}),
|
||||
transport: JsonRpcTransport::Plain,
|
||||
@@ -307,11 +335,29 @@ impl JsonRpcConnection {
|
||||
let JsonRpcConnectionRuntime {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx: _,
|
||||
task_handles,
|
||||
} = self.take_runtime_or_panic("JSON-RPC connection runtime already taken");
|
||||
(outgoing_tx, incoming_rx, task_handles)
|
||||
}
|
||||
|
||||
pub(crate) fn into_parts(
|
||||
mut self,
|
||||
) -> (
|
||||
mpsc::Sender<JSONRPCMessage>,
|
||||
mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
watch::Receiver<bool>,
|
||||
Vec<tokio::task::JoinHandle<()>>,
|
||||
) {
|
||||
let JsonRpcConnectionRuntime {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles,
|
||||
} = self.take_runtime_or_panic("JSON-RPC connection runtime already taken");
|
||||
(outgoing_tx, incoming_rx, disconnected_rx, task_handles)
|
||||
}
|
||||
|
||||
pub(crate) fn with_child_process(mut self, child_process: Child) -> Self {
|
||||
self.transport = JsonRpcTransport::from_child_process(child_process);
|
||||
self
|
||||
@@ -327,8 +373,10 @@ impl JsonRpcConnection {
|
||||
|
||||
async fn send_disconnected(
|
||||
incoming_tx: &mpsc::Sender<JsonRpcConnectionEvent>,
|
||||
disconnected_tx: &watch::Sender<bool>,
|
||||
reason: Option<String>,
|
||||
) {
|
||||
let _ = disconnected_tx.send(true);
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Disconnected { reason })
|
||||
.await;
|
||||
|
||||
@@ -42,12 +42,13 @@ impl ConnectionProcessor {
|
||||
}
|
||||
|
||||
async fn run_connection(
|
||||
mut connection: JsonRpcConnection,
|
||||
connection: JsonRpcConnection,
|
||||
session_registry: Arc<SessionRegistry>,
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
) {
|
||||
let router = Arc::new(build_router());
|
||||
let (json_outgoing_tx, mut incoming_rx, connection_tasks) = connection.take_runtime();
|
||||
let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) =
|
||||
connection.into_parts();
|
||||
let (outgoing_tx, mut outgoing_rx) =
|
||||
mpsc::channel::<RpcServerOutboundMessage>(CHANNEL_CAPACITY);
|
||||
let notifications = RpcNotificationSender::new(outgoing_tx.clone());
|
||||
@@ -95,7 +96,13 @@ async fn run_connection(
|
||||
JsonRpcConnectionEvent::Message(message) => match message {
|
||||
codex_app_server_protocol::JSONRPCMessage::Request(request) => {
|
||||
if let Some(route) = router.request_route(request.method.as_str()) {
|
||||
let message = route(Arc::clone(&handler), request).await;
|
||||
let message = tokio::select! {
|
||||
message = route(Arc::clone(&handler), request) => message,
|
||||
_ = disconnected_rx.changed() => {
|
||||
debug!("exec-server transport disconnected while handling request");
|
||||
break;
|
||||
}
|
||||
};
|
||||
if let Some(message) = message
|
||||
&& outgoing_tx.send(message).await.is_err()
|
||||
{
|
||||
@@ -124,7 +131,15 @@ async fn run_connection(
|
||||
);
|
||||
break;
|
||||
};
|
||||
let result = route(Arc::clone(&handler), notification).await;
|
||||
let result = tokio::select! {
|
||||
result = route(Arc::clone(&handler), notification) => result,
|
||||
_ = disconnected_rx.changed() => {
|
||||
debug!(
|
||||
"exec-server transport disconnected while handling notification"
|
||||
);
|
||||
break;
|
||||
}
|
||||
};
|
||||
if let Err(err) = result {
|
||||
warn!("closing exec-server connection after protocol error: {err}");
|
||||
break;
|
||||
@@ -163,3 +178,241 @@ async fn run_connection(
|
||||
}
|
||||
let _ = outbound_task.await;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use serde::Serialize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::io::DuplexStream;
|
||||
use tokio::io::Lines;
|
||||
use tokio::io::duplex;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::run_connection;
|
||||
use crate::ExecServerRuntimePaths;
|
||||
use crate::ProcessId;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::protocol::EXEC_METHOD;
|
||||
use crate::protocol::EXEC_READ_METHOD;
|
||||
use crate::protocol::EXEC_TERMINATE_METHOD;
|
||||
use crate::protocol::ExecParams;
|
||||
use crate::protocol::ExecResponse;
|
||||
use crate::protocol::INITIALIZE_METHOD;
|
||||
use crate::protocol::INITIALIZED_METHOD;
|
||||
use crate::protocol::InitializeParams;
|
||||
use crate::protocol::InitializeResponse;
|
||||
use crate::protocol::ReadParams;
|
||||
use crate::protocol::TerminateParams;
|
||||
use crate::protocol::TerminateResponse;
|
||||
use crate::server::session_registry::SessionRegistry;
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_disconnect_detaches_session_during_in_flight_read() {
|
||||
let registry = SessionRegistry::new();
|
||||
let (mut first_writer, mut first_lines, first_task) =
|
||||
spawn_test_connection(Arc::clone(®istry), "first");
|
||||
|
||||
send_request(
|
||||
&mut first_writer,
|
||||
/*id*/ 1,
|
||||
INITIALIZE_METHOD,
|
||||
&InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
let initialize_response: InitializeResponse =
|
||||
read_response(&mut first_lines, /*expected_id*/ 1).await;
|
||||
send_notification(&mut first_writer, INITIALIZED_METHOD, &()).await;
|
||||
|
||||
let process_id = ProcessId::from("proc-long-poll");
|
||||
send_request(
|
||||
&mut first_writer,
|
||||
/*id*/ 2,
|
||||
EXEC_METHOD,
|
||||
&exec_params(process_id.clone()),
|
||||
)
|
||||
.await;
|
||||
let _: ExecResponse = read_response(&mut first_lines, /*expected_id*/ 2).await;
|
||||
|
||||
send_request(
|
||||
&mut first_writer,
|
||||
/*id*/ 3,
|
||||
EXEC_READ_METHOD,
|
||||
&ReadParams {
|
||||
process_id: process_id.clone(),
|
||||
after_seq: None,
|
||||
max_bytes: None,
|
||||
wait_ms: Some(5_000),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
drop(first_writer);
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
|
||||
let (mut second_writer, mut second_lines, second_task) =
|
||||
spawn_test_connection(Arc::clone(®istry), "second");
|
||||
send_request(
|
||||
&mut second_writer,
|
||||
/*id*/ 1,
|
||||
INITIALIZE_METHOD,
|
||||
&InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: Some(initialize_response.session_id.clone()),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
let second_initialize_response = timeout(
|
||||
Duration::from_secs(1),
|
||||
read_response::<InitializeResponse>(&mut second_lines, /*expected_id*/ 1),
|
||||
)
|
||||
.await
|
||||
.expect("resume initialize should not wait for the old read to finish");
|
||||
assert_eq!(
|
||||
second_initialize_response.session_id,
|
||||
initialize_response.session_id
|
||||
);
|
||||
timeout(Duration::from_secs(1), first_task)
|
||||
.await
|
||||
.expect("first processor should exit")
|
||||
.expect("first processor should join");
|
||||
send_notification(&mut second_writer, INITIALIZED_METHOD, &()).await;
|
||||
|
||||
send_request(
|
||||
&mut second_writer,
|
||||
/*id*/ 2,
|
||||
EXEC_TERMINATE_METHOD,
|
||||
&TerminateParams { process_id },
|
||||
)
|
||||
.await;
|
||||
let _: TerminateResponse = read_response(&mut second_lines, /*expected_id*/ 2).await;
|
||||
|
||||
drop(second_writer);
|
||||
drop(second_lines);
|
||||
timeout(Duration::from_secs(1), second_task)
|
||||
.await
|
||||
.expect("second processor should exit")
|
||||
.expect("second processor should join");
|
||||
}
|
||||
|
||||
fn spawn_test_connection(
|
||||
registry: Arc<SessionRegistry>,
|
||||
label: &str,
|
||||
) -> (DuplexStream, Lines<BufReader<DuplexStream>>, JoinHandle<()>) {
|
||||
let (client_writer, server_reader) = duplex(1 << 20);
|
||||
let (server_writer, client_reader) = duplex(1 << 20);
|
||||
let connection =
|
||||
JsonRpcConnection::from_stdio(server_reader, server_writer, label.to_string());
|
||||
let task = tokio::spawn(run_connection(connection, registry, test_runtime_paths()));
|
||||
(client_writer, BufReader::new(client_reader).lines(), task)
|
||||
}
|
||||
|
||||
fn test_runtime_paths() -> ExecServerRuntimePaths {
|
||||
ExecServerRuntimePaths::new(
|
||||
std::env::current_exe().expect("current exe"),
|
||||
/*codex_linux_sandbox_exe*/ None,
|
||||
)
|
||||
.expect("runtime paths")
|
||||
}
|
||||
|
||||
async fn send_request<P: Serialize>(
|
||||
writer: &mut DuplexStream,
|
||||
id: i64,
|
||||
method: &str,
|
||||
params: &P,
|
||||
) {
|
||||
write_message(
|
||||
writer,
|
||||
&JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(id),
|
||||
method: method.to_string(),
|
||||
params: Some(serde_json::to_value(params).expect("serialize params")),
|
||||
trace: None,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn send_notification<P: Serialize>(writer: &mut DuplexStream, method: &str, params: &P) {
|
||||
write_message(
|
||||
writer,
|
||||
&JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: method.to_string(),
|
||||
params: Some(serde_json::to_value(params).expect("serialize params")),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn write_message(writer: &mut DuplexStream, message: &JSONRPCMessage) {
|
||||
let encoded = serde_json::to_vec(message).expect("serialize JSON-RPC message");
|
||||
writer.write_all(&encoded).await.expect("write request");
|
||||
writer.write_all(b"\n").await.expect("write newline");
|
||||
}
|
||||
|
||||
async fn read_response<T: DeserializeOwned>(
|
||||
lines: &mut Lines<BufReader<DuplexStream>>,
|
||||
expected_id: i64,
|
||||
) -> T {
|
||||
let line = lines
|
||||
.next_line()
|
||||
.await
|
||||
.expect("read response")
|
||||
.expect("response line");
|
||||
match serde_json::from_str::<JSONRPCMessage>(&line).expect("decode JSON-RPC response") {
|
||||
JSONRPCMessage::Response(JSONRPCResponse { id, result }) => {
|
||||
assert_eq!(id, RequestId::Integer(expected_id));
|
||||
serde_json::from_value(result).expect("decode response result")
|
||||
}
|
||||
JSONRPCMessage::Error(error) => panic!("unexpected JSON-RPC error: {error:?}"),
|
||||
other => panic!("expected JSON-RPC response, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn exec_params(process_id: ProcessId) -> ExecParams {
|
||||
let mut env = HashMap::new();
|
||||
if let Some(path) = std::env::var_os("PATH") {
|
||||
env.insert("PATH".to_string(), path.to_string_lossy().into_owned());
|
||||
}
|
||||
ExecParams {
|
||||
process_id,
|
||||
argv: sleep_then_print_argv(),
|
||||
cwd: std::env::current_dir().expect("cwd"),
|
||||
env_policy: None,
|
||||
env,
|
||||
tty: false,
|
||||
pipe_stdin: false,
|
||||
arg0: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn sleep_then_print_argv() -> Vec<String> {
|
||||
if cfg!(windows) {
|
||||
vec![
|
||||
std::env::var("COMSPEC").unwrap_or_else(|_| "cmd.exe".to_string()),
|
||||
"/C".to_string(),
|
||||
"ping -n 3 127.0.0.1 >NUL && echo late".to_string(),
|
||||
]
|
||||
} else {
|
||||
vec![
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
"sleep 1; printf late".to_string(),
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user