exec-server: make in-process client call handler directly

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
starr-openai
2026-03-17 21:14:17 +00:00
parent 9c53cce1d1
commit 3d40df7939
8 changed files with 307 additions and 334 deletions

View File

@@ -29,6 +29,9 @@ use tokio_tungstenite::connect_async;
use tracing::debug;
use tracing::warn;
use crate::client_api::ExecServerClientConnectOptions;
use crate::client_api::ExecServerEvent;
use crate::client_api::RemoteExecServerConnectArgs;
use crate::connection::JsonRpcConnection;
use crate::connection::JsonRpcConnectionEvent;
use crate::protocol::EXEC_EXITED_METHOD;
@@ -51,20 +54,10 @@ use crate::protocol::TerminateParams;
use crate::protocol::TerminateResponse;
use crate::protocol::WriteParams;
use crate::protocol::WriteResponse;
use crate::server::ExecServerClientNotification;
use crate::server::ExecServerHandler;
use crate::server::ExecServerInboundMessage;
use crate::server::ExecServerOutboundMessage;
use crate::server::ExecServerRequest;
use crate::server::ExecServerResponseMessage;
use crate::server::ExecServerServerNotification;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ExecServerClientConnectOptions {
pub client_name: String,
pub initialize_timeout: Duration,
}
impl Default for ExecServerClientConnectOptions {
fn default() -> Self {
Self {
@@ -74,14 +67,6 @@ impl Default for ExecServerClientConnectOptions {
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RemoteExecServerConnectArgs {
pub websocket_url: String,
pub client_name: String,
pub connect_timeout: Duration,
pub initialize_timeout: Duration,
}
impl From<RemoteExecServerConnectArgs> for ExecServerClientConnectOptions {
fn from(value: RemoteExecServerConnectArgs) -> Self {
Self {
@@ -105,16 +90,11 @@ impl RemoteExecServerConnectArgs {
}
}
#[cfg(test)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ExecServerOutput {
pub stream: crate::protocol::ExecOutputStream,
pub chunk: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExecServerEvent {
OutputDelta(ExecOutputDeltaNotification),
Exited(ExecExitedNotification),
struct ExecServerOutput {
stream: crate::protocol::ExecOutputStream,
chunk: Vec<u8>,
}
#[cfg(test)]
@@ -209,32 +189,6 @@ impl PendingRequest {
Ok(())
}
fn resolve_typed(self, response: ExecServerResponseMessage) -> Result<(), ExecServerError> {
match (self, response) {
(PendingRequest::Initialize(tx), ExecServerResponseMessage::Initialize(response)) => {
let _ = tx.send(Ok(response));
}
(PendingRequest::Exec(tx), ExecServerResponseMessage::Exec(response)) => {
let _ = tx.send(Ok(response));
}
(PendingRequest::Read(tx), ExecServerResponseMessage::Read(response)) => {
let _ = tx.send(Ok(response));
}
(PendingRequest::Write(tx), ExecServerResponseMessage::Write(response)) => {
let _ = tx.send(Ok(response));
}
(PendingRequest::Terminate(tx), ExecServerResponseMessage::Terminate(response)) => {
let _ = tx.send(Ok(response));
}
(_, response) => {
return Err(ExecServerError::Protocol(format!(
"unexpected in-process response kind: {response:?}"
)));
}
}
Ok(())
}
fn resolve_error(self, error: JSONRPCErrorError) {
match self {
PendingRequest::Initialize(tx) => {
@@ -261,7 +215,7 @@ enum ClientBackend {
write_tx: mpsc::Sender<JSONRPCMessage>,
},
InProcess {
write_tx: mpsc::Sender<ExecServerInboundMessage>,
handler: Arc<Mutex<ExecServerHandler>>,
},
}
@@ -271,15 +225,11 @@ struct Inner {
events_tx: broadcast::Sender<ExecServerEvent>,
next_request_id: AtomicI64,
reader_task: JoinHandle<()>,
server_task: Option<JoinHandle<()>>,
}
impl Drop for Inner {
fn drop(&mut self) {
self.reader_task.abort();
if let Some(server_task) = &self.server_task {
server_task.abort();
}
}
}
@@ -316,19 +266,8 @@ impl ExecServerClient {
pub async fn connect_in_process(
options: ExecServerClientConnectOptions,
) -> Result<Self, ExecServerError> {
let (write_tx, mut inbound_rx) = mpsc::channel::<ExecServerInboundMessage>(256);
let (outbound_tx, mut outgoing_rx) = mpsc::channel::<ExecServerOutboundMessage>(256);
let server_task = tokio::spawn(async move {
let mut handler = ExecServerHandler::new(outbound_tx);
while let Some(message) = inbound_rx.recv().await {
if let Err(err) = handler.handle_message(message).await {
warn!("in-process exec-server handler stopped after protocol error: {err}");
break;
}
}
handler.shutdown().await;
});
let handler = Arc::new(Mutex::new(ExecServerHandler::new(outbound_tx)));
let inner = Arc::new_cyclic(|weak| {
let weak = weak.clone();
@@ -337,7 +276,9 @@ impl ExecServerClient {
if let Some(inner) = weak.upgrade()
&& let Err(err) = handle_in_process_outbound_message(&inner, message).await
{
warn!("in-process exec-server client closing after protocol error: {err}");
warn!(
"in-process exec-server client closing after unexpected response: {err}"
);
handle_transport_shutdown(&inner).await;
return;
}
@@ -349,12 +290,11 @@ impl ExecServerClient {
});
Inner {
backend: ClientBackend::InProcess { write_tx },
backend: ClientBackend::InProcess { handler },
pending: Mutex::new(HashMap::new()),
events_tx: broadcast::channel(256).0,
next_request_id: AtomicI64::new(1),
reader_task,
server_task: Some(server_task),
}
});
@@ -447,7 +387,6 @@ impl ExecServerClient {
events_tx: broadcast::channel(256).0,
next_request_id: AtomicI64::new(1),
reader_task,
server_task: None,
}
});
@@ -548,6 +487,10 @@ impl ExecServerClient {
}
async fn request_exec(&self, params: ExecParams) -> Result<ExecResponse, ExecServerError> {
if let ClientBackend::InProcess { handler } = &self.inner.backend {
return server_result_to_client(handler.lock().await.exec(params).await);
}
let request_id = self.next_request_id();
let (response_tx, response_rx) = oneshot::channel();
self.inner
@@ -555,29 +498,20 @@ impl ExecServerClient {
.lock()
.await
.insert(request_id.clone(), PendingRequest::Exec(response_tx));
let send_result = match &self.inner.backend {
ClientBackend::JsonRpc { write_tx } => {
send_jsonrpc_request(write_tx, request_id.clone(), EXEC_METHOD, &params).await
}
ClientBackend::InProcess { write_tx } => {
send_in_process_request(
write_tx,
ExecServerInboundMessage::Request(ExecServerRequest::Exec {
request_id: request_id.clone(),
params,
}),
)
.await
}
let ClientBackend::JsonRpc { write_tx } = &self.inner.backend else {
unreachable!("in-process exec requests return before JSON-RPC setup");
};
if let Err(err) = send_result {
self.inner.pending.lock().await.remove(&request_id);
return Err(err);
}
receive_typed_response(response_rx).await
let send_result =
send_jsonrpc_request(write_tx, request_id.clone(), EXEC_METHOD, &params).await;
self.finish_request(request_id, send_result, response_rx)
.await
}
async fn write_process(&self, params: WriteParams) -> Result<WriteResponse, ExecServerError> {
if let ClientBackend::InProcess { handler } = &self.inner.backend {
return server_result_to_client(handler.lock().await.write(params).await);
}
let request_id = self.next_request_id();
let (response_tx, response_rx) = oneshot::channel();
self.inner
@@ -585,29 +519,20 @@ impl ExecServerClient {
.lock()
.await
.insert(request_id.clone(), PendingRequest::Write(response_tx));
let send_result = match &self.inner.backend {
ClientBackend::JsonRpc { write_tx } => {
send_jsonrpc_request(write_tx, request_id.clone(), EXEC_WRITE_METHOD, &params).await
}
ClientBackend::InProcess { write_tx } => {
send_in_process_request(
write_tx,
ExecServerInboundMessage::Request(ExecServerRequest::Write {
request_id: request_id.clone(),
params,
}),
)
.await
}
let ClientBackend::JsonRpc { write_tx } = &self.inner.backend else {
unreachable!("in-process write requests return before JSON-RPC setup");
};
if let Err(err) = send_result {
self.inner.pending.lock().await.remove(&request_id);
return Err(err);
}
receive_typed_response(response_rx).await
let send_result =
send_jsonrpc_request(write_tx, request_id.clone(), EXEC_WRITE_METHOD, &params).await;
self.finish_request(request_id, send_result, response_rx)
.await
}
async fn request_read(&self, params: ReadParams) -> Result<ReadResponse, ExecServerError> {
if let ClientBackend::InProcess { handler } = &self.inner.backend {
return server_result_to_client(handler.lock().await.read(params).await);
}
let request_id = self.next_request_id();
let (response_tx, response_rx) = oneshot::channel();
self.inner
@@ -615,26 +540,13 @@ impl ExecServerClient {
.lock()
.await
.insert(request_id.clone(), PendingRequest::Read(response_tx));
let send_result = match &self.inner.backend {
ClientBackend::JsonRpc { write_tx } => {
send_jsonrpc_request(write_tx, request_id.clone(), EXEC_READ_METHOD, &params).await
}
ClientBackend::InProcess { write_tx } => {
send_in_process_request(
write_tx,
ExecServerInboundMessage::Request(ExecServerRequest::Read {
request_id: request_id.clone(),
params,
}),
)
.await
}
let ClientBackend::JsonRpc { write_tx } = &self.inner.backend else {
unreachable!("in-process read requests return before JSON-RPC setup");
};
if let Err(err) = send_result {
self.inner.pending.lock().await.remove(&request_id);
return Err(err);
}
receive_typed_response(response_rx).await
let send_result =
send_jsonrpc_request(write_tx, request_id.clone(), EXEC_READ_METHOD, &params).await;
self.finish_request(request_id, send_result, response_rx)
.await
}
async fn terminate_session(
@@ -644,6 +556,10 @@ impl ExecServerClient {
let params = TerminateParams {
process_id: process_id.to_string(),
};
if let ClientBackend::InProcess { handler } = &self.inner.backend {
return server_result_to_client(handler.lock().await.terminate(params).await);
}
let request_id = self.next_request_id();
let (response_tx, response_rx) = oneshot::channel();
self.inner
@@ -651,27 +567,14 @@ impl ExecServerClient {
.lock()
.await
.insert(request_id.clone(), PendingRequest::Terminate(response_tx));
let send_result = match &self.inner.backend {
ClientBackend::JsonRpc { write_tx } => {
send_jsonrpc_request(write_tx, request_id.clone(), EXEC_TERMINATE_METHOD, &params)
.await
}
ClientBackend::InProcess { write_tx } => {
send_in_process_request(
write_tx,
ExecServerInboundMessage::Request(ExecServerRequest::Terminate {
request_id: request_id.clone(),
params,
}),
)
.await
}
let ClientBackend::JsonRpc { write_tx } = &self.inner.backend else {
unreachable!("in-process terminate requests return before JSON-RPC setup");
};
if let Err(err) = send_result {
self.inner.pending.lock().await.remove(&request_id);
return Err(err);
}
receive_typed_response(response_rx).await
let send_result =
send_jsonrpc_request(write_tx, request_id.clone(), EXEC_TERMINATE_METHOD, &params)
.await;
self.finish_request(request_id, send_result, response_rx)
.await
}
async fn notify<P: Serialize>(&self, method: &str, params: &P) -> Result<(), ExecServerError> {
@@ -686,22 +589,16 @@ impl ExecServerClient {
.await
.map_err(|_| ExecServerError::Closed)
}
ClientBackend::InProcess { write_tx } => {
let message = match method {
INITIALIZED_METHOD => ExecServerInboundMessage::Notification(
ExecServerClientNotification::Initialized,
),
other => {
return Err(ExecServerError::Protocol(format!(
"unsupported in-process notification method `{other}`"
)));
}
};
write_tx
.send(message)
ClientBackend::InProcess { handler } => match method {
INITIALIZED_METHOD => handler
.lock()
.await
.map_err(|_| ExecServerError::Closed)
}
.initialized()
.map_err(ExecServerError::Protocol),
other => Err(ExecServerError::Protocol(format!(
"unsupported in-process notification method `{other}`"
))),
},
}
}
@@ -709,6 +606,10 @@ impl ExecServerClient {
&self,
params: InitializeParams,
) -> Result<InitializeResponse, ExecServerError> {
if let ClientBackend::InProcess { handler } = &self.inner.backend {
return server_result_to_client(handler.lock().await.initialize());
}
let request_id = self.next_request_id();
let (response_tx, response_rx) = oneshot::channel();
self.inner
@@ -716,31 +617,31 @@ impl ExecServerClient {
.lock()
.await
.insert(request_id.clone(), PendingRequest::Initialize(response_tx));
let send_result = match &self.inner.backend {
ClientBackend::JsonRpc { write_tx } => {
send_jsonrpc_request(write_tx, request_id.clone(), INITIALIZE_METHOD, &params).await
}
ClientBackend::InProcess { write_tx } => {
send_in_process_request(
write_tx,
ExecServerInboundMessage::Request(ExecServerRequest::Initialize {
request_id: request_id.clone(),
params,
}),
)
.await
}
let ClientBackend::JsonRpc { write_tx } = &self.inner.backend else {
unreachable!("in-process initialize requests return before JSON-RPC setup");
};
let send_result =
send_jsonrpc_request(write_tx, request_id.clone(), INITIALIZE_METHOD, &params).await;
self.finish_request(request_id, send_result, response_rx)
.await
}
fn next_request_id(&self) -> RequestId {
RequestId::Integer(self.inner.next_request_id.fetch_add(1, Ordering::SeqCst))
}
async fn finish_request<T>(
&self,
request_id: RequestId,
send_result: Result<(), ExecServerError>,
response_rx: oneshot::Receiver<Result<T, JSONRPCErrorError>>,
) -> Result<T, ExecServerError> {
if let Err(err) = send_result {
self.inner.pending.lock().await.remove(&request_id);
return Err(err);
}
receive_typed_response(response_rx).await
}
fn next_request_id(&self) -> RequestId {
RequestId::Integer(self.inner.next_request_id.fetch_add(1, Ordering::SeqCst))
}
}
async fn receive_typed_response<T>(
@@ -756,6 +657,16 @@ async fn receive_typed_response<T>(
}
}
fn server_result_to_client<T>(result: Result<T, JSONRPCErrorError>) -> Result<T, ExecServerError> {
match result {
Ok(response) => Ok(response),
Err(error) => Err(ExecServerError::Server {
code: error.code,
message: error.message,
}),
}
}
async fn send_jsonrpc_request<P: Serialize>(
write_tx: &mpsc::Sender<JSONRPCMessage>,
request_id: RequestId,
@@ -774,33 +685,15 @@ async fn send_jsonrpc_request<P: Serialize>(
.map_err(|_| ExecServerError::Closed)
}
async fn send_in_process_request(
write_tx: &mpsc::Sender<ExecServerInboundMessage>,
message: ExecServerInboundMessage,
) -> Result<(), ExecServerError> {
write_tx
.send(message)
.await
.map_err(|_| ExecServerError::Closed)
}
async fn handle_in_process_outbound_message(
inner: &Arc<Inner>,
message: ExecServerOutboundMessage,
) -> Result<(), ExecServerError> {
match message {
ExecServerOutboundMessage::Response {
request_id,
response,
} => {
if let Some(pending) = inner.pending.lock().await.remove(&request_id) {
pending.resolve_typed(response)?;
}
}
ExecServerOutboundMessage::Error { request_id, error } => {
if let Some(pending) = inner.pending.lock().await.remove(&request_id) {
pending.resolve_error(error);
}
ExecServerOutboundMessage::Response { .. } | ExecServerOutboundMessage::Error { .. } => {
return Err(ExecServerError::Protocol(
"unexpected in-process RPC response".to_string(),
));
}
ExecServerOutboundMessage::Notification(notification) => {
handle_in_process_notification(inner, notification).await;