mirror of
https://github.com/openai/codex.git
synced 2026-05-28 15:00:16 +00:00
Simplify exec-server transport ownership
Remove the Option wrapper used only to force connection drop order and call transport shutdown explicitly instead. Also drop dead-code allowances that are no longer needed. Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -152,9 +152,8 @@ pub(crate) struct Session {
|
||||
}
|
||||
|
||||
struct Inner {
|
||||
// Keep the underlying transport connection alive and drop it before the RPC
|
||||
// client starts tearing down its channel/task handles.
|
||||
connection: Option<JsonRpcConnection>,
|
||||
// Keep the underlying transport connection alive for the client lifetime.
|
||||
connection: JsonRpcConnection,
|
||||
client: RpcClient,
|
||||
// The remote transport delivers one shared notification stream for every
|
||||
// process on the connection. Keep a local process_id -> session registry so
|
||||
@@ -184,7 +183,7 @@ struct Inner {
|
||||
impl Drop for Inner {
|
||||
fn drop(&mut self) {
|
||||
self.reader_task.abort();
|
||||
drop(self.connection.take());
|
||||
self.connection.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -466,7 +465,7 @@ impl ExecServerClient {
|
||||
});
|
||||
|
||||
Inner {
|
||||
connection: Some(connection),
|
||||
connection,
|
||||
client: rpc_client,
|
||||
sessions: ArcSwap::from_pointee(HashMap::new()),
|
||||
sessions_write_lock: Mutex::new(()),
|
||||
|
||||
@@ -49,6 +49,7 @@ pub(crate) struct StdioExecServerCommand {
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) enum ExecServerTransportParams {
|
||||
WebSocketUrl(String),
|
||||
StdioCommand(StdioExecServerCommand),
|
||||
}
|
||||
|
||||
/// Sends HTTP requests through a runtime-selected transport.
|
||||
|
||||
@@ -24,16 +24,27 @@ impl ExecServerClient {
|
||||
pub(crate) async fn connect_for_transport(
|
||||
transport_params: crate::client_api::ExecServerTransportParams,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
let crate::client_api::ExecServerTransportParams::WebSocketUrl(websocket_url) =
|
||||
transport_params;
|
||||
Self::connect_websocket(RemoteExecServerConnectArgs {
|
||||
websocket_url,
|
||||
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
|
||||
connect_timeout: ENVIRONMENT_CONNECT_TIMEOUT,
|
||||
initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT,
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
match transport_params {
|
||||
crate::client_api::ExecServerTransportParams::WebSocketUrl(websocket_url) => {
|
||||
Self::connect_websocket(RemoteExecServerConnectArgs {
|
||||
websocket_url,
|
||||
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
|
||||
connect_timeout: ENVIRONMENT_CONNECT_TIMEOUT,
|
||||
initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT,
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
}
|
||||
crate::client_api::ExecServerTransportParams::StdioCommand(command) => {
|
||||
Self::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
command,
|
||||
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
|
||||
initialize_timeout: ENVIRONMENT_INITIALIZE_TIMEOUT,
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect_websocket(
|
||||
|
||||
@@ -84,11 +84,15 @@ pub(crate) struct JsonRpcConnection {
|
||||
|
||||
impl Drop for JsonRpcConnection {
|
||||
fn drop(&mut self) {
|
||||
self.transport.shutdown();
|
||||
self.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonRpcConnection {
|
||||
pub(crate) fn shutdown(&mut self) {
|
||||
self.transport.shutdown();
|
||||
}
|
||||
|
||||
pub(crate) fn from_stdio<R, W>(reader: R, writer: W, connection_label: String) -> Self
|
||||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
@@ -293,7 +297,7 @@ impl JsonRpcConnection {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn take_client_runtime(
|
||||
pub(crate) fn take_runtime(
|
||||
&mut self,
|
||||
) -> (
|
||||
mpsc::Sender<JSONRPCMessage>,
|
||||
@@ -304,7 +308,7 @@ impl JsonRpcConnection {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
task_handles,
|
||||
} = self.take_runtime("JSON-RPC client runtime already taken");
|
||||
} = self.take_runtime_or_panic("JSON-RPC connection runtime already taken");
|
||||
(outgoing_tx, incoming_rx, task_handles)
|
||||
}
|
||||
|
||||
@@ -313,22 +317,7 @@ impl JsonRpcConnection {
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn into_parts(
|
||||
mut self,
|
||||
) -> (
|
||||
mpsc::Sender<JSONRPCMessage>,
|
||||
mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
Vec<tokio::task::JoinHandle<()>>,
|
||||
) {
|
||||
let JsonRpcConnectionRuntime {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
task_handles,
|
||||
} = self.take_runtime("JSON-RPC connection parts already taken");
|
||||
(outgoing_tx, incoming_rx, task_handles)
|
||||
}
|
||||
|
||||
fn take_runtime(&mut self, message: &'static str) -> JsonRpcConnectionRuntime {
|
||||
fn take_runtime_or_panic(&mut self, message: &'static str) -> JsonRpcConnectionRuntime {
|
||||
match self.runtime.take() {
|
||||
Some(runtime) => runtime,
|
||||
None => panic!("{message}"),
|
||||
|
||||
@@ -58,11 +58,9 @@ pub(crate) enum RpcServerOutboundMessage {
|
||||
request_id: RequestId,
|
||||
error: JSONRPCErrorError,
|
||||
},
|
||||
#[allow(dead_code)]
|
||||
Notification(JSONRPCNotification),
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RpcNotificationSender {
|
||||
outgoing_tx: mpsc::Sender<RpcServerOutboundMessage>,
|
||||
@@ -84,7 +82,6 @@ impl RpcNotificationSender {
|
||||
.map_err(|_| internal_error("RPC connection closed while sending response".into()))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn notify<P: Serialize>(
|
||||
&self,
|
||||
method: &str,
|
||||
@@ -235,7 +232,7 @@ impl RpcClient {
|
||||
pub(crate) fn new(
|
||||
connection: &mut JsonRpcConnection,
|
||||
) -> (Self, mpsc::Receiver<RpcClientEvent>) {
|
||||
let (write_tx, mut incoming_rx, transport_tasks) = connection.take_client_runtime();
|
||||
let (write_tx, mut incoming_rx, transport_tasks) = connection.take_runtime();
|
||||
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
|
||||
let (event_tx, event_rx) = mpsc::channel(128);
|
||||
let closed = Arc::new(AtomicBool::new(false));
|
||||
@@ -363,7 +360,6 @@ impl RpcClient {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn pending_request_count(&self) -> usize {
|
||||
self.pending.lock().await.len()
|
||||
}
|
||||
|
||||
@@ -42,12 +42,12 @@ impl ConnectionProcessor {
|
||||
}
|
||||
|
||||
async fn run_connection(
|
||||
connection: JsonRpcConnection,
|
||||
mut connection: JsonRpcConnection,
|
||||
session_registry: Arc<SessionRegistry>,
|
||||
runtime_paths: ExecServerRuntimePaths,
|
||||
) {
|
||||
let router = Arc::new(build_router());
|
||||
let (json_outgoing_tx, mut incoming_rx, connection_tasks) = connection.into_parts();
|
||||
let (json_outgoing_tx, mut incoming_rx, connection_tasks) = connection.take_runtime();
|
||||
let (outgoing_tx, mut outgoing_rx) =
|
||||
mpsc::channel::<RpcServerOutboundMessage>(CHANNEL_CAPACITY);
|
||||
let notifications = RpcNotificationSender::new(outgoing_tx.clone());
|
||||
|
||||
Reference in New Issue
Block a user