Files
codex/codex-rs/exec-server/src/rpc.rs
starr-openai 3ff901257a Flatten JSON-RPC connection state
Drop the separate JsonRpcConnectionRuntime wrapper so JsonRpcConnection directly owns the channels, disconnect watch, transport tasks, and transport guard. This keeps the lifetime model explicit without helper extraction methods.

Co-authored-by: Codex <noreply@openai.com>
2026-05-06 19:19:45 -07:00

736 lines
25 KiB
Rust

use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
use codex_app_server_protocol::JSONRPCError;
use codex_app_server_protocol::JSONRPCErrorError;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RequestId;
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use crate::connection::JsonRpcConnection;
use crate::connection::JsonRpcConnectionEvent;
use crate::connection::JsonRpcTransport;
#[derive(Debug)]
pub(crate) enum RpcCallError {
/// The underlying JSON-RPC transport closed before this call completed.
Closed,
/// The response bytes were valid JSON-RPC but not the expected result type.
Json(serde_json::Error),
/// The executor returned a JSON-RPC error response for this call.
Server(JSONRPCErrorError),
}
type PendingRequest = oneshot::Sender<Result<Value, RpcCallError>>;
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
type RequestRoute<S> = Box<
dyn Fn(Arc<S>, JSONRPCRequest) -> BoxFuture<Option<RpcServerOutboundMessage>> + Send + Sync,
>;
type NotificationRoute<S> =
Box<dyn Fn(Arc<S>, JSONRPCNotification) -> BoxFuture<Result<(), String>> + Send + Sync>;
#[derive(Debug)]
pub(crate) enum RpcClientEvent {
Notification(JSONRPCNotification),
Disconnected { reason: Option<String> },
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum RpcServerOutboundMessage {
Response {
request_id: RequestId,
result: Value,
},
Error {
request_id: RequestId,
error: JSONRPCErrorError,
},
Notification(JSONRPCNotification),
}
#[derive(Clone)]
pub(crate) struct RpcNotificationSender {
outgoing_tx: mpsc::Sender<RpcServerOutboundMessage>,
}
impl RpcNotificationSender {
pub(crate) fn new(outgoing_tx: mpsc::Sender<RpcServerOutboundMessage>) -> Self {
Self { outgoing_tx }
}
pub(crate) async fn response(
&self,
request_id: RequestId,
result: Value,
) -> Result<(), JSONRPCErrorError> {
self.outgoing_tx
.send(RpcServerOutboundMessage::Response { request_id, result })
.await
.map_err(|_| internal_error("RPC connection closed while sending response".into()))
}
pub(crate) async fn notify<P: Serialize>(
&self,
method: &str,
params: &P,
) -> Result<(), JSONRPCErrorError> {
let params = serde_json::to_value(params).map_err(|err| internal_error(err.to_string()))?;
self.outgoing_tx
.send(RpcServerOutboundMessage::Notification(
JSONRPCNotification {
method: method.to_string(),
params: Some(params),
},
))
.await
.map_err(|_| internal_error("RPC connection closed while sending notification".into()))
}
}
pub(crate) struct RpcRouter<S> {
request_routes: HashMap<&'static str, RequestRoute<S>>,
notification_routes: HashMap<&'static str, NotificationRoute<S>>,
}
impl<S> Default for RpcRouter<S> {
fn default() -> Self {
Self {
request_routes: HashMap::new(),
notification_routes: HashMap::new(),
}
}
}
impl<S> RpcRouter<S>
where
S: Send + Sync + 'static,
{
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) fn request<P, R, F, Fut>(&mut self, method: &'static str, handler: F)
where
P: DeserializeOwned + Send + 'static,
R: Serialize + Send + 'static,
F: Fn(Arc<S>, P) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<R, JSONRPCErrorError>> + Send + 'static,
{
self.request_routes.insert(
method,
Box::new(move |state, request| {
let request_id = request.id;
let params = request.params;
let response =
decode_request_params::<P>(params).map(|params| handler(state, params));
Box::pin(async move {
let response = match response {
Ok(response) => response.await,
Err(error) => {
return Some(RpcServerOutboundMessage::Error { request_id, error });
}
};
Some(match response {
Ok(result) => match serde_json::to_value(result) {
Ok(result) => RpcServerOutboundMessage::Response { request_id, result },
Err(err) => RpcServerOutboundMessage::Error {
request_id,
error: internal_error(err.to_string()),
},
},
Err(error) => RpcServerOutboundMessage::Error { request_id, error },
})
})
}),
);
}
pub(crate) fn request_with_id<P, F, Fut>(&mut self, method: &'static str, handler: F)
where
P: DeserializeOwned + Send + 'static,
F: Fn(Arc<S>, RequestId, P) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), JSONRPCErrorError>> + Send + 'static,
{
self.request_routes.insert(
method,
Box::new(move |state, request| {
let request_id = request.id;
let params = decode_request_params::<P>(request.params)
.map(|params| handler(state, request_id.clone(), params));
Box::pin(async move {
let response = match params {
Ok(response) => response.await,
Err(error) => {
return Some(RpcServerOutboundMessage::Error { request_id, error });
}
};
match response {
Ok(()) => None,
Err(error) => Some(RpcServerOutboundMessage::Error { request_id, error }),
}
})
}),
);
}
pub(crate) fn notification<P, F, Fut>(&mut self, method: &'static str, handler: F)
where
P: DeserializeOwned + Send + 'static,
F: Fn(Arc<S>, P) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), String>> + Send + 'static,
{
self.notification_routes.insert(
method,
Box::new(move |state, notification| {
let params = decode_notification_params::<P>(notification.params)
.map(|params| handler(state, params));
Box::pin(async move {
let handler = match params {
Ok(handler) => handler,
Err(err) => return Err(err),
};
handler.await
})
}),
);
}
pub(crate) fn request_route(&self, method: &str) -> Option<&RequestRoute<S>> {
self.request_routes.get(method)
}
pub(crate) fn notification_route(&self, method: &str) -> Option<&NotificationRoute<S>> {
self.notification_routes.get(method)
}
}
pub(crate) struct RpcClient {
write_tx: mpsc::Sender<JSONRPCMessage>,
pending: Arc<Mutex<HashMap<RequestId, PendingRequest>>>,
// This flips before the ordered RPC reader drains pending requests, so new
// calls fail instead of registering work that can never complete.
closed: Arc<AtomicBool>,
next_request_id: AtomicI64,
transport_tasks: Vec<JoinHandle<()>>,
_transport: JsonRpcTransport,
reader_task: JoinHandle<()>,
}
impl RpcClient {
pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver<RpcClientEvent>) {
let JsonRpcConnection {
outgoing_tx: write_tx,
incoming_rx: mut incoming_rx,
disconnected_rx: _,
task_handles: transport_tasks,
transport,
} = connection;
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
let (event_tx, event_rx) = mpsc::channel(128);
let closed = Arc::new(AtomicBool::new(false));
let pending_for_reader = Arc::clone(&pending);
let closed_for_reader = Arc::clone(&closed);
let reader_task = tokio::spawn(async move {
let reason = loop {
let Some(event) = incoming_rx.recv().await else {
break None;
};
match event {
JsonRpcConnectionEvent::Message(message) => {
if let Err(err) =
handle_server_message(&pending_for_reader, &event_tx, message).await
{
break Some(err);
}
}
JsonRpcConnectionEvent::MalformedMessage { reason } => {
break Some(reason);
}
JsonRpcConnectionEvent::Disconnected { reason } => break reason,
}
};
closed_for_reader.store(true, Ordering::SeqCst);
let _ = event_tx.send(RpcClientEvent::Disconnected { reason }).await;
drain_pending(&pending_for_reader).await;
});
(
Self {
write_tx,
pending,
closed,
next_request_id: AtomicI64::new(1),
transport_tasks,
_transport: transport,
reader_task,
},
event_rx,
)
}
pub(crate) async fn notify<P: Serialize>(
&self,
method: &str,
params: &P,
) -> Result<(), serde_json::Error> {
let params = serde_json::to_value(params)?;
if self.closed.load(Ordering::SeqCst) {
return Err(serde_json::Error::io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"JSON-RPC transport closed",
)));
}
self.write_tx
.send(JSONRPCMessage::Notification(JSONRPCNotification {
method: method.to_string(),
params: Some(params),
}))
.await
.map_err(|_| {
serde_json::Error::io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"JSON-RPC transport closed",
))
})
}
pub(crate) async fn call<P, T>(&self, method: &str, params: &P) -> Result<T, RpcCallError>
where
P: Serialize,
T: DeserializeOwned,
{
let request_id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::SeqCst));
let (response_tx, response_rx) = oneshot::channel();
{
let mut pending = self.pending.lock().await;
// Registering the pending request and checking disconnect must be
// atomic with the reader's drain_pending path. Otherwise a call
// can sneak in after the drain and wait forever.
if self.closed.load(Ordering::SeqCst) {
return Err(RpcCallError::Closed);
}
pending.insert(request_id.clone(), response_tx);
}
let params = match serde_json::to_value(params) {
Ok(params) => params,
Err(err) => {
self.pending.lock().await.remove(&request_id);
return Err(RpcCallError::Json(err));
}
};
if self
.write_tx
.send(JSONRPCMessage::Request(JSONRPCRequest {
id: request_id.clone(),
method: method.to_string(),
params: Some(params),
trace: None,
}))
.await
.is_err()
{
self.pending.lock().await.remove(&request_id);
return Err(RpcCallError::Closed);
}
// Do not race in-flight requests directly against the transport-close
// watch value. The connection reader receives JSON-RPC messages and
// the terminal disconnect event on one ordered queue, then drains any
// still-pending requests. Awaiting this receiver preserves that order:
// responses already read before EOF still win, and truly pending calls
// are failed once the reader observes the disconnect.
let result: Result<Value, RpcCallError> =
response_rx.await.map_err(|_| RpcCallError::Closed)?;
let response = match result {
Ok(response) => response,
Err(error) => return Err(error),
};
serde_json::from_value(response).map_err(RpcCallError::Json)
}
#[cfg(test)]
pub(crate) async fn pending_request_count(&self) -> usize {
self.pending.lock().await.len()
}
}
impl Drop for RpcClient {
fn drop(&mut self) {
for task in &self.transport_tasks {
task.abort();
}
self.reader_task.abort();
}
}
pub(crate) fn encode_server_message(
message: RpcServerOutboundMessage,
) -> Result<JSONRPCMessage, serde_json::Error> {
match message {
RpcServerOutboundMessage::Response { request_id, result } => {
Ok(JSONRPCMessage::Response(JSONRPCResponse {
id: request_id,
result,
}))
}
RpcServerOutboundMessage::Error { request_id, error } => {
Ok(JSONRPCMessage::Error(JSONRPCError {
id: request_id,
error,
}))
}
RpcServerOutboundMessage::Notification(notification) => {
Ok(JSONRPCMessage::Notification(notification))
}
}
}
pub(crate) fn invalid_request(message: String) -> JSONRPCErrorError {
JSONRPCErrorError {
code: -32600,
data: None,
message,
}
}
pub(crate) fn method_not_found(message: String) -> JSONRPCErrorError {
JSONRPCErrorError {
code: -32601,
data: None,
message,
}
}
pub(crate) fn invalid_params(message: String) -> JSONRPCErrorError {
JSONRPCErrorError {
code: -32602,
data: None,
message,
}
}
pub(crate) fn not_found(message: String) -> JSONRPCErrorError {
JSONRPCErrorError {
code: -32004,
data: None,
message,
}
}
pub(crate) fn internal_error(message: String) -> JSONRPCErrorError {
JSONRPCErrorError {
code: -32603,
data: None,
message,
}
}
fn decode_request_params<P>(params: Option<Value>) -> Result<P, JSONRPCErrorError>
where
P: DeserializeOwned,
{
decode_params(params).map_err(|err| invalid_params(err.to_string()))
}
fn decode_notification_params<P>(params: Option<Value>) -> Result<P, String>
where
P: DeserializeOwned,
{
decode_params(params).map_err(|err| err.to_string())
}
fn decode_params<P>(params: Option<Value>) -> Result<P, serde_json::Error>
where
P: DeserializeOwned,
{
let params = params.unwrap_or(Value::Null);
match serde_json::from_value(params.clone()) {
Ok(params) => Ok(params),
Err(err) => {
if matches!(params, Value::Object(ref map) if map.is_empty()) {
serde_json::from_value(Value::Null).map_err(|_| err)
} else {
Err(err)
}
}
}
}
async fn handle_server_message(
pending: &Mutex<HashMap<RequestId, PendingRequest>>,
event_tx: &mpsc::Sender<RpcClientEvent>,
message: JSONRPCMessage,
) -> Result<(), String> {
match message {
JSONRPCMessage::Response(JSONRPCResponse { id, result }) => {
if let Some(pending) = pending.lock().await.remove(&id) {
let _ = pending.send(Ok(result));
}
}
JSONRPCMessage::Error(JSONRPCError { id, error }) => {
if let Some(pending) = pending.lock().await.remove(&id) {
let _ = pending.send(Err(RpcCallError::Server(error)));
}
}
JSONRPCMessage::Notification(notification) => {
let _ = event_tx
.send(RpcClientEvent::Notification(notification))
.await;
}
JSONRPCMessage::Request(request) => {
return Err(format!(
"unexpected JSON-RPC request from remote server: {}",
request.method
));
}
}
Ok(())
}
async fn drain_pending(pending: &Mutex<HashMap<RequestId, PendingRequest>>) {
let pending = {
let mut pending = pending.lock().await;
pending
.drain()
.map(|(_, pending)| pending)
.collect::<Vec<_>>()
};
for pending in pending {
let _ = pending.send(Err(RpcCallError::Closed));
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RequestId;
use pretty_assertions::assert_eq;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::time::timeout;
use super::RpcClient;
use super::RpcClientEvent;
use crate::connection::JsonRpcConnection;
async fn read_jsonrpc_line<R>(lines: &mut tokio::io::Lines<BufReader<R>>) -> JSONRPCMessage
where
R: tokio::io::AsyncRead + Unpin,
{
let next_line = timeout(Duration::from_secs(1), lines.next_line()).await;
let line_result = match next_line {
Ok(line_result) => line_result,
Err(err) => panic!("timed out waiting for JSON-RPC line: {err}"),
};
let maybe_line = match line_result {
Ok(maybe_line) => maybe_line,
Err(err) => panic!("failed to read JSON-RPC line: {err}"),
};
let line = match maybe_line {
Some(line) => line,
None => panic!("server connection closed before JSON-RPC line arrived"),
};
match serde_json::from_str::<JSONRPCMessage>(&line) {
Ok(message) => message,
Err(err) => panic!("failed to parse JSON-RPC line: {err}"),
}
}
async fn write_jsonrpc_line<W>(writer: &mut W, message: JSONRPCMessage)
where
W: tokio::io::AsyncWrite + Unpin,
{
let encoded = match serde_json::to_string(&message) {
Ok(encoded) => encoded,
Err(err) => panic!("failed to encode JSON-RPC message: {err}"),
};
if let Err(err) = writer.write_all(format!("{encoded}\n").as_bytes()).await {
panic!("failed to write JSON-RPC line: {err}");
}
}
#[tokio::test]
async fn rpc_client_matches_out_of_order_responses_by_request_id() {
let (client_stdin, server_reader) = tokio::io::duplex(4096);
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
let connection =
JsonRpcConnection::from_stdio(client_stdout, client_stdin, "test-rpc".to_string());
let (client, _events_rx) = RpcClient::new(connection);
let server = tokio::spawn(async move {
let mut lines = BufReader::new(server_reader).lines();
let first = read_jsonrpc_line(&mut lines).await;
let second = read_jsonrpc_line(&mut lines).await;
let (slow_request, fast_request) = match (first, second) {
(
JSONRPCMessage::Request(first_request),
JSONRPCMessage::Request(second_request),
) if first_request.method == "slow" && second_request.method == "fast" => {
(first_request, second_request)
}
(
JSONRPCMessage::Request(first_request),
JSONRPCMessage::Request(second_request),
) if first_request.method == "fast" && second_request.method == "slow" => {
(second_request, first_request)
}
_ => panic!("expected slow and fast requests"),
};
write_jsonrpc_line(
&mut server_writer,
JSONRPCMessage::Response(JSONRPCResponse {
id: fast_request.id,
result: serde_json::json!({ "value": "fast" }),
}),
)
.await;
write_jsonrpc_line(
&mut server_writer,
JSONRPCMessage::Response(JSONRPCResponse {
id: slow_request.id,
result: serde_json::json!({ "value": "slow" }),
}),
)
.await;
});
let slow_params = serde_json::json!({ "n": 1 });
let fast_params = serde_json::json!({ "n": 2 });
let (slow, fast) = tokio::join!(
client.call::<_, serde_json::Value>("slow", &slow_params),
client.call::<_, serde_json::Value>("fast", &fast_params),
);
let slow = slow.unwrap_or_else(|err| panic!("slow request failed: {err:?}"));
let fast = fast.unwrap_or_else(|err| panic!("fast request failed: {err:?}"));
assert_eq!(slow, serde_json::json!({ "value": "slow" }));
assert_eq!(fast, serde_json::json!({ "value": "fast" }));
assert_eq!(client.pending_request_count().await, 0);
if let Err(err) = server.await {
panic!("server task failed: {err}");
}
}
#[tokio::test]
async fn rpc_client_rejects_new_calls_after_reader_protocol_error() {
let (client_stdin, _server_reader) = tokio::io::duplex(4096);
let (mut server_writer, client_stdout) = tokio::io::duplex(4096);
let connection =
JsonRpcConnection::from_stdio(client_stdout, client_stdin, "test-rpc".to_string());
let (client, mut events_rx) = RpcClient::new(connection);
write_jsonrpc_line(
&mut server_writer,
JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(1),
method: "server/request".to_string(),
params: None,
trace: None,
}),
)
.await;
let event = timeout(Duration::from_secs(1), events_rx.recv())
.await
.expect("timed out waiting for disconnect event");
match event {
Some(RpcClientEvent::Disconnected { reason }) => {
assert!(
reason
.as_deref()
.is_some_and(|reason| reason.contains("unexpected JSON-RPC request")),
"unexpected disconnect reason: {reason:?}"
);
}
event => panic!("expected disconnect event, got {event:?}"),
}
let result = timeout(
Duration::from_secs(1),
client.call::<_, serde_json::Value>("after-close", &serde_json::json!({})),
)
.await
.expect("timed out waiting for closed call");
assert!(matches!(result, Err(super::RpcCallError::Closed)));
assert_eq!(client.pending_request_count().await, 0);
}
#[tokio::test]
async fn rpc_client_drains_pending_call_on_transport_eof() {
let (client_stdin, server_reader) = tokio::io::duplex(4096);
let (server_writer, client_stdout) = tokio::io::duplex(4096);
let connection =
JsonRpcConnection::from_stdio(client_stdout, client_stdin, "test-rpc".to_string());
let (client, mut events_rx) = RpcClient::new(connection);
let server = tokio::spawn(async move {
let mut lines = BufReader::new(server_reader).lines();
let request = read_jsonrpc_line(&mut lines).await;
match request {
JSONRPCMessage::Request(request) if request.method == "will-close" => {}
other => panic!("expected will-close request, got {other:?}"),
}
drop(server_writer);
});
let result = timeout(
Duration::from_secs(1),
client.call::<_, serde_json::Value>("will-close", &serde_json::json!({})),
)
.await
.expect("timed out waiting for closed call");
assert!(matches!(result, Err(super::RpcCallError::Closed)));
let event = timeout(Duration::from_secs(1), events_rx.recv())
.await
.expect("timed out waiting for disconnect event");
assert!(matches!(
event,
Some(RpcClientEvent::Disconnected { reason: None })
));
assert_eq!(client.pending_request_count().await, 0);
let result = timeout(
Duration::from_secs(1),
client.call::<_, serde_json::Value>("after-close", &serde_json::json!({})),
)
.await
.expect("timed out waiting for fast closed call");
assert!(matches!(result, Err(super::RpcCallError::Closed)));
let notify = client.notify("after-close", &serde_json::json!({})).await;
assert!(notify.is_err());
server.await.expect("server task should finish");
}
}