mirror of
https://github.com/openai/codex.git
synced 2026-05-20 19:23:21 +00:00
Simplify stdio exec-server transport ownership
Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -12,6 +12,7 @@ use futures::FutureExt;
|
||||
use futures::future::BoxFuture;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::OnceCell;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::watch;
|
||||
|
||||
@@ -152,6 +153,9 @@ pub(crate) struct Session {
|
||||
|
||||
struct Inner {
|
||||
client: RpcClient,
|
||||
// Keep the connection alive for any transport-specific owned state such as
|
||||
// the stdio child process. RpcClient only takes the runtime channels/tasks.
|
||||
_connection: JsonRpcConnection,
|
||||
// The remote transport delivers one shared notification stream for every
|
||||
// process on the connection. Keep a local process_id -> session registry so
|
||||
// we can turn those connection-global notifications into process wakeups
|
||||
@@ -191,28 +195,25 @@ pub struct ExecServerClient {
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct LazyRemoteExecServerClient {
|
||||
transport: ExecServerTransport,
|
||||
client: Arc<Mutex<Option<ExecServerClient>>>,
|
||||
client: Arc<OnceCell<ExecServerClient>>,
|
||||
}
|
||||
|
||||
impl LazyRemoteExecServerClient {
|
||||
pub(crate) fn new(transport: ExecServerTransport) -> Self {
|
||||
Self {
|
||||
transport,
|
||||
client: Arc::new(Mutex::new(None)),
|
||||
client: Arc::new(OnceCell::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn get(&self) -> Result<ExecServerClient, ExecServerError> {
|
||||
let mut client = self.client.lock().await;
|
||||
if let Some(client) = client.as_ref()
|
||||
&& !client.is_disconnected()
|
||||
{
|
||||
return Ok(client.clone());
|
||||
}
|
||||
|
||||
let connected = ExecServerClient::connect_for_environment(self.transport.clone()).await?;
|
||||
*client = Some(connected.clone());
|
||||
Ok(connected)
|
||||
self.client
|
||||
.get_or_try_init(|| {
|
||||
let transport = self.transport.clone();
|
||||
async move { ExecServerClient::connect_for_environment(transport).await }
|
||||
})
|
||||
.await
|
||||
.cloned()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -276,10 +277,6 @@ pub enum ExecServerError {
|
||||
}
|
||||
|
||||
impl ExecServerClient {
|
||||
fn is_disconnected(&self) -> bool {
|
||||
self.inner.disconnected_error().is_some() || self.inner.client.is_disconnected()
|
||||
}
|
||||
|
||||
pub async fn initialize(
|
||||
&self,
|
||||
options: ExecServerClientConnectOptions,
|
||||
@@ -429,10 +426,10 @@ impl ExecServerClient {
|
||||
}
|
||||
|
||||
pub(crate) async fn connect(
|
||||
connection: JsonRpcConnection,
|
||||
mut connection: JsonRpcConnection,
|
||||
options: ExecServerClientConnectOptions,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
let (rpc_client, mut events_rx) = RpcClient::new(connection);
|
||||
let (rpc_client, mut events_rx) = RpcClient::new(&mut connection);
|
||||
let inner = Arc::new_cyclic(|weak| {
|
||||
let weak = weak.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
@@ -467,6 +464,7 @@ impl ExecServerClient {
|
||||
|
||||
Inner {
|
||||
client: rpc_client,
|
||||
_connection: connection,
|
||||
sessions: ArcSwap::from_pointee(HashMap::new()),
|
||||
sessions_write_lock: Mutex::new(()),
|
||||
disconnected: OnceLock::new(),
|
||||
|
||||
@@ -3,9 +3,7 @@ use std::time::Duration;
|
||||
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::process::Child;
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tracing::debug;
|
||||
@@ -79,7 +77,6 @@ impl ExecServerClient {
|
||||
args: StdioExecServerConnectArgs,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
let mut child = stdio_command_process(&args.command)
|
||||
.kill_on_drop(true)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
@@ -110,60 +107,13 @@ impl ExecServerClient {
|
||||
|
||||
Self::connect(
|
||||
JsonRpcConnection::from_stdio(stdout, stdin, "exec-server stdio command".to_string())
|
||||
.with_transport_lifetime(Box::new(StdioChildGuard::spawn(child))),
|
||||
.with_stdio_child(child),
|
||||
args.into(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
struct StdioChildGuard {
|
||||
shutdown_tx: Option<oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
impl StdioChildGuard {
|
||||
fn spawn(child: Child) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = oneshot::channel();
|
||||
tokio::spawn(supervise_stdio_child(child, shutdown_rx));
|
||||
Self {
|
||||
shutdown_tx: Some(shutdown_tx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StdioChildGuard {
|
||||
fn drop(&mut self) {
|
||||
if let Some(shutdown_tx) = self.shutdown_tx.take() {
|
||||
let _ = shutdown_tx.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn supervise_stdio_child(mut child: Child, shutdown_rx: oneshot::Receiver<()>) {
|
||||
let shutdown_requested = tokio::select! {
|
||||
result = child.wait() => {
|
||||
if let Err(err) = result {
|
||||
debug!("failed to wait for exec-server stdio child: {err}");
|
||||
}
|
||||
false
|
||||
}
|
||||
_ = shutdown_rx => true,
|
||||
};
|
||||
|
||||
if shutdown_requested {
|
||||
kill_stdio_child(&mut child);
|
||||
if let Err(err) = child.wait().await {
|
||||
debug!("failed to wait for exec-server stdio child after shutdown: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn kill_stdio_child(child: &mut Child) {
|
||||
if let Err(err) = child.start_kill() {
|
||||
debug!("failed to terminate exec-server stdio child: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
fn stdio_command_process(stdio_command: &StdioExecServerCommand) -> Command {
|
||||
let mut command = Command::new(&stdio_command.program);
|
||||
command.args(&stdio_command.args);
|
||||
|
||||
@@ -3,10 +3,12 @@ use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::process::Child;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::watch;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::debug;
|
||||
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
@@ -15,16 +17,6 @@ use tokio::io::BufWriter;
|
||||
|
||||
pub(crate) const CHANNEL_CAPACITY: usize = 128;
|
||||
|
||||
pub(crate) type JsonRpcTransportLifetime = Box<dyn Send>;
|
||||
|
||||
pub(crate) struct JsonRpcConnectionParts {
|
||||
pub(crate) outgoing_tx: mpsc::Sender<JSONRPCMessage>,
|
||||
pub(crate) incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
pub(crate) disconnected_rx: watch::Receiver<bool>,
|
||||
pub(crate) task_handles: Vec<tokio::task::JoinHandle<()>>,
|
||||
pub(crate) transport_lifetime: Option<JsonRpcTransportLifetime>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum JsonRpcConnectionEvent {
|
||||
Message(JSONRPCMessage),
|
||||
@@ -32,12 +24,24 @@ pub(crate) enum JsonRpcConnectionEvent {
|
||||
Disconnected { reason: Option<String> },
|
||||
}
|
||||
|
||||
struct StdioTransport {
|
||||
child: Child,
|
||||
}
|
||||
|
||||
impl Drop for StdioTransport {
|
||||
fn drop(&mut self) {
|
||||
if let Err(err) = self.child.start_kill() {
|
||||
debug!("failed to terminate exec-server stdio child: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct JsonRpcConnection {
|
||||
outgoing_tx: mpsc::Sender<JSONRPCMessage>,
|
||||
incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
disconnected_rx: watch::Receiver<bool>,
|
||||
outgoing_tx: Option<mpsc::Sender<JSONRPCMessage>>,
|
||||
incoming_rx: Option<mpsc::Receiver<JsonRpcConnectionEvent>>,
|
||||
disconnected_rx: Option<watch::Receiver<bool>>,
|
||||
task_handles: Vec<tokio::task::JoinHandle<()>>,
|
||||
transport_lifetime: Option<JsonRpcTransportLifetime>,
|
||||
_stdio_transport: Option<StdioTransport>,
|
||||
}
|
||||
|
||||
impl JsonRpcConnection {
|
||||
@@ -124,11 +128,11 @@ impl JsonRpcConnection {
|
||||
});
|
||||
|
||||
Self {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
outgoing_tx: Some(outgoing_tx),
|
||||
incoming_rx: Some(incoming_rx),
|
||||
disconnected_rx: Some(disconnected_rx),
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
transport_lifetime: None,
|
||||
_stdio_transport: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,16 +263,38 @@ impl JsonRpcConnection {
|
||||
});
|
||||
|
||||
Self {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
outgoing_tx: Some(outgoing_tx),
|
||||
incoming_rx: Some(incoming_rx),
|
||||
disconnected_rx: Some(disconnected_rx),
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
transport_lifetime: None,
|
||||
_stdio_transport: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_transport_lifetime(mut self, lifetime: JsonRpcTransportLifetime) -> Self {
|
||||
self.transport_lifetime = Some(lifetime);
|
||||
pub(crate) fn take_client_runtime(
|
||||
&mut self,
|
||||
) -> (
|
||||
mpsc::Sender<JSONRPCMessage>,
|
||||
mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
watch::Receiver<bool>,
|
||||
Vec<tokio::task::JoinHandle<()>>,
|
||||
) {
|
||||
(
|
||||
self.outgoing_tx
|
||||
.take()
|
||||
.expect("JSON-RPC client runtime already taken"),
|
||||
self.incoming_rx
|
||||
.take()
|
||||
.expect("JSON-RPC client runtime already taken"),
|
||||
self.disconnected_rx
|
||||
.take()
|
||||
.expect("JSON-RPC client runtime already taken"),
|
||||
std::mem::take(&mut self.task_handles),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn with_stdio_child(mut self, child: Child) -> Self {
|
||||
self._stdio_transport = Some(StdioTransport { child });
|
||||
self
|
||||
}
|
||||
|
||||
@@ -281,22 +307,15 @@ impl JsonRpcConnection {
|
||||
Vec<tokio::task::JoinHandle<()>>,
|
||||
) {
|
||||
(
|
||||
self.outgoing_tx,
|
||||
self.incoming_rx,
|
||||
self.disconnected_rx,
|
||||
self.outgoing_tx
|
||||
.expect("JSON-RPC connection parts already taken"),
|
||||
self.incoming_rx
|
||||
.expect("JSON-RPC connection parts already taken"),
|
||||
self.disconnected_rx
|
||||
.expect("JSON-RPC connection parts already taken"),
|
||||
self.task_handles,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn into_parts_with_lifetime(self) -> JsonRpcConnectionParts {
|
||||
JsonRpcConnectionParts {
|
||||
outgoing_tx: self.outgoing_tx,
|
||||
incoming_rx: self.incoming_rx,
|
||||
disconnected_rx: self.disconnected_rx,
|
||||
task_handles: self.task_handles,
|
||||
transport_lifetime: self.transport_lifetime,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_disconnected(
|
||||
|
||||
@@ -233,19 +233,16 @@ pub(crate) struct RpcClient {
|
||||
}
|
||||
|
||||
impl RpcClient {
|
||||
pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver<RpcClientEvent>) {
|
||||
let connection_parts = connection.into_parts_with_lifetime();
|
||||
let write_tx = connection_parts.outgoing_tx;
|
||||
let mut incoming_rx = connection_parts.incoming_rx;
|
||||
let disconnected_rx = connection_parts.disconnected_rx;
|
||||
let transport_tasks = connection_parts.task_handles;
|
||||
let transport_lifetime = connection_parts.transport_lifetime;
|
||||
pub(crate) fn new(
|
||||
connection: &mut JsonRpcConnection,
|
||||
) -> (Self, mpsc::Receiver<RpcClientEvent>) {
|
||||
let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks) =
|
||||
connection.take_client_runtime();
|
||||
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
|
||||
let (event_tx, event_rx) = mpsc::channel(128);
|
||||
|
||||
let pending_for_reader = Arc::clone(&pending);
|
||||
let reader_task = tokio::spawn(async move {
|
||||
let _transport_lifetime = transport_lifetime;
|
||||
while let Some(event) = incoming_rx.recv().await {
|
||||
match event {
|
||||
JsonRpcConnectionEvent::Message(message) => {
|
||||
@@ -307,10 +304,6 @@ impl RpcClient {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn is_disconnected(&self) -> bool {
|
||||
*self.disconnected_rx.borrow()
|
||||
}
|
||||
|
||||
pub(crate) async fn call<P, T>(&self, method: &str, params: &P) -> Result<T, RpcCallError>
|
||||
where
|
||||
P: Serialize,
|
||||
@@ -575,11 +568,9 @@ mod tests {
|
||||
async fn rpc_client_matches_out_of_order_responses_by_request_id() {
|
||||
let (client_stdin, server_reader) = tokio::io::duplex(4096);
|
||||
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
|
||||
let (client, _events_rx) = RpcClient::new(JsonRpcConnection::from_stdio(
|
||||
client_stdout,
|
||||
client_stdin,
|
||||
"test-rpc".to_string(),
|
||||
));
|
||||
let mut connection =
|
||||
JsonRpcConnection::from_stdio(client_stdout, client_stdin, "test-rpc".to_string());
|
||||
let (client, _events_rx) = RpcClient::new(&mut connection);
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(server_reader).lines();
|
||||
|
||||
Reference in New Issue
Block a user