Compare commits

...

12 Commits

Author SHA1 Message Date
starr-openai
ee0db3246a Fix reconnect test lint 2026-05-18 23:50:45 -07:00
starr-openai
edd0d6a66e Harden exec-server reconnect recovery 2026-05-18 23:33:43 -07:00
starr-openai
d82bb0d81e Document exec-server reconnect hierarchy 2026-05-18 21:17:01 -07:00
starr-openai
b25c4b53c9 Document and test exec-server reconnect invariants 2026-05-18 20:53:13 -07:00
starr-openai
ef8267fb69 Add exec-server websocket reconnect foundation 2026-05-18 20:05:53 -07:00
starr-openai
90804bb2eb Remove manual websocket pong handling 2026-05-18 16:14:53 -07:00
starr-openai
6f9640533a Restore server-owned websocket keepalive 2026-05-18 16:00:04 -07:00
starr-openai
6d6cdeb128 Fix websocket backpressure regression tests 2026-05-18 15:36:41 -07:00
starr-openai
9215e15ee3 Add websocket backpressure regression tests 2026-05-18 15:24:27 -07:00
starr-openai
d94300d782 Preserve exec-server websocket keepalive ownership 2026-05-18 10:42:43 -07:00
starr-openai
3673b69a2a Add exec-server websocket pump tests 2026-05-18 10:31:48 -07:00
starr-openai
8a9300e92a Refactor exec-server websocket pump 2026-05-18 10:04:34 -07:00
15 changed files with 2516 additions and 538 deletions

View File

@@ -7,13 +7,112 @@ JSON-RPC server for spawning and controlling subprocesses through
It provides:
- a CLI entrypoint: `codex exec-server`
- a Rust client: `ExecServerClient`
- a Rust connection client: `ExecServerConnection`
- a small protocol module with shared request/response types
This crate owns the transport, protocol, and filesystem/process handlers. The
top-level `codex` binary owns hidden helper dispatch for sandboxed
filesystem operations and `codex-linux-sandbox`.
## Client And Session Ownership
Remote environments expose one logical exec-server client to the rest of Codex.
That client is not the same thing as one websocket connection.
```text
Environment
`- one RemoteExecServerClient
|- RemoteExecServerSession
| |- logical_session_id
| |- current ExecServerConnection?
| |- one in-flight reconnect attempt
| |- terminal reconnect error?
| `- tracked process_sessions: HashMap<ProcessId, Weak<ProcessSession>>
|- RemoteProcess -> RemoteExecProcess -> ProcessSessionHandle
|- RemoteFileSystem
`- HttpClient for RemoteExecServerClient
ExecServerConnection
`- Inner
|- RpcClient
|- reader task
|- disconnect latch
|- connection-local process_session_routes
|- connection-local HTTP body stream routes
`- initialized session_id for this live binding
ProcessSessionHandle
|- process_id
|- Arc<ProcessSession>
`- ProcessSessionControl
ProcessSession
|- wake channel
|- event log
|- ordered event buffer
|- failure state
`- disconnect policy
ProcessSessionControl
|- Connection(ExecServerConnection) for direct/test one-shot sessions
`- RemoteClient(RemoteExecServerClient) for reconnecting remote environments
```
The main roles are:
- `RemoteExecServerClient`: environment-owned logical client. `Environment`
clones this into the remote process backend, remote filesystem, and remote
HTTP capability so all remote APIs share one reconnecting session.
- `RemoteExecServerSession`: durable logical-session state behind the client.
It remembers the resumable session id, current live connection, one shared
reconnect attempt, weak references to process sessions that may need rebinding,
and any terminal resume error.
- `ExecServerConnection`: one live JSON-RPC transport binding. It owns
connection-local routing for notifications and streamed HTTP response bodies.
- `Inner`: private per-connection machinery behind `ExecServerConnection`.
It owns the `RpcClient`, reader task, disconnect latch, connection-local
process notification routes, connection-local HTTP body stream routes, and
initialized session id for that live binding.
- `ProcessSession`: durable per-process client state owned by the live process
handle. It keeps the local event log, wake cursor, and failure state that must
survive connection replacement while that handle still exists.
- `ProcessSessionHandle`: process-facing handle used by `RemoteExecProcess`.
It routes reads, writes, terminate, and unregister through either a focused
direct connection test path or the logical reconnecting client path.
- `ProcessSessionControl`: small command-path enum for a
`ProcessSessionHandle`. It is not an owner; it only chooses direct
`ExecServerConnection` versus reconnecting `RemoteExecServerClient`.
- `RemoteProcess`, `RemoteFileSystem`, and `HttpClient`: thin capability
adapters. They should not own reconnect state themselves.
Reconnect invariants:
- There is one shared reconnect attempt per `RemoteExecServerClient`, not one
reconnect loop per API surface.
- Reconnect resumes the same logical session id and rebinds tracked
`ProcessSession` routes onto the replacement `ExecServerConnection`.
- When a reconnecting process session loses its transport, it emits a local
`ResyncRequired` event and wake so callers blocked on pushed events or wake
notifications know to recover through `process/read(afterSeq)`.
- `process/read` may retry once after a transport-close race because its
`afterSeq` cursor makes the replay read-only and recoverable.
- `process/start`, `process/write`, `process/terminate`, filesystem RPCs, and
`http/request` are not replayed after an ambiguous mid-request disconnect.
They reconnect before later calls, but an in-flight call that may already
have reached the server returns an error instead of risking duplicate side
effects.
- Streamed HTTP bodies are connection-local. A reconnect can start a later
HTTP request, but it cannot resume body-delta delivery for an already-open
stream.
- Rendezvous uses the same logical split. The relay websocket and relay
`stream_id` are transport beneath the exec-server logical session for the
first reconnect slice. If a rendezvous websocket dies, the harness/client
may establish a fresh relay stream, then re-run exec-server initialize with
the prior `session_id` as `resume_session_id`. Existing process state recovers
through `process/read(after_seq)`, not relay-frame replay. Full same-stream
relay resume/replay remains a later protocol slice that requires endpoint-held
seq/ack/replay state.
## Transport
The server speaks the shared `codex-app-server-protocol` message envelope on

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@
//! This module is the facade for the environment-owned [`crate::HttpClient`]
//! capability:
//! - [`ReqwestHttpClient`] executes requests directly with `reqwest`
//! - [`ExecServerClient`] forwards requests over the JSON-RPC transport
//! - [`ExecServerConnection`] forwards requests over the JSON-RPC transport
//! - [`HttpResponseBodyStream`] presents buffered local bodies and streamed
//! remote `http/request/bodyDelta` notifications through one byte-stream API
//!
@@ -11,7 +11,7 @@
//! - orchestrator process: holds an `Arc<dyn HttpClient>` and chooses local or
//! remote execution
//! - remote runtime: serves the `http/request` RPC and runs the concrete local
//! HTTP request there when the orchestrator uses [`ExecServerClient`]
//! HTTP request there when the orchestrator uses [`ExecServerConnection`]
#[path = "reqwest_http_client.rs"]
mod reqwest_http_client;

File diff suppressed because it is too large Load Diff

View File

@@ -14,7 +14,7 @@ use tokio::sync::mpsc;
use super::HttpResponseBodyStream;
use super::response_body_stream::HttpBodyStreamRegistration;
use crate::HttpClient;
use crate::client::ExecServerClient;
use crate::client::ExecServerConnection;
use crate::client::ExecServerError;
use crate::protocol::HTTP_REQUEST_METHOD;
use crate::protocol::HttpRequestParams;
@@ -23,7 +23,7 @@ use crate::protocol::HttpRequestResponse;
/// Maximum queued body frames per streamed HTTP response.
const HTTP_BODY_DELTA_CHANNEL_CAPACITY: usize = 256;
impl ExecServerClient {
impl ExecServerConnection {
/// Performs an HTTP request and buffers the response body.
pub async fn http_request(
&self,
@@ -67,14 +67,14 @@ impl ExecServerClient {
}
}
impl HttpClient for ExecServerClient {
impl HttpClient for ExecServerConnection {
/// Orchestrator-side adapter that forwards buffered HTTP requests to the
/// remote runtime over the shared JSON-RPC connection.
fn http_request(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>> {
async move { ExecServerClient::http_request(self, params).await }.boxed()
async move { ExecServerConnection::http_request(self, params).await }.boxed()
}
/// Orchestrator-side adapter that forwards streamed HTTP requests to the
@@ -83,6 +83,6 @@ impl HttpClient for ExecServerClient {
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> {
async move { ExecServerClient::http_request_stream(self, params).await }.boxed()
async move { ExecServerConnection::http_request_stream(self, params).await }.boxed()
}
}

View File

@@ -9,7 +9,7 @@ use tracing::warn;
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
use crate::ExecServerClient;
use crate::ExecServerConnection;
use crate::ExecServerError;
use crate::client_api::RemoteExecServerConnectArgs;
use crate::client_api::StdioExecServerCommand;
@@ -17,9 +17,9 @@ use crate::client_api::StdioExecServerConnectArgs;
use crate::connection::JsonRpcConnection;
use crate::relay::harness_connection_from_websocket;
const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment";
pub(crate) const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment";
impl ExecServerClient {
impl ExecServerConnection {
pub(crate) async fn connect_for_transport(
transport_params: crate::client_api::ExecServerTransportParams,
) -> Result<Self, ExecServerError> {

View File

@@ -323,37 +323,20 @@ impl JsonRpcConnection {
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (websocket_writer, websocket_reader) = stream.split();
Self::from_websocket_parts(
websocket_writer,
websocket_reader,
connection_label,
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
)
Self::from_websocket_stream(stream, connection_label, /*ping_interval*/ None)
}
pub(crate) fn from_axum_websocket(stream: AxumWebSocket, connection_label: String) -> Self {
let (websocket_writer, websocket_reader) = stream.split();
Self::from_websocket_parts(
websocket_writer,
websocket_reader,
connection_label,
// Axum only wraps inbound exec-server websocket accepts. Outbound websocket clients
// own keepalive pings so one side does not accidentally create redundant traffic.
/*keepalive_interval*/
None,
)
Self::from_websocket_stream(stream, connection_label, Some(WEBSOCKET_KEEPALIVE_INTERVAL))
}
fn from_websocket_parts<W, R, M, E>(
mut websocket_writer: W,
mut websocket_reader: R,
fn from_websocket_stream<T, M, E>(
mut websocket: T,
connection_label: String,
keepalive_interval: Option<Duration>,
ping_interval: Option<Duration>,
) -> Self
where
W: Sink<M, Error = E> + Unpin + Send + 'static,
R: Stream<Item = Result<M, E>> + Unpin + Send + 'static,
T: Sink<M, Error = E> + Stream<Item = Result<M, E>> + Unpin + Send + 'static,
M: JsonRpcWebSocketMessage,
E: std::fmt::Display + Send + 'static,
{
@@ -361,118 +344,106 @@ impl JsonRpcConnection {
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (disconnected_tx, disconnected_rx) = watch::channel(false);
let reader_label = connection_label.clone();
let incoming_tx_for_reader = incoming_tx.clone();
let disconnected_tx_for_reader = disconnected_tx.clone();
let reader_task = tokio::spawn(async move {
let websocket_task = tokio::spawn(async move {
let mut ping_interval = ping_interval.map(|ping_interval| {
let mut interval = tokio::time::interval_at(
tokio::time::Instant::now() + ping_interval,
ping_interval,
);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
interval
});
loop {
match websocket_reader.next().await {
Some(Ok(message)) => match message.parse_jsonrpc_frame() {
Ok(JsonRpcWebSocketFrame::Message(message)) => {
if incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Message(message))
.await
.is_err()
{
break;
}
tokio::select! {
maybe_message = outgoing_rx.recv() => {
let Some(message) = maybe_message else {
break;
};
if let Err(reason) = send_websocket_jsonrpc_message(
&mut websocket,
&connection_label,
&message,
)
.await
{
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
break;
}
Err(err) => {
send_malformed_message(
&incoming_tx_for_reader,
Some(format!(
"failed to parse websocket JSON-RPC message from {reader_label}: {err}"
)),
)
.await;
}
_ = async {
match ping_interval.as_mut() {
Some(interval) => interval.tick().await,
None => std::future::pending().await,
}
Ok(JsonRpcWebSocketFrame::Close) => {
} => {
if let Err(err) = websocket.send(M::ping()).await {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
/*reason*/ None,
&incoming_tx,
&disconnected_tx,
Some(format!(
"failed to write websocket ping to {connection_label}: {err}"
)),
)
.await;
break;
}
Ok(JsonRpcWebSocketFrame::Ignore) => {}
},
Some(Err(err)) => {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
Some(format!(
"failed to read websocket JSON-RPC message from {reader_label}: {err}"
)),
)
.await;
break;
}
None => {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
/*reason*/ None,
)
.await;
break;
}
}
}
});
let writer_task = tokio::spawn(async move {
if let Some(keepalive_interval) = keepalive_interval {
let mut keepalive = tokio::time::interval_at(
tokio::time::Instant::now() + keepalive_interval,
keepalive_interval,
);
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
maybe_message = outgoing_rx.recv() => {
let Some(message) = maybe_message else {
break;
};
if let Err(reason) = send_websocket_jsonrpc_message(
&mut websocket_writer,
&connection_label,
&message,
)
.await
{
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
break;
}
}
_ = keepalive.tick() => {
if let Err(err) = websocket_writer.send(M::ping()).await {
incoming_message = websocket.next() => {
match incoming_message {
Some(Ok(message)) => match message.parse_jsonrpc_frame() {
Ok(JsonRpcWebSocketFrame::Message(message)) => {
if incoming_tx
.send(JsonRpcConnectionEvent::Message(message))
.await
.is_err()
{
break;
}
}
Ok(JsonRpcWebSocketFrame::Close) => {
send_disconnected(
&incoming_tx,
&disconnected_tx,
/*reason*/ None,
)
.await;
break;
}
Ok(JsonRpcWebSocketFrame::Ignore) => {}
Err(err) => {
send_malformed_message(
&incoming_tx,
Some(format!(
"failed to parse websocket JSON-RPC message from {connection_label}: {err}"
)),
)
.await;
}
},
Some(Err(err)) => {
send_disconnected(
&incoming_tx,
&disconnected_tx,
Some(format!(
"failed to write websocket ping to {connection_label}: {err}"
"failed to read websocket JSON-RPC message from {connection_label}: {err}"
)),
)
.await;
break;
}
None => {
send_disconnected(
&incoming_tx,
&disconnected_tx,
/*reason*/ None,
)
.await;
break;
}
}
}
}
} else {
while let Some(message) = outgoing_rx.recv().await {
if let Err(reason) = send_websocket_jsonrpc_message(
&mut websocket_writer,
&connection_label,
&message,
)
.await
{
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
break;
}
}
}
});
@@ -480,7 +451,7 @@ impl JsonRpcConnection {
outgoing_tx,
incoming_rx,
disconnected_rx,
task_handles: vec![reader_task, writer_task],
task_handles: vec![websocket_task],
transport: JsonRpcTransport::Plain,
}
}
@@ -619,34 +590,250 @@ fn serialize_jsonrpc_message(message: &JSONRPCMessage) -> Result<String, serde_j
#[cfg(test)]
mod tests {
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::task::Context;
use std::task::Poll;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::RequestId;
use futures::channel::mpsc as futures_mpsc;
use futures::stream;
use futures::task::Context;
use futures::task::Poll;
use futures::task::AtomicWaker;
use tokio::net::TcpListener;
use tokio::time::timeout;
use tokio_tungstenite::accept_async;
use tokio_tungstenite::connect_async;
use super::*;
struct TestWebSocketSink {
message_tx: futures_mpsc::UnboundedSender<Message>,
#[tokio::test]
async fn websocket_connection_sends_configured_ping() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let connection = JsonRpcConnection::from_websocket_stream(
client_websocket,
"test".into(),
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
);
let message = timeout(Duration::from_secs(1), server_websocket.next())
.await?
.expect("websocket should stay open")?;
assert!(matches!(message, Message::Ping(_)));
drop(connection);
Ok(())
}
impl Sink<Message> for TestWebSocketSink {
#[tokio::test]
async fn websocket_connection_ignores_server_pong() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let mut connection = JsonRpcConnection::from_websocket(client_websocket, "test".into());
server_websocket
.send(Message::Pong(b"check".to_vec().into()))
.await?;
assert!(
timeout(Duration::from_millis(50), connection.incoming_rx.recv())
.await
.is_err()
);
drop(connection);
Ok(())
}
#[tokio::test]
async fn websocket_connection_reports_server_close() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let mut connection = JsonRpcConnection::from_websocket(client_websocket, "test".into());
server_websocket.close(None).await?;
assert!(matches!(
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
Some(JsonRpcConnectionEvent::Disconnected { reason: None })
));
drop(connection);
Ok(())
}
#[tokio::test]
async fn websocket_connection_accepts_binary_jsonrpc_message() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let mut connection = JsonRpcConnection::from_websocket(client_websocket, "test".into());
let message = JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(1),
method: "test".to_string(),
params: None,
trace: None,
});
server_websocket
.send(Message::Binary(serde_json::to_vec(&message)?.into()))
.await?;
assert!(matches!(
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
Some(JsonRpcConnectionEvent::Message(actual)) if actual == message
));
drop(connection);
Ok(())
}
#[tokio::test]
async fn websocket_connection_keeps_outbound_message_while_send_is_backpressured()
-> anyhow::Result<()> {
let (websocket, control, mut outbound_rx) =
ControlledWebSocket::new(/*write_ready*/ false);
let mut connection = JsonRpcConnection::from_websocket_stream(
websocket,
"test".into(),
/*ping_interval*/ None,
);
let message = test_jsonrpc_message();
connection.outgoing_tx.send(message.clone()).await?;
control.wait_for_blocked_write().await?;
control.send_inbound(Message::Pong(b"check".to_vec().into()))?;
assert!(
timeout(Duration::from_millis(50), connection.incoming_rx.recv())
.await
.is_err()
);
control.set_write_ready();
assert!(matches!(
timeout(Duration::from_secs(1), outbound_rx.next()).await?,
Some(Message::Text(text)) if serde_json::from_str::<JSONRPCMessage>(&text)? == message
));
drop(connection);
Ok(())
}
async fn websocket_pair() -> anyhow::Result<(
WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
WebSocketStream<tokio::net::TcpStream>,
)> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let websocket_url = format!("ws://{}", listener.local_addr()?);
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await?;
accept_async(stream).await.map_err(anyhow::Error::from)
});
let (client_websocket, _) = connect_async(websocket_url).await?;
let server_websocket = server_task.await??;
Ok((client_websocket, server_websocket))
}
fn test_jsonrpc_message() -> JSONRPCMessage {
JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(1),
method: "test".to_string(),
params: None,
trace: None,
})
}
struct ControlledWebSocket {
inbound_rx: futures_mpsc::UnboundedReceiver<Result<Message, std::convert::Infallible>>,
outbound_tx: futures_mpsc::UnboundedSender<Message>,
write_ready: Arc<AtomicBool>,
write_blocked: Arc<AtomicBool>,
write_blocked_waker: Arc<AtomicWaker>,
write_waker: Arc<AtomicWaker>,
}
struct ControlledWebSocketHandle {
inbound_tx: futures_mpsc::UnboundedSender<Result<Message, std::convert::Infallible>>,
write_ready: Arc<AtomicBool>,
write_blocked: Arc<AtomicBool>,
write_blocked_waker: Arc<AtomicWaker>,
write_waker: Arc<AtomicWaker>,
}
impl ControlledWebSocket {
fn new(
write_ready: bool,
) -> (
Self,
ControlledWebSocketHandle,
futures_mpsc::UnboundedReceiver<Message>,
) {
let (inbound_tx, inbound_rx) = futures_mpsc::unbounded();
let (outbound_tx, outbound_rx) = futures_mpsc::unbounded();
let write_ready = Arc::new(AtomicBool::new(write_ready));
let write_blocked = Arc::new(AtomicBool::new(false));
let write_blocked_waker = Arc::new(AtomicWaker::new());
let write_waker = Arc::new(AtomicWaker::new());
(
Self {
inbound_rx,
outbound_tx,
write_ready: Arc::clone(&write_ready),
write_blocked: Arc::clone(&write_blocked),
write_blocked_waker: Arc::clone(&write_blocked_waker),
write_waker: Arc::clone(&write_waker),
},
ControlledWebSocketHandle {
inbound_tx,
write_ready,
write_blocked,
write_blocked_waker,
write_waker,
},
outbound_rx,
)
}
}
impl ControlledWebSocketHandle {
fn send_inbound(&self, message: Message) -> anyhow::Result<()> {
self.inbound_tx
.unbounded_send(Ok(message))
.map_err(anyhow::Error::from)
}
fn set_write_ready(&self) {
self.write_ready.store(true, Ordering::Release);
self.write_waker.wake();
}
async fn wait_for_blocked_write(&self) -> anyhow::Result<()> {
timeout(
Duration::from_secs(1),
futures::future::poll_fn(|cx| {
if self.write_blocked.load(Ordering::Acquire) {
Poll::Ready(())
} else {
self.write_blocked_waker.register(cx.waker());
Poll::Pending
}
}),
)
.await?;
Ok(())
}
}
impl Sink<Message> for ControlledWebSocket {
type Error = std::convert::Infallible;
fn poll_ready(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.write_ready.load(Ordering::Acquire) {
Poll::Ready(Ok(()))
} else {
self.write_blocked.store(true, Ordering::Release);
self.write_blocked_waker.wake();
self.write_waker.register(cx.waker());
Poll::Pending
}
}
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
self.get_mut()
.message_tx
self.outbound_tx
.unbounded_send(item)
.expect("test websocket receiver should stay open");
.expect("test outbound receiver should stay open");
Ok(())
}
@@ -665,24 +852,11 @@ mod tests {
}
}
#[tokio::test]
async fn websocket_connection_sends_keepalive_ping() {
let (message_tx, mut message_rx) = futures_mpsc::unbounded::<Message>();
let websocket_writer = TestWebSocketSink { message_tx };
let websocket_reader = stream::pending::<Result<Message, std::convert::Infallible>>();
let connection = JsonRpcConnection::from_websocket_parts(
websocket_writer,
websocket_reader,
"test".into(),
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
);
impl Stream for ControlledWebSocket {
type Item = Result<Message, std::convert::Infallible>;
let message = timeout(Duration::from_secs(1), message_rx.next())
.await
.expect("keepalive ping should arrive before timeout")
.expect("keepalive ping should be sent");
assert!(matches!(message, Message::Ping(_)));
drop(connection);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inbound_rx).poll_next(cx)
}
}
}

View File

@@ -6,7 +6,7 @@ use crate::ExecServerError;
use crate::ExecServerRuntimePaths;
use crate::ExecutorFileSystem;
use crate::HttpClient;
use crate::client::LazyRemoteExecServerClient;
use crate::client::RemoteExecServerClient;
use crate::client::http_client::ReqwestHttpClient;
use crate::client_api::ExecServerTransportParams;
use crate::environment_provider::DefaultEnvironmentProvider;
@@ -403,7 +403,7 @@ impl Environment {
} => Some(exec_server_url.clone()),
ExecServerTransportParams::StdioCommand { .. } => None,
};
let client = LazyRemoteExecServerClient::new(remote_transport.clone());
let client = RemoteExecServerClient::new(remote_transport.clone());
let exec_backend: Arc<dyn ExecBackend> = Arc::new(RemoteProcess::new(client.clone()));
let filesystem: Arc<dyn ExecutorFileSystem> =
Arc::new(RemoteFileSystem::new(client.clone()));

View File

@@ -23,8 +23,9 @@ mod runtime_paths;
mod sandboxed_file_system;
mod server;
pub use client::ExecServerClient;
pub use client::ExecServerConnection;
pub use client::ExecServerError;
pub type ExecServerClient = ExecServerConnection;
pub use client::http_client::HttpResponseBodyStream;
pub use client::http_client::ReqwestHttpClient;
pub use client_api::ExecServerClientConnectOptions;

View File

@@ -23,13 +23,16 @@ pub struct StartedExecProcess {
/// The stream is scoped to one [`ExecProcess`] handle. `Output` events carry
/// stdout, stderr, or pty bytes. `Exited` reports the process exit status, while
/// `Closed` means all output streams have ended and no more output events will
/// arrive. `Failed` is used when the process session cannot continue, for
/// example because the remote executor connection disconnected.
/// arrive. `ResyncRequired` means a reconnecting remote process session should
/// recover through [`ExecProcess::read`] using its last delivered sequence.
/// `Failed` is used when the process session cannot continue, for example
/// because a direct one-shot executor connection disconnected.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExecProcessEvent {
Output(ProcessOutputChunk),
Exited { seq: u64, exit_code: i32 },
Closed { seq: u64 },
ResyncRequired,
Failed(String),
}
@@ -60,13 +63,14 @@ struct ExecProcessEventHistory {
impl ExecProcessEvent {
/// Sequence number used to order process-owned events.
///
/// `Failed` is intentionally unsequenced because it is synthesized by the
/// client when the session or transport fails, not emitted by the process.
/// `ResyncRequired` and `Failed` are intentionally unsequenced because they
/// are synthesized by the client when the session or transport changes,
/// not emitted by the process.
pub(crate) fn seq(&self) -> Option<u64> {
match self {
ExecProcessEvent::Output(chunk) => Some(chunk.seq),
ExecProcessEvent::Exited { seq, .. } | ExecProcessEvent::Closed { seq } => Some(*seq),
ExecProcessEvent::Failed(_) => None,
ExecProcessEvent::ResyncRequired | ExecProcessEvent::Failed(_) => None,
}
}
@@ -74,7 +78,9 @@ impl ExecProcessEvent {
match self {
ExecProcessEvent::Output(chunk) => chunk.chunk.0.len(),
ExecProcessEvent::Failed(message) => message.len(),
ExecProcessEvent::Exited { .. } | ExecProcessEvent::Closed { .. } => 0,
ExecProcessEvent::Exited { .. }
| ExecProcessEvent::Closed { .. }
| ExecProcessEvent::ResyncRequired => 0,
}
}
}

View File

@@ -1,7 +1,9 @@
use std::collections::HashMap;
use codex_app_server_protocol::JSONRPCMessage;
use futures::Sink;
use futures::SinkExt;
use futures::Stream;
use futures::StreamExt;
use prost::Message as ProstMessage;
use tokio::io::AsyncRead;
@@ -140,121 +142,25 @@ fn jsonrpc_payload(message: &JSONRPCMessage) -> Result<Vec<u8>, ExecServerError>
serde_json::to_vec(message).map_err(ExecServerError::Json)
}
pub(crate) fn harness_connection_from_websocket<S>(
stream: WebSocketStream<S>,
pub(crate) fn harness_connection_from_websocket<T, E>(
stream: T,
connection_label: String,
) -> JsonRpcConnection
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
T: Sink<Message, Error = E> + Stream<Item = Result<Message, E>> + Unpin + Send + 'static,
E: std::fmt::Display + Send + 'static,
{
let stream_id = Uuid::new_v4().to_string();
let (mut websocket_writer, mut websocket_reader) = stream.split();
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (disconnected_tx, disconnected_rx) = watch::channel(false);
let reader_label = connection_label;
let reader_stream_id = stream_id.clone();
let incoming_tx_for_reader = incoming_tx;
let disconnected_tx_for_reader = disconnected_tx.clone();
let reader_task = tokio::spawn(async move {
loop {
match websocket_reader.next().await {
Some(Ok(Message::Binary(payload))) => {
let frame = match decode_relay_message_frame(payload.as_ref()) {
Ok(frame) => frame,
Err(err) => {
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: format!(
"failed to parse relay message frame from {reader_label}: {err}"
),
})
.await;
continue;
}
};
if frame.stream_id != reader_stream_id {
continue;
}
let kind = match frame.validate() {
Ok(kind) => kind,
Err(err) => {
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: err.to_string(),
})
.await;
continue;
}
};
match kind {
RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() {
Ok(message) => {
if incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Message(message))
.await
.is_err()
{
break;
}
}
Err(err) => {
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: err.to_string(),
})
.await;
}
},
RelayFrameBodyKind::Reset => {
let _ = disconnected_tx_for_reader.send(true);
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Disconnected {
reason: frame.into_reset_reason(),
})
.await;
break;
}
RelayFrameBodyKind::Ack
| RelayFrameBodyKind::Resume
| RelayFrameBodyKind::Heartbeat => {}
}
}
Some(Ok(Message::Close(_))) | None => {
let _ = disconnected_tx_for_reader.send(true);
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Disconnected { reason: None })
.await;
break;
}
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {}
Some(Ok(Message::Text(_))) => {
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: "relay exec-server transport expects binary protobuf frames"
.to_string(),
})
.await;
}
Some(Err(err)) => {
let _ = disconnected_tx_for_reader.send(true);
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Disconnected {
reason: Some(format!(
"failed to read relay websocket frame from {reader_label}: {err}"
)),
})
.await;
break;
}
}
}
});
let writer_task = tokio::spawn(async move {
let websocket_task = tokio::spawn(async move {
let mut websocket = stream;
let reader_label = connection_label;
let reader_stream_id = stream_id.clone();
let resume = RelayMessageFrame::resume(stream_id.clone());
if websocket_writer
if websocket
.send(Message::Binary(encode_relay_message_frame(&resume).into()))
.await
.is_err()
@@ -263,12 +169,12 @@ where
return;
}
let mut next_seq = 0u32;
let mut keepalive = tokio::time::interval_at(
tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL,
WEBSOCKET_KEEPALIVE_INTERVAL,
);
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut next_seq = 0u32;
loop {
tokio::select! {
maybe_message = outgoing_rx.recv() => {
@@ -284,7 +190,7 @@ where
};
let frame = RelayMessageFrame::data(stream_id.clone(), next_seq, payload);
next_seq = next_seq.wrapping_add(1);
if websocket_writer
if websocket
.send(Message::Binary(encode_relay_message_frame(&frame).into()))
.await
.is_err()
@@ -294,11 +200,103 @@ where
}
}
_ = keepalive.tick() => {
if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() {
if websocket.send(Message::Ping(Vec::new().into())).await.is_err() {
let _ = disconnected_tx.send(true);
break;
}
}
incoming_message = websocket.next() => {
match incoming_message {
Some(Ok(Message::Binary(payload))) => {
let frame = match decode_relay_message_frame(payload.as_ref()) {
Ok(frame) => frame,
Err(err) => {
let _ = incoming_tx
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: format!(
"failed to parse relay message frame from {reader_label}: {err}"
),
})
.await;
continue;
}
};
if frame.stream_id != reader_stream_id {
continue;
}
let kind = match frame.validate() {
Ok(kind) => kind,
Err(err) => {
let _ = incoming_tx
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: err.to_string(),
})
.await;
continue;
}
};
match kind {
RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() {
Ok(message) => {
if incoming_tx
.send(JsonRpcConnectionEvent::Message(message))
.await
.is_err()
{
break;
}
}
Err(err) => {
let _ = incoming_tx
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: err.to_string(),
})
.await;
}
},
RelayFrameBodyKind::Reset => {
let _ = disconnected_tx.send(true);
let _ = incoming_tx
.send(JsonRpcConnectionEvent::Disconnected {
reason: frame.into_reset_reason(),
})
.await;
break;
}
RelayFrameBodyKind::Ack
| RelayFrameBodyKind::Resume
| RelayFrameBodyKind::Heartbeat => {}
}
}
Some(Ok(Message::Close(_))) | None => {
let _ = disconnected_tx.send(true);
let _ = incoming_tx
.send(JsonRpcConnectionEvent::Disconnected { reason: None })
.await;
break;
}
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {}
Some(Ok(Message::Text(_))) => {
let _ = incoming_tx
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: "relay exec-server transport expects binary protobuf frames"
.to_string(),
})
.await;
}
Some(Err(err)) => {
let _ = disconnected_tx.send(true);
let _ = incoming_tx
.send(JsonRpcConnectionEvent::Disconnected {
reason: Some(format!(
"failed to read relay websocket frame from {reader_label}: {err}"
)),
})
.await;
break;
}
}
}
}
}
});
@@ -307,7 +305,7 @@ where
outgoing_tx,
incoming_rx,
disconnected_rx,
task_handles: vec![reader_task, writer_task],
task_handles: vec![websocket_task],
transport: JsonRpcTransport::Plain,
}
}
@@ -318,59 +316,53 @@ pub(crate) async fn run_multiplexed_executor<S>(
) where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (mut websocket_writer, mut websocket_reader) = stream.split();
let mut websocket = stream;
let (physical_outgoing_tx, mut physical_outgoing_rx) =
mpsc::channel::<Vec<u8>>(CHANNEL_CAPACITY);
let writer_task = tokio::spawn(async move {
let mut keepalive = tokio::time::interval_at(
tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL,
WEBSOCKET_KEEPALIVE_INTERVAL,
);
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
maybe_encoded = physical_outgoing_rx.recv() => {
let Some(encoded) = maybe_encoded else {
break;
};
if websocket_writer
.send(Message::Binary(encoded.into()))
.await
.is_err()
{
break;
}
}
_ = keepalive.tick() => {
if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() {
break;
}
}
}
}
});
let mut streams: HashMap<String, VirtualStream> = HashMap::new();
let mut keepalive = tokio::time::interval_at(
tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL,
WEBSOCKET_KEEPALIVE_INTERVAL,
);
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
let frame = match websocket_reader.next().await {
Some(Ok(Message::Binary(payload))) => {
match decode_relay_message_frame(payload.as_ref()) {
Ok(frame) => frame,
Err(err) => {
warn!("dropping malformed relay message frame from harness: {err}");
continue;
}
let frame = tokio::select! {
maybe_encoded = physical_outgoing_rx.recv() => {
let Some(encoded) = maybe_encoded else {
break;
};
if websocket.send(Message::Binary(encoded.into())).await.is_err() {
break;
}
}
Some(Ok(Message::Close(_))) | None => break,
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue,
Some(Ok(Message::Text(_))) => {
warn!("dropping non-binary relay message frame from harness");
continue;
}
Some(Err(err)) => {
debug!("multiplexed executor websocket read failed: {err}");
break;
_ = keepalive.tick() => {
if websocket.send(Message::Ping(Vec::new().into())).await.is_err() {
break;
}
continue;
}
incoming_message = websocket.next() => match incoming_message {
Some(Ok(Message::Binary(payload))) => {
match decode_relay_message_frame(payload.as_ref()) {
Ok(frame) => frame,
Err(err) => {
warn!("dropping malformed relay message frame from harness: {err}");
continue;
}
}
}
Some(Ok(Message::Close(_))) | None => break,
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue,
Some(Ok(Message::Text(_))) => {
warn!("dropping non-binary relay message frame from harness");
continue;
}
Some(Err(err)) => {
debug!("multiplexed executor websocket read failed: {err}");
break;
}
}
};
@@ -423,7 +415,6 @@ pub(crate) async fn run_multiplexed_executor<S>(
stream.disconnect(/*reason*/ None).await;
}
drop(physical_outgoing_tx);
let _ = writer_task.await;
}
struct VirtualStream {
@@ -492,8 +483,20 @@ fn spawn_virtual_stream(
#[cfg(test)]
mod tests {
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::RequestId;
use futures::Sink;
use futures::Stream;
use futures::channel::mpsc as futures_mpsc;
use futures::task::AtomicWaker;
use tokio::net::TcpListener;
use tokio::time::timeout;
use tokio_tungstenite::accept_async;
@@ -502,40 +505,107 @@ mod tests {
use super::*;
fn test_runtime_paths() -> anyhow::Result<crate::ExecServerRuntimePaths> {
crate::ExecServerRuntimePaths::new(
std::env::current_exe()?,
/*codex_linux_sandbox_exe*/ None,
)
.map_err(anyhow::Error::from)
}
#[tokio::test]
async fn multiplexed_executor_sends_keepalive_ping() -> anyhow::Result<()> {
async fn harness_connection_receives_relay_data() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let executor_task = tokio::spawn(run_multiplexed_executor(
client_websocket,
ConnectionProcessor::new(test_runtime_paths()?),
let mut connection =
harness_connection_from_websocket(client_websocket, "test".to_string());
let stream_id = read_resume_stream_id(&mut server_websocket).await?;
let message = test_jsonrpc_message();
server_websocket
.send(Message::Binary(
encode_relay_message_frame(&RelayMessageFrame::data(
stream_id,
/*seq*/ 0,
jsonrpc_payload(&message)?,
))
.into(),
))
.await?;
assert!(matches!(
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
Some(JsonRpcConnectionEvent::Message(actual)) if actual == message
));
read_keepalive_ping(&mut server_websocket).await?;
executor_task.abort();
let _ = executor_task.await;
drop(connection);
Ok(())
}
#[tokio::test]
async fn harness_connection_sends_keepalive_ping() -> anyhow::Result<()> {
async fn harness_connection_reports_text_frames_as_malformed() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let connection = harness_connection_from_websocket(client_websocket, "test".to_string());
let mut connection =
harness_connection_from_websocket(client_websocket, "test".to_string());
read_keepalive_ping(&mut server_websocket).await?;
read_resume_stream_id(&mut server_websocket).await?;
server_websocket.send(Message::Text("nope".into())).await?;
assert!(matches!(
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
Some(JsonRpcConnectionEvent::MalformedMessage { reason })
if reason == "relay exec-server transport expects binary protobuf frames"
));
drop(connection);
Ok(())
}
#[tokio::test]
async fn harness_connection_reports_server_close() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let mut connection =
harness_connection_from_websocket(client_websocket, "test".to_string());
read_resume_stream_id(&mut server_websocket).await?;
server_websocket.close(None).await?;
assert!(matches!(
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
Some(JsonRpcConnectionEvent::Disconnected { reason: None })
));
drop(connection);
Ok(())
}
#[tokio::test]
async fn harness_connection_keeps_outbound_frame_while_send_is_backpressured()
-> anyhow::Result<()> {
let (websocket, control, mut outbound_rx) =
ControlledWebSocket::new(/*write_ready*/ true);
let mut connection = harness_connection_from_websocket(websocket, "test".to_string());
let Message::Binary(resume_payload) = timeout(Duration::from_secs(1), outbound_rx.next())
.await?
.expect("resume frame")
else {
anyhow::bail!("expected relay resume frame");
};
let stream_id = decode_relay_message_frame(resume_payload.as_ref())?.stream_id;
let message = test_jsonrpc_message();
control.set_write_blocked();
connection.outgoing_tx.send(message.clone()).await?;
control.wait_for_blocked_write().await?;
control.send_inbound(Message::Pong(b"check".to_vec().into()))?;
assert!(
timeout(Duration::from_millis(50), connection.incoming_rx.recv())
.await
.is_err()
);
control.set_write_ready();
let Message::Binary(data_payload) = timeout(Duration::from_secs(1), outbound_rx.next())
.await?
.expect("data frame")
else {
anyhow::bail!("expected relay data frame");
};
let frame = decode_relay_message_frame(data_payload.as_ref())?;
assert_eq!(frame.stream_id, stream_id);
assert_eq!(frame.into_jsonrpc_message()?, message);
drop(connection);
Ok(())
}
async fn websocket_pair() -> anyhow::Result<(
WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
WebSocketStream<tokio::net::TcpStream>,
@@ -551,18 +621,155 @@ mod tests {
Ok((client_websocket, server_websocket))
}
async fn read_keepalive_ping(
async fn read_resume_stream_id(
websocket: &mut WebSocketStream<tokio::net::TcpStream>,
) -> anyhow::Result<()> {
loop {
let Some(message) = timeout(Duration::from_secs(1), websocket.next()).await? else {
anyhow::bail!("websocket closed before keepalive ping");
};
match message? {
Message::Ping(_) => return Ok(()),
Message::Binary(_) | Message::Text(_) | Message::Pong(_) | Message::Frame(_) => {}
Message::Close(_) => anyhow::bail!("websocket closed before keepalive ping"),
) -> anyhow::Result<String> {
let message = timeout(Duration::from_secs(1), websocket.next())
.await?
.expect("websocket should stay open")?;
let Message::Binary(payload) = message else {
anyhow::bail!("expected relay resume frame, got {message:?}");
};
let frame = decode_relay_message_frame(payload.as_ref())?;
assert_eq!(frame.validate()?, RelayFrameBodyKind::Resume);
Ok(frame.stream_id)
}
fn test_jsonrpc_message() -> JSONRPCMessage {
JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(1),
method: "test".to_string(),
params: None,
trace: None,
})
}
struct ControlledWebSocket {
inbound_rx: futures_mpsc::UnboundedReceiver<Result<Message, std::convert::Infallible>>,
outbound_tx: futures_mpsc::UnboundedSender<Message>,
write_ready: Arc<AtomicBool>,
write_blocked: Arc<AtomicBool>,
write_blocked_waker: Arc<AtomicWaker>,
write_waker: Arc<AtomicWaker>,
}
struct ControlledWebSocketHandle {
inbound_tx: futures_mpsc::UnboundedSender<Result<Message, std::convert::Infallible>>,
write_ready: Arc<AtomicBool>,
write_blocked: Arc<AtomicBool>,
write_blocked_waker: Arc<AtomicWaker>,
write_waker: Arc<AtomicWaker>,
}
impl ControlledWebSocket {
fn new(
write_ready: bool,
) -> (
Self,
ControlledWebSocketHandle,
futures_mpsc::UnboundedReceiver<Message>,
) {
let (inbound_tx, inbound_rx) = futures_mpsc::unbounded();
let (outbound_tx, outbound_rx) = futures_mpsc::unbounded();
let write_ready = Arc::new(AtomicBool::new(write_ready));
let write_blocked = Arc::new(AtomicBool::new(false));
let write_blocked_waker = Arc::new(AtomicWaker::new());
let write_waker = Arc::new(AtomicWaker::new());
(
Self {
inbound_rx,
outbound_tx,
write_ready: Arc::clone(&write_ready),
write_blocked: Arc::clone(&write_blocked),
write_blocked_waker: Arc::clone(&write_blocked_waker),
write_waker: Arc::clone(&write_waker),
},
ControlledWebSocketHandle {
inbound_tx,
write_ready,
write_blocked,
write_blocked_waker,
write_waker,
},
outbound_rx,
)
}
}
impl ControlledWebSocketHandle {
fn send_inbound(&self, message: Message) -> anyhow::Result<()> {
self.inbound_tx
.unbounded_send(Ok(message))
.map_err(anyhow::Error::from)
}
fn set_write_blocked(&self) {
self.write_ready.store(false, Ordering::Release);
}
fn set_write_ready(&self) {
self.write_ready.store(true, Ordering::Release);
self.write_waker.wake();
}
async fn wait_for_blocked_write(&self) -> anyhow::Result<()> {
timeout(
Duration::from_secs(1),
futures::future::poll_fn(|cx| {
if self.write_blocked.load(Ordering::Acquire) {
Poll::Ready(())
} else {
self.write_blocked_waker.register(cx.waker());
Poll::Pending
}
}),
)
.await?;
Ok(())
}
}
impl Sink<Message> for ControlledWebSocket {
type Error = std::convert::Infallible;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.write_ready.load(Ordering::Acquire) {
Poll::Ready(Ok(()))
} else {
self.write_blocked.store(true, Ordering::Release);
self.write_blocked_waker.wake();
self.write_waker.register(cx.waker());
Poll::Pending
}
}
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
self.outbound_tx
.unbounded_send(item)
.expect("test outbound receiver should stay open");
Ok(())
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
impl Stream for ControlledWebSocket {
type Item = Result<Message, std::convert::Infallible>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inbound_rx).poll_next(cx)
}
}
}

View File

@@ -14,7 +14,7 @@ use crate::FileSystemResult;
use crate::FileSystemSandboxContext;
use crate::ReadDirectoryEntry;
use crate::RemoveOptions;
use crate::client::LazyRemoteExecServerClient;
use crate::client::RemoteExecServerClient;
use crate::protocol::FsCopyParams;
use crate::protocol::FsCreateDirectoryParams;
use crate::protocol::FsGetMetadataParams;
@@ -28,11 +28,11 @@ const NOT_FOUND_ERROR_CODE: i64 = -32004;
#[derive(Clone)]
pub(crate) struct RemoteFileSystem {
client: LazyRemoteExecServerClient,
client: RemoteExecServerClient,
}
impl RemoteFileSystem {
pub(crate) fn new(client: LazyRemoteExecServerClient) -> Self {
pub(crate) fn new(client: RemoteExecServerClient) -> Self {
trace!("remote fs new");
Self { client }
}
@@ -46,8 +46,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
sandbox: Option<&FileSystemSandboxContext>,
) -> FileSystemResult<Vec<u8>> {
trace!("remote fs read_file");
let client = self.client.get().await.map_err(map_remote_error)?;
let response = client
let connection = self.client.connection().await.map_err(map_remote_error)?;
let response = connection
.fs_read_file(FsReadFileParams {
path: path.clone(),
sandbox: remote_sandbox_context(sandbox),
@@ -69,8 +69,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
sandbox: Option<&FileSystemSandboxContext>,
) -> FileSystemResult<()> {
trace!("remote fs write_file");
let client = self.client.get().await.map_err(map_remote_error)?;
client
let connection = self.client.connection().await.map_err(map_remote_error)?;
connection
.fs_write_file(FsWriteFileParams {
path: path.clone(),
data_base64: STANDARD.encode(contents),
@@ -88,8 +88,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
sandbox: Option<&FileSystemSandboxContext>,
) -> FileSystemResult<()> {
trace!("remote fs create_directory");
let client = self.client.get().await.map_err(map_remote_error)?;
client
let connection = self.client.connection().await.map_err(map_remote_error)?;
connection
.fs_create_directory(FsCreateDirectoryParams {
path: path.clone(),
recursive: Some(options.recursive),
@@ -106,8 +106,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
sandbox: Option<&FileSystemSandboxContext>,
) -> FileSystemResult<FileMetadata> {
trace!("remote fs get_metadata");
let client = self.client.get().await.map_err(map_remote_error)?;
let response = client
let connection = self.client.connection().await.map_err(map_remote_error)?;
let response = connection
.fs_get_metadata(FsGetMetadataParams {
path: path.clone(),
sandbox: remote_sandbox_context(sandbox),
@@ -129,8 +129,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
sandbox: Option<&FileSystemSandboxContext>,
) -> FileSystemResult<Vec<ReadDirectoryEntry>> {
trace!("remote fs read_directory");
let client = self.client.get().await.map_err(map_remote_error)?;
let response = client
let connection = self.client.connection().await.map_err(map_remote_error)?;
let response = connection
.fs_read_directory(FsReadDirectoryParams {
path: path.clone(),
sandbox: remote_sandbox_context(sandbox),
@@ -155,8 +155,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
sandbox: Option<&FileSystemSandboxContext>,
) -> FileSystemResult<()> {
trace!("remote fs remove");
let client = self.client.get().await.map_err(map_remote_error)?;
client
let connection = self.client.connection().await.map_err(map_remote_error)?;
connection
.fs_remove(FsRemoveParams {
path: path.clone(),
recursive: Some(options.recursive),
@@ -176,8 +176,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
sandbox: Option<&FileSystemSandboxContext>,
) -> FileSystemResult<()> {
trace!("remote fs copy");
let client = self.client.get().await.map_err(map_remote_error)?;
client
let connection = self.client.connection().await.map_err(map_remote_error)?;
connection
.fs_copy(FsCopyParams {
source_path: source_path.clone(),
destination_path: destination_path.clone(),

View File

@@ -1,31 +1,33 @@
use std::sync::Arc;
use async_trait::async_trait;
use tokio::runtime::Handle;
use tokio::sync::watch;
use tracing::trace;
use tracing::warn;
use crate::ExecBackend;
use crate::ExecProcess;
use crate::ExecProcessEventReceiver;
use crate::ExecServerError;
use crate::StartedExecProcess;
use crate::client::LazyRemoteExecServerClient;
use crate::client::Session;
use crate::client::ProcessSessionHandle;
use crate::client::RemoteExecServerClient;
use crate::protocol::ExecParams;
use crate::protocol::ReadResponse;
use crate::protocol::WriteResponse;
#[derive(Clone)]
pub(crate) struct RemoteProcess {
client: LazyRemoteExecServerClient,
client: RemoteExecServerClient,
}
struct RemoteExecProcess {
session: Session,
session: ProcessSessionHandle,
}
impl RemoteProcess {
pub(crate) fn new(client: LazyRemoteExecServerClient) -> Self {
pub(crate) fn new(client: RemoteExecServerClient) -> Self {
trace!("remote process new");
Self { client }
}
@@ -35,9 +37,8 @@ impl RemoteProcess {
impl ExecBackend for RemoteProcess {
async fn start(&self, params: ExecParams) -> Result<StartedExecProcess, ExecServerError> {
let process_id = params.process_id.clone();
let client = self.client.get().await?;
let session = client.register_session(&process_id).await?;
if let Err(err) = client.exec(params).await {
let (session, connection) = self.client.register_process_session(&process_id).await?;
if let Err(err) = connection.exec(params).await {
session.unregister().await;
return Err(err);
}
@@ -85,8 +86,14 @@ impl ExecProcess for RemoteExecProcess {
impl Drop for RemoteExecProcess {
fn drop(&mut self) {
let session = self.session.clone();
tokio::spawn(async move {
let Ok(handle) = Handle::try_current() else {
warn!(
"Could not schedule remote exec process unregister on drop: no Tokio runtime is available"
);
return;
};
std::mem::drop(handle.spawn(async move {
session.unregister().await;
});
}));
}
}

View File

@@ -166,6 +166,7 @@ async fn collect_process_output_from_events(
drop(session);
return Ok((stdout, stderr, exit_code, true));
}
ExecProcessEvent::ResyncRequired => continue,
ExecProcessEvent::Failed(message) => {
anyhow::bail!("process failed before closed state: {message}");
}
@@ -189,6 +190,7 @@ async fn collect_process_event_snapshots(
ProcessEventSnapshot::Exited { seq, exit_code }
}
ExecProcessEvent::Closed { seq } => ProcessEventSnapshot::Closed { seq },
ExecProcessEvent::ResyncRequired => continue,
ExecProcessEvent::Failed(message) => {
anyhow::bail!("process failed before closed state: {message}");
}
@@ -541,7 +543,7 @@ async fn assert_exec_process_preserves_queued_events_before_subscribe(
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
// Serialize tests that launch a real exec-server process through the full CLI.
#[serial_test::serial(remote_exec_server)]
async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
async fn remote_exec_process_surfaces_reconnect_failure_after_server_shutdown() -> Result<()> {
let mut context = create_process_context(/*use_remote*/ true).await?;
let session = context
.backend
@@ -562,7 +564,6 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
.await?;
let process = Arc::clone(&session.process);
let mut events = process.subscribe_events();
let process_for_pending_read = Arc::clone(&process);
let pending_read = tokio::spawn(async move {
process_for_pending_read
@@ -579,36 +580,33 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
.expect("remote context should include exec-server harness");
server.shutdown().await?;
let event = timeout(Duration::from_secs(2), events.recv()).await??;
let ExecProcessEvent::Failed(event_message) = event else {
anyhow::bail!("expected process failure event, got {event:?}");
};
let pending_error = timeout(Duration::from_secs(2), pending_read)
.await
.context("timed out waiting for pending read after reconnect failure")??
.expect_err("pending read should fail after reconnect fails");
assert!(
event_message.starts_with("exec-server transport disconnected"),
"unexpected failure event: {event_message}"
pending_error
.to_string()
.starts_with("failed to connect to exec-server websocket"),
"unexpected pending read error: {pending_error}"
);
let pending_response = timeout(Duration::from_secs(2), pending_read).await???;
let pending_message = pending_response
.failure
.expect("pending read should surface disconnect as a failure");
let read_error = timeout(
Duration::from_secs(2),
process.read(
/*after_seq*/ None,
/*max_bytes*/ None,
/*wait_ms*/ Some(0),
),
)
.await
.context("timed out waiting for read after reconnect failure")?
.expect_err("read after reconnect failure should fail");
assert!(
pending_message.starts_with("exec-server transport disconnected"),
"unexpected pending failure message: {pending_message}"
);
let mut wake_rx = process.subscribe_wake();
let response = read_process_until_change(process, &mut wake_rx, /*after_seq*/ None).await?;
let message = response
.failure
.expect("disconnect should surface as a failure");
assert!(
message.starts_with("exec-server transport disconnected"),
"unexpected failure message: {message}"
);
assert!(
response.closed,
"disconnect should close the process session"
read_error
.to_string()
.starts_with("failed to connect to exec-server websocket"),
"unexpected read error: {read_error}"
);
let write_result = timeout(
@@ -621,7 +619,7 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
assert!(
write_error
.to_string()
.starts_with("exec-server transport disconnected"),
.starts_with("failed to connect to exec-server websocket"),
"unexpected write error: {write_error}"
);

View File

@@ -193,6 +193,15 @@ impl ExecutorProcessTransport {
self.note_seq(seq);
self.closed = true;
}
Ok(ExecProcessEvent::ResyncRequired) => {
if let Err(error) = self.recover_process_events().await {
warn!(
"Failed to resync remote MCP server output stream ({}): {error}",
self.program_name
);
self.closed = true;
}
}
Ok(ExecProcessEvent::Failed(message)) => {
warn!(
"Remote MCP server process failed ({}): {message}",
@@ -205,7 +214,7 @@ impl ExecutorProcessTransport {
"Remote MCP server output stream lagged ({}): skipped {skipped} events",
self.program_name
);
if let Err(error) = self.recover_lagged_events().await {
if let Err(error) = self.recover_process_events().await {
warn!(
"Failed to recover remote MCP server output stream ({}): {error}",
self.program_name
@@ -232,7 +241,7 @@ impl ExecutorProcessTransport {
true
}
async fn recover_lagged_events(&mut self) -> io::Result<()> {
async fn recover_process_events(&mut self) -> io::Result<()> {
let response = self
.process
.read(