mirror of
https://github.com/openai/codex.git
synced 2026-05-22 03:54:18 +00:00
Compare commits
12 Commits
rust-v0.13
...
starr/exec
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee0db3246a | ||
|
|
edd0d6a66e | ||
|
|
d82bb0d81e | ||
|
|
b25c4b53c9 | ||
|
|
ef8267fb69 | ||
|
|
90804bb2eb | ||
|
|
6f9640533a | ||
|
|
6d6cdeb128 | ||
|
|
9215e15ee3 | ||
|
|
d94300d782 | ||
|
|
3673b69a2a | ||
|
|
8a9300e92a |
@@ -7,13 +7,112 @@ JSON-RPC server for spawning and controlling subprocesses through
|
||||
It provides:
|
||||
|
||||
- a CLI entrypoint: `codex exec-server`
|
||||
- a Rust client: `ExecServerClient`
|
||||
- a Rust connection client: `ExecServerConnection`
|
||||
- a small protocol module with shared request/response types
|
||||
|
||||
This crate owns the transport, protocol, and filesystem/process handlers. The
|
||||
top-level `codex` binary owns hidden helper dispatch for sandboxed
|
||||
filesystem operations and `codex-linux-sandbox`.
|
||||
|
||||
## Client And Session Ownership
|
||||
|
||||
Remote environments expose one logical exec-server client to the rest of Codex.
|
||||
That client is not the same thing as one websocket connection.
|
||||
|
||||
```text
|
||||
Environment
|
||||
`- one RemoteExecServerClient
|
||||
|- RemoteExecServerSession
|
||||
| |- logical_session_id
|
||||
| |- current ExecServerConnection?
|
||||
| |- one in-flight reconnect attempt
|
||||
| |- terminal reconnect error?
|
||||
| `- tracked process_sessions: HashMap<ProcessId, Weak<ProcessSession>>
|
||||
|- RemoteProcess -> RemoteExecProcess -> ProcessSessionHandle
|
||||
|- RemoteFileSystem
|
||||
`- HttpClient for RemoteExecServerClient
|
||||
|
||||
ExecServerConnection
|
||||
`- Inner
|
||||
|- RpcClient
|
||||
|- reader task
|
||||
|- disconnect latch
|
||||
|- connection-local process_session_routes
|
||||
|- connection-local HTTP body stream routes
|
||||
`- initialized session_id for this live binding
|
||||
|
||||
ProcessSessionHandle
|
||||
|- process_id
|
||||
|- Arc<ProcessSession>
|
||||
`- ProcessSessionControl
|
||||
|
||||
ProcessSession
|
||||
|- wake channel
|
||||
|- event log
|
||||
|- ordered event buffer
|
||||
|- failure state
|
||||
`- disconnect policy
|
||||
|
||||
ProcessSessionControl
|
||||
|- Connection(ExecServerConnection) for direct/test one-shot sessions
|
||||
`- RemoteClient(RemoteExecServerClient) for reconnecting remote environments
|
||||
```
|
||||
|
||||
The main roles are:
|
||||
|
||||
- `RemoteExecServerClient`: environment-owned logical client. `Environment`
|
||||
clones this into the remote process backend, remote filesystem, and remote
|
||||
HTTP capability so all remote APIs share one reconnecting session.
|
||||
- `RemoteExecServerSession`: durable logical-session state behind the client.
|
||||
It remembers the resumable session id, current live connection, one shared
|
||||
reconnect attempt, weak references to process sessions that may need rebinding,
|
||||
and any terminal resume error.
|
||||
- `ExecServerConnection`: one live JSON-RPC transport binding. It owns
|
||||
connection-local routing for notifications and streamed HTTP response bodies.
|
||||
- `Inner`: private per-connection machinery behind `ExecServerConnection`.
|
||||
It owns the `RpcClient`, reader task, disconnect latch, connection-local
|
||||
process notification routes, connection-local HTTP body stream routes, and
|
||||
initialized session id for that live binding.
|
||||
- `ProcessSession`: durable per-process client state owned by the live process
|
||||
handle. It keeps the local event log, wake cursor, and failure state that must
|
||||
survive connection replacement while that handle still exists.
|
||||
- `ProcessSessionHandle`: process-facing handle used by `RemoteExecProcess`.
|
||||
It routes reads, writes, terminate, and unregister through either a focused
|
||||
direct connection test path or the logical reconnecting client path.
|
||||
- `ProcessSessionControl`: small command-path enum for a
|
||||
`ProcessSessionHandle`. It is not an owner; it only chooses direct
|
||||
`ExecServerConnection` versus reconnecting `RemoteExecServerClient`.
|
||||
- `RemoteProcess`, `RemoteFileSystem`, and `HttpClient`: thin capability
|
||||
adapters. They should not own reconnect state themselves.
|
||||
|
||||
Reconnect invariants:
|
||||
|
||||
- There is one shared reconnect attempt per `RemoteExecServerClient`, not one
|
||||
reconnect loop per API surface.
|
||||
- Reconnect resumes the same logical session id and rebinds tracked
|
||||
`ProcessSession` routes onto the replacement `ExecServerConnection`.
|
||||
- When a reconnecting process session loses its transport, it emits a local
|
||||
`ResyncRequired` event and wake so callers blocked on pushed events or wake
|
||||
notifications know to recover through `process/read(afterSeq)`.
|
||||
- `process/read` may retry once after a transport-close race because its
|
||||
`afterSeq` cursor makes the replay read-only and recoverable.
|
||||
- `process/start`, `process/write`, `process/terminate`, filesystem RPCs, and
|
||||
`http/request` are not replayed after an ambiguous mid-request disconnect.
|
||||
They reconnect before later calls, but an in-flight call that may already
|
||||
have reached the server returns an error instead of risking duplicate side
|
||||
effects.
|
||||
- Streamed HTTP bodies are connection-local. A reconnect can start a later
|
||||
HTTP request, but it cannot resume body-delta delivery for an already-open
|
||||
stream.
|
||||
- Rendezvous uses the same logical split. The relay websocket and relay
|
||||
`stream_id` are transport beneath the exec-server logical session for the
|
||||
first reconnect slice. If a rendezvous websocket dies, the harness/client
|
||||
may establish a fresh relay stream, then re-run exec-server initialize with
|
||||
the prior `session_id` as `resume_session_id`. Existing process state recovers
|
||||
through `process/read(after_seq)`, not relay-frame replay. Full same-stream
|
||||
relay resume/replay remains a later protocol slice that requires endpoint-held
|
||||
seq/ack/replay state.
|
||||
|
||||
## Transport
|
||||
|
||||
The server speaks the shared `codex-app-server-protocol` message envelope on
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,7 @@
|
||||
//! This module is the facade for the environment-owned [`crate::HttpClient`]
|
||||
//! capability:
|
||||
//! - [`ReqwestHttpClient`] executes requests directly with `reqwest`
|
||||
//! - [`ExecServerClient`] forwards requests over the JSON-RPC transport
|
||||
//! - [`ExecServerConnection`] forwards requests over the JSON-RPC transport
|
||||
//! - [`HttpResponseBodyStream`] presents buffered local bodies and streamed
|
||||
//! remote `http/request/bodyDelta` notifications through one byte-stream API
|
||||
//!
|
||||
@@ -11,7 +11,7 @@
|
||||
//! - orchestrator process: holds an `Arc<dyn HttpClient>` and chooses local or
|
||||
//! remote execution
|
||||
//! - remote runtime: serves the `http/request` RPC and runs the concrete local
|
||||
//! HTTP request there when the orchestrator uses [`ExecServerClient`]
|
||||
//! HTTP request there when the orchestrator uses [`ExecServerConnection`]
|
||||
|
||||
#[path = "reqwest_http_client.rs"]
|
||||
mod reqwest_http_client;
|
||||
|
||||
1077
codex-rs/exec-server/src/client/reconnect_tests.rs
Normal file
1077
codex-rs/exec-server/src/client/reconnect_tests.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -14,7 +14,7 @@ use tokio::sync::mpsc;
|
||||
use super::HttpResponseBodyStream;
|
||||
use super::response_body_stream::HttpBodyStreamRegistration;
|
||||
use crate::HttpClient;
|
||||
use crate::client::ExecServerClient;
|
||||
use crate::client::ExecServerConnection;
|
||||
use crate::client::ExecServerError;
|
||||
use crate::protocol::HTTP_REQUEST_METHOD;
|
||||
use crate::protocol::HttpRequestParams;
|
||||
@@ -23,7 +23,7 @@ use crate::protocol::HttpRequestResponse;
|
||||
/// Maximum queued body frames per streamed HTTP response.
|
||||
const HTTP_BODY_DELTA_CHANNEL_CAPACITY: usize = 256;
|
||||
|
||||
impl ExecServerClient {
|
||||
impl ExecServerConnection {
|
||||
/// Performs an HTTP request and buffers the response body.
|
||||
pub async fn http_request(
|
||||
&self,
|
||||
@@ -67,14 +67,14 @@ impl ExecServerClient {
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpClient for ExecServerClient {
|
||||
impl HttpClient for ExecServerConnection {
|
||||
/// Orchestrator-side adapter that forwards buffered HTTP requests to the
|
||||
/// remote runtime over the shared JSON-RPC connection.
|
||||
fn http_request(
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>> {
|
||||
async move { ExecServerClient::http_request(self, params).await }.boxed()
|
||||
async move { ExecServerConnection::http_request(self, params).await }.boxed()
|
||||
}
|
||||
|
||||
/// Orchestrator-side adapter that forwards streamed HTTP requests to the
|
||||
@@ -83,6 +83,6 @@ impl HttpClient for ExecServerClient {
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> {
|
||||
async move { ExecServerClient::http_request_stream(self, params).await }.boxed()
|
||||
async move { ExecServerConnection::http_request_stream(self, params).await }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use tracing::warn;
|
||||
|
||||
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
|
||||
|
||||
use crate::ExecServerClient;
|
||||
use crate::ExecServerConnection;
|
||||
use crate::ExecServerError;
|
||||
use crate::client_api::RemoteExecServerConnectArgs;
|
||||
use crate::client_api::StdioExecServerCommand;
|
||||
@@ -17,9 +17,9 @@ use crate::client_api::StdioExecServerConnectArgs;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::relay::harness_connection_from_websocket;
|
||||
|
||||
const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment";
|
||||
pub(crate) const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment";
|
||||
|
||||
impl ExecServerClient {
|
||||
impl ExecServerConnection {
|
||||
pub(crate) async fn connect_for_transport(
|
||||
transport_params: crate::client_api::ExecServerTransportParams,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
|
||||
@@ -323,37 +323,20 @@ impl JsonRpcConnection {
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let (websocket_writer, websocket_reader) = stream.split();
|
||||
Self::from_websocket_parts(
|
||||
websocket_writer,
|
||||
websocket_reader,
|
||||
connection_label,
|
||||
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
|
||||
)
|
||||
Self::from_websocket_stream(stream, connection_label, /*ping_interval*/ None)
|
||||
}
|
||||
|
||||
pub(crate) fn from_axum_websocket(stream: AxumWebSocket, connection_label: String) -> Self {
|
||||
let (websocket_writer, websocket_reader) = stream.split();
|
||||
Self::from_websocket_parts(
|
||||
websocket_writer,
|
||||
websocket_reader,
|
||||
connection_label,
|
||||
// Axum only wraps inbound exec-server websocket accepts. Outbound websocket clients
|
||||
// own keepalive pings so one side does not accidentally create redundant traffic.
|
||||
/*keepalive_interval*/
|
||||
None,
|
||||
)
|
||||
Self::from_websocket_stream(stream, connection_label, Some(WEBSOCKET_KEEPALIVE_INTERVAL))
|
||||
}
|
||||
|
||||
fn from_websocket_parts<W, R, M, E>(
|
||||
mut websocket_writer: W,
|
||||
mut websocket_reader: R,
|
||||
fn from_websocket_stream<T, M, E>(
|
||||
mut websocket: T,
|
||||
connection_label: String,
|
||||
keepalive_interval: Option<Duration>,
|
||||
ping_interval: Option<Duration>,
|
||||
) -> Self
|
||||
where
|
||||
W: Sink<M, Error = E> + Unpin + Send + 'static,
|
||||
R: Stream<Item = Result<M, E>> + Unpin + Send + 'static,
|
||||
T: Sink<M, Error = E> + Stream<Item = Result<M, E>> + Unpin + Send + 'static,
|
||||
M: JsonRpcWebSocketMessage,
|
||||
E: std::fmt::Display + Send + 'static,
|
||||
{
|
||||
@@ -361,118 +344,106 @@ impl JsonRpcConnection {
|
||||
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (disconnected_tx, disconnected_rx) = watch::channel(false);
|
||||
|
||||
let reader_label = connection_label.clone();
|
||||
let incoming_tx_for_reader = incoming_tx.clone();
|
||||
let disconnected_tx_for_reader = disconnected_tx.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
let websocket_task = tokio::spawn(async move {
|
||||
let mut ping_interval = ping_interval.map(|ping_interval| {
|
||||
let mut interval = tokio::time::interval_at(
|
||||
tokio::time::Instant::now() + ping_interval,
|
||||
ping_interval,
|
||||
);
|
||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
interval
|
||||
});
|
||||
|
||||
loop {
|
||||
match websocket_reader.next().await {
|
||||
Some(Ok(message)) => match message.parse_jsonrpc_frame() {
|
||||
Ok(JsonRpcWebSocketFrame::Message(message)) => {
|
||||
if incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Message(message))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
tokio::select! {
|
||||
maybe_message = outgoing_rx.recv() => {
|
||||
let Some(message) = maybe_message else {
|
||||
break;
|
||||
};
|
||||
if let Err(reason) = send_websocket_jsonrpc_message(
|
||||
&mut websocket,
|
||||
&connection_label,
|
||||
&message,
|
||||
)
|
||||
.await
|
||||
{
|
||||
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
send_malformed_message(
|
||||
&incoming_tx_for_reader,
|
||||
Some(format!(
|
||||
"failed to parse websocket JSON-RPC message from {reader_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
_ = async {
|
||||
match ping_interval.as_mut() {
|
||||
Some(interval) => interval.tick().await,
|
||||
None => std::future::pending().await,
|
||||
}
|
||||
Ok(JsonRpcWebSocketFrame::Close) => {
|
||||
} => {
|
||||
if let Err(err) = websocket.send(M::ping()).await {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to write websocket ping to {connection_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Ok(JsonRpcWebSocketFrame::Ignore) => {}
|
||||
},
|
||||
Some(Err(err)) => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
Some(format!(
|
||||
"failed to read websocket JSON-RPC message from {reader_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let writer_task = tokio::spawn(async move {
|
||||
if let Some(keepalive_interval) = keepalive_interval {
|
||||
let mut keepalive = tokio::time::interval_at(
|
||||
tokio::time::Instant::now() + keepalive_interval,
|
||||
keepalive_interval,
|
||||
);
|
||||
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
tokio::select! {
|
||||
maybe_message = outgoing_rx.recv() => {
|
||||
let Some(message) = maybe_message else {
|
||||
break;
|
||||
};
|
||||
if let Err(reason) = send_websocket_jsonrpc_message(
|
||||
&mut websocket_writer,
|
||||
&connection_label,
|
||||
&message,
|
||||
)
|
||||
.await
|
||||
{
|
||||
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = keepalive.tick() => {
|
||||
if let Err(err) = websocket_writer.send(M::ping()).await {
|
||||
incoming_message = websocket.next() => {
|
||||
match incoming_message {
|
||||
Some(Ok(message)) => match message.parse_jsonrpc_frame() {
|
||||
Ok(JsonRpcWebSocketFrame::Message(message)) => {
|
||||
if incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Message(message))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(JsonRpcWebSocketFrame::Close) => {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Ok(JsonRpcWebSocketFrame::Ignore) => {}
|
||||
Err(err) => {
|
||||
send_malformed_message(
|
||||
&incoming_tx,
|
||||
Some(format!(
|
||||
"failed to parse websocket JSON-RPC message from {connection_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
},
|
||||
Some(Err(err)) => {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to write websocket ping to {connection_label}: {err}"
|
||||
"failed to read websocket JSON-RPC message from {connection_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
while let Some(message) = outgoing_rx.recv().await {
|
||||
if let Err(reason) = send_websocket_jsonrpc_message(
|
||||
&mut websocket_writer,
|
||||
&connection_label,
|
||||
&message,
|
||||
)
|
||||
.await
|
||||
{
|
||||
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -480,7 +451,7 @@ impl JsonRpcConnection {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
task_handles: vec![websocket_task],
|
||||
transport: JsonRpcTransport::Plain,
|
||||
}
|
||||
}
|
||||
@@ -619,34 +590,250 @@ fn serialize_jsonrpc_message(message: &JSONRPCMessage) -> Result<String, serde_j
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use futures::channel::mpsc as futures_mpsc;
|
||||
use futures::stream;
|
||||
use futures::task::Context;
|
||||
use futures::task::Poll;
|
||||
use futures::task::AtomicWaker;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::accept_async;
|
||||
use tokio_tungstenite::connect_async;
|
||||
|
||||
use super::*;
|
||||
|
||||
struct TestWebSocketSink {
|
||||
message_tx: futures_mpsc::UnboundedSender<Message>,
|
||||
#[tokio::test]
|
||||
async fn websocket_connection_sends_configured_ping() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let connection = JsonRpcConnection::from_websocket_stream(
|
||||
client_websocket,
|
||||
"test".into(),
|
||||
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
|
||||
);
|
||||
|
||||
let message = timeout(Duration::from_secs(1), server_websocket.next())
|
||||
.await?
|
||||
.expect("websocket should stay open")?;
|
||||
assert!(matches!(message, Message::Ping(_)));
|
||||
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl Sink<Message> for TestWebSocketSink {
|
||||
#[tokio::test]
|
||||
async fn websocket_connection_ignores_server_pong() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let mut connection = JsonRpcConnection::from_websocket(client_websocket, "test".into());
|
||||
|
||||
server_websocket
|
||||
.send(Message::Pong(b"check".to_vec().into()))
|
||||
.await?;
|
||||
assert!(
|
||||
timeout(Duration::from_millis(50), connection.incoming_rx.recv())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn websocket_connection_reports_server_close() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let mut connection = JsonRpcConnection::from_websocket(client_websocket, "test".into());
|
||||
|
||||
server_websocket.close(None).await?;
|
||||
assert!(matches!(
|
||||
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
|
||||
Some(JsonRpcConnectionEvent::Disconnected { reason: None })
|
||||
));
|
||||
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn websocket_connection_accepts_binary_jsonrpc_message() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let mut connection = JsonRpcConnection::from_websocket(client_websocket, "test".into());
|
||||
let message = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(1),
|
||||
method: "test".to_string(),
|
||||
params: None,
|
||||
trace: None,
|
||||
});
|
||||
|
||||
server_websocket
|
||||
.send(Message::Binary(serde_json::to_vec(&message)?.into()))
|
||||
.await?;
|
||||
assert!(matches!(
|
||||
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
|
||||
Some(JsonRpcConnectionEvent::Message(actual)) if actual == message
|
||||
));
|
||||
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn websocket_connection_keeps_outbound_message_while_send_is_backpressured()
|
||||
-> anyhow::Result<()> {
|
||||
let (websocket, control, mut outbound_rx) =
|
||||
ControlledWebSocket::new(/*write_ready*/ false);
|
||||
let mut connection = JsonRpcConnection::from_websocket_stream(
|
||||
websocket,
|
||||
"test".into(),
|
||||
/*ping_interval*/ None,
|
||||
);
|
||||
let message = test_jsonrpc_message();
|
||||
|
||||
connection.outgoing_tx.send(message.clone()).await?;
|
||||
control.wait_for_blocked_write().await?;
|
||||
control.send_inbound(Message::Pong(b"check".to_vec().into()))?;
|
||||
assert!(
|
||||
timeout(Duration::from_millis(50), connection.incoming_rx.recv())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
control.set_write_ready();
|
||||
assert!(matches!(
|
||||
timeout(Duration::from_secs(1), outbound_rx.next()).await?,
|
||||
Some(Message::Text(text)) if serde_json::from_str::<JSONRPCMessage>(&text)? == message
|
||||
));
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn websocket_pair() -> anyhow::Result<(
|
||||
WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
|
||||
WebSocketStream<tokio::net::TcpStream>,
|
||||
)> {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await?;
|
||||
let websocket_url = format!("ws://{}", listener.local_addr()?);
|
||||
let server_task = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
accept_async(stream).await.map_err(anyhow::Error::from)
|
||||
});
|
||||
let (client_websocket, _) = connect_async(websocket_url).await?;
|
||||
let server_websocket = server_task.await??;
|
||||
Ok((client_websocket, server_websocket))
|
||||
}
|
||||
|
||||
fn test_jsonrpc_message() -> JSONRPCMessage {
|
||||
JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(1),
|
||||
method: "test".to_string(),
|
||||
params: None,
|
||||
trace: None,
|
||||
})
|
||||
}
|
||||
|
||||
struct ControlledWebSocket {
|
||||
inbound_rx: futures_mpsc::UnboundedReceiver<Result<Message, std::convert::Infallible>>,
|
||||
outbound_tx: futures_mpsc::UnboundedSender<Message>,
|
||||
write_ready: Arc<AtomicBool>,
|
||||
write_blocked: Arc<AtomicBool>,
|
||||
write_blocked_waker: Arc<AtomicWaker>,
|
||||
write_waker: Arc<AtomicWaker>,
|
||||
}
|
||||
|
||||
struct ControlledWebSocketHandle {
|
||||
inbound_tx: futures_mpsc::UnboundedSender<Result<Message, std::convert::Infallible>>,
|
||||
write_ready: Arc<AtomicBool>,
|
||||
write_blocked: Arc<AtomicBool>,
|
||||
write_blocked_waker: Arc<AtomicWaker>,
|
||||
write_waker: Arc<AtomicWaker>,
|
||||
}
|
||||
|
||||
impl ControlledWebSocket {
|
||||
fn new(
|
||||
write_ready: bool,
|
||||
) -> (
|
||||
Self,
|
||||
ControlledWebSocketHandle,
|
||||
futures_mpsc::UnboundedReceiver<Message>,
|
||||
) {
|
||||
let (inbound_tx, inbound_rx) = futures_mpsc::unbounded();
|
||||
let (outbound_tx, outbound_rx) = futures_mpsc::unbounded();
|
||||
let write_ready = Arc::new(AtomicBool::new(write_ready));
|
||||
let write_blocked = Arc::new(AtomicBool::new(false));
|
||||
let write_blocked_waker = Arc::new(AtomicWaker::new());
|
||||
let write_waker = Arc::new(AtomicWaker::new());
|
||||
(
|
||||
Self {
|
||||
inbound_rx,
|
||||
outbound_tx,
|
||||
write_ready: Arc::clone(&write_ready),
|
||||
write_blocked: Arc::clone(&write_blocked),
|
||||
write_blocked_waker: Arc::clone(&write_blocked_waker),
|
||||
write_waker: Arc::clone(&write_waker),
|
||||
},
|
||||
ControlledWebSocketHandle {
|
||||
inbound_tx,
|
||||
write_ready,
|
||||
write_blocked,
|
||||
write_blocked_waker,
|
||||
write_waker,
|
||||
},
|
||||
outbound_rx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ControlledWebSocketHandle {
|
||||
fn send_inbound(&self, message: Message) -> anyhow::Result<()> {
|
||||
self.inbound_tx
|
||||
.unbounded_send(Ok(message))
|
||||
.map_err(anyhow::Error::from)
|
||||
}
|
||||
|
||||
fn set_write_ready(&self) {
|
||||
self.write_ready.store(true, Ordering::Release);
|
||||
self.write_waker.wake();
|
||||
}
|
||||
|
||||
async fn wait_for_blocked_write(&self) -> anyhow::Result<()> {
|
||||
timeout(
|
||||
Duration::from_secs(1),
|
||||
futures::future::poll_fn(|cx| {
|
||||
if self.write_blocked.load(Ordering::Acquire) {
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
self.write_blocked_waker.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<Message> for ControlledWebSocket {
|
||||
type Error = std::convert::Infallible;
|
||||
|
||||
fn poll_ready(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
if self.write_ready.load(Ordering::Acquire) {
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
self.write_blocked.store(true, Ordering::Release);
|
||||
self.write_blocked_waker.wake();
|
||||
self.write_waker.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
|
||||
self.get_mut()
|
||||
.message_tx
|
||||
self.outbound_tx
|
||||
.unbounded_send(item)
|
||||
.expect("test websocket receiver should stay open");
|
||||
.expect("test outbound receiver should stay open");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -665,24 +852,11 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn websocket_connection_sends_keepalive_ping() {
|
||||
let (message_tx, mut message_rx) = futures_mpsc::unbounded::<Message>();
|
||||
let websocket_writer = TestWebSocketSink { message_tx };
|
||||
let websocket_reader = stream::pending::<Result<Message, std::convert::Infallible>>();
|
||||
let connection = JsonRpcConnection::from_websocket_parts(
|
||||
websocket_writer,
|
||||
websocket_reader,
|
||||
"test".into(),
|
||||
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
|
||||
);
|
||||
impl Stream for ControlledWebSocket {
|
||||
type Item = Result<Message, std::convert::Infallible>;
|
||||
|
||||
let message = timeout(Duration::from_secs(1), message_rx.next())
|
||||
.await
|
||||
.expect("keepalive ping should arrive before timeout")
|
||||
.expect("keepalive ping should be sent");
|
||||
assert!(matches!(message, Message::Ping(_)));
|
||||
|
||||
drop(connection);
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
Pin::new(&mut self.inbound_rx).poll_next(cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::ExecServerError;
|
||||
use crate::ExecServerRuntimePaths;
|
||||
use crate::ExecutorFileSystem;
|
||||
use crate::HttpClient;
|
||||
use crate::client::LazyRemoteExecServerClient;
|
||||
use crate::client::RemoteExecServerClient;
|
||||
use crate::client::http_client::ReqwestHttpClient;
|
||||
use crate::client_api::ExecServerTransportParams;
|
||||
use crate::environment_provider::DefaultEnvironmentProvider;
|
||||
@@ -403,7 +403,7 @@ impl Environment {
|
||||
} => Some(exec_server_url.clone()),
|
||||
ExecServerTransportParams::StdioCommand { .. } => None,
|
||||
};
|
||||
let client = LazyRemoteExecServerClient::new(remote_transport.clone());
|
||||
let client = RemoteExecServerClient::new(remote_transport.clone());
|
||||
let exec_backend: Arc<dyn ExecBackend> = Arc::new(RemoteProcess::new(client.clone()));
|
||||
let filesystem: Arc<dyn ExecutorFileSystem> =
|
||||
Arc::new(RemoteFileSystem::new(client.clone()));
|
||||
|
||||
@@ -23,8 +23,9 @@ mod runtime_paths;
|
||||
mod sandboxed_file_system;
|
||||
mod server;
|
||||
|
||||
pub use client::ExecServerClient;
|
||||
pub use client::ExecServerConnection;
|
||||
pub use client::ExecServerError;
|
||||
pub type ExecServerClient = ExecServerConnection;
|
||||
pub use client::http_client::HttpResponseBodyStream;
|
||||
pub use client::http_client::ReqwestHttpClient;
|
||||
pub use client_api::ExecServerClientConnectOptions;
|
||||
|
||||
@@ -23,13 +23,16 @@ pub struct StartedExecProcess {
|
||||
/// The stream is scoped to one [`ExecProcess`] handle. `Output` events carry
|
||||
/// stdout, stderr, or pty bytes. `Exited` reports the process exit status, while
|
||||
/// `Closed` means all output streams have ended and no more output events will
|
||||
/// arrive. `Failed` is used when the process session cannot continue, for
|
||||
/// example because the remote executor connection disconnected.
|
||||
/// arrive. `ResyncRequired` means a reconnecting remote process session should
|
||||
/// recover through [`ExecProcess::read`] using its last delivered sequence.
|
||||
/// `Failed` is used when the process session cannot continue, for example
|
||||
/// because a direct one-shot executor connection disconnected.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ExecProcessEvent {
|
||||
Output(ProcessOutputChunk),
|
||||
Exited { seq: u64, exit_code: i32 },
|
||||
Closed { seq: u64 },
|
||||
ResyncRequired,
|
||||
Failed(String),
|
||||
}
|
||||
|
||||
@@ -60,13 +63,14 @@ struct ExecProcessEventHistory {
|
||||
impl ExecProcessEvent {
|
||||
/// Sequence number used to order process-owned events.
|
||||
///
|
||||
/// `Failed` is intentionally unsequenced because it is synthesized by the
|
||||
/// client when the session or transport fails, not emitted by the process.
|
||||
/// `ResyncRequired` and `Failed` are intentionally unsequenced because they
|
||||
/// are synthesized by the client when the session or transport changes,
|
||||
/// not emitted by the process.
|
||||
pub(crate) fn seq(&self) -> Option<u64> {
|
||||
match self {
|
||||
ExecProcessEvent::Output(chunk) => Some(chunk.seq),
|
||||
ExecProcessEvent::Exited { seq, .. } | ExecProcessEvent::Closed { seq } => Some(*seq),
|
||||
ExecProcessEvent::Failed(_) => None,
|
||||
ExecProcessEvent::ResyncRequired | ExecProcessEvent::Failed(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +78,9 @@ impl ExecProcessEvent {
|
||||
match self {
|
||||
ExecProcessEvent::Output(chunk) => chunk.chunk.0.len(),
|
||||
ExecProcessEvent::Failed(message) => message.len(),
|
||||
ExecProcessEvent::Exited { .. } | ExecProcessEvent::Closed { .. } => 0,
|
||||
ExecProcessEvent::Exited { .. }
|
||||
| ExecProcessEvent::Closed { .. }
|
||||
| ExecProcessEvent::ResyncRequired => 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use futures::Sink;
|
||||
use futures::SinkExt;
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
use prost::Message as ProstMessage;
|
||||
use tokio::io::AsyncRead;
|
||||
@@ -140,121 +142,25 @@ fn jsonrpc_payload(message: &JSONRPCMessage) -> Result<Vec<u8>, ExecServerError>
|
||||
serde_json::to_vec(message).map_err(ExecServerError::Json)
|
||||
}
|
||||
|
||||
pub(crate) fn harness_connection_from_websocket<S>(
|
||||
stream: WebSocketStream<S>,
|
||||
pub(crate) fn harness_connection_from_websocket<T, E>(
|
||||
stream: T,
|
||||
connection_label: String,
|
||||
) -> JsonRpcConnection
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
T: Sink<Message, Error = E> + Stream<Item = Result<Message, E>> + Unpin + Send + 'static,
|
||||
E: std::fmt::Display + Send + 'static,
|
||||
{
|
||||
let stream_id = Uuid::new_v4().to_string();
|
||||
let (mut websocket_writer, mut websocket_reader) = stream.split();
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (disconnected_tx, disconnected_rx) = watch::channel(false);
|
||||
|
||||
let reader_label = connection_label;
|
||||
let reader_stream_id = stream_id.clone();
|
||||
let incoming_tx_for_reader = incoming_tx;
|
||||
let disconnected_tx_for_reader = disconnected_tx.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match websocket_reader.next().await {
|
||||
Some(Ok(Message::Binary(payload))) => {
|
||||
let frame = match decode_relay_message_frame(payload.as_ref()) {
|
||||
Ok(frame) => frame,
|
||||
Err(err) => {
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: format!(
|
||||
"failed to parse relay message frame from {reader_label}: {err}"
|
||||
),
|
||||
})
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if frame.stream_id != reader_stream_id {
|
||||
continue;
|
||||
}
|
||||
let kind = match frame.validate() {
|
||||
Ok(kind) => kind,
|
||||
Err(err) => {
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
match kind {
|
||||
RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() {
|
||||
Ok(message) => {
|
||||
if incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Message(message))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
},
|
||||
RelayFrameBodyKind::Reset => {
|
||||
let _ = disconnected_tx_for_reader.send(true);
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Disconnected {
|
||||
reason: frame.into_reset_reason(),
|
||||
})
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
RelayFrameBodyKind::Ack
|
||||
| RelayFrameBodyKind::Resume
|
||||
| RelayFrameBodyKind::Heartbeat => {}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
let _ = disconnected_tx_for_reader.send(true);
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Disconnected { reason: None })
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {}
|
||||
Some(Ok(Message::Text(_))) => {
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: "relay exec-server transport expects binary protobuf frames"
|
||||
.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
let _ = disconnected_tx_for_reader.send(true);
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Disconnected {
|
||||
reason: Some(format!(
|
||||
"failed to read relay websocket frame from {reader_label}: {err}"
|
||||
)),
|
||||
})
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let writer_task = tokio::spawn(async move {
|
||||
let websocket_task = tokio::spawn(async move {
|
||||
let mut websocket = stream;
|
||||
let reader_label = connection_label;
|
||||
let reader_stream_id = stream_id.clone();
|
||||
let resume = RelayMessageFrame::resume(stream_id.clone());
|
||||
if websocket_writer
|
||||
if websocket
|
||||
.send(Message::Binary(encode_relay_message_frame(&resume).into()))
|
||||
.await
|
||||
.is_err()
|
||||
@@ -263,12 +169,12 @@ where
|
||||
return;
|
||||
}
|
||||
|
||||
let mut next_seq = 0u32;
|
||||
let mut keepalive = tokio::time::interval_at(
|
||||
tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL,
|
||||
WEBSOCKET_KEEPALIVE_INTERVAL,
|
||||
);
|
||||
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
let mut next_seq = 0u32;
|
||||
loop {
|
||||
tokio::select! {
|
||||
maybe_message = outgoing_rx.recv() => {
|
||||
@@ -284,7 +190,7 @@ where
|
||||
};
|
||||
let frame = RelayMessageFrame::data(stream_id.clone(), next_seq, payload);
|
||||
next_seq = next_seq.wrapping_add(1);
|
||||
if websocket_writer
|
||||
if websocket
|
||||
.send(Message::Binary(encode_relay_message_frame(&frame).into()))
|
||||
.await
|
||||
.is_err()
|
||||
@@ -294,11 +200,103 @@ where
|
||||
}
|
||||
}
|
||||
_ = keepalive.tick() => {
|
||||
if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() {
|
||||
if websocket.send(Message::Ping(Vec::new().into())).await.is_err() {
|
||||
let _ = disconnected_tx.send(true);
|
||||
break;
|
||||
}
|
||||
}
|
||||
incoming_message = websocket.next() => {
|
||||
match incoming_message {
|
||||
Some(Ok(Message::Binary(payload))) => {
|
||||
let frame = match decode_relay_message_frame(payload.as_ref()) {
|
||||
Ok(frame) => frame,
|
||||
Err(err) => {
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: format!(
|
||||
"failed to parse relay message frame from {reader_label}: {err}"
|
||||
),
|
||||
})
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if frame.stream_id != reader_stream_id {
|
||||
continue;
|
||||
}
|
||||
let kind = match frame.validate() {
|
||||
Ok(kind) => kind,
|
||||
Err(err) => {
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
match kind {
|
||||
RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() {
|
||||
Ok(message) => {
|
||||
if incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Message(message))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
},
|
||||
RelayFrameBodyKind::Reset => {
|
||||
let _ = disconnected_tx.send(true);
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Disconnected {
|
||||
reason: frame.into_reset_reason(),
|
||||
})
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
RelayFrameBodyKind::Ack
|
||||
| RelayFrameBodyKind::Resume
|
||||
| RelayFrameBodyKind::Heartbeat => {}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
let _ = disconnected_tx.send(true);
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Disconnected { reason: None })
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {}
|
||||
Some(Ok(Message::Text(_))) => {
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: "relay exec-server transport expects binary protobuf frames"
|
||||
.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
let _ = disconnected_tx.send(true);
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Disconnected {
|
||||
reason: Some(format!(
|
||||
"failed to read relay websocket frame from {reader_label}: {err}"
|
||||
)),
|
||||
})
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -307,7 +305,7 @@ where
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
task_handles: vec![websocket_task],
|
||||
transport: JsonRpcTransport::Plain,
|
||||
}
|
||||
}
|
||||
@@ -318,59 +316,53 @@ pub(crate) async fn run_multiplexed_executor<S>(
|
||||
) where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let (mut websocket_writer, mut websocket_reader) = stream.split();
|
||||
let mut websocket = stream;
|
||||
let (physical_outgoing_tx, mut physical_outgoing_rx) =
|
||||
mpsc::channel::<Vec<u8>>(CHANNEL_CAPACITY);
|
||||
let writer_task = tokio::spawn(async move {
|
||||
let mut keepalive = tokio::time::interval_at(
|
||||
tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL,
|
||||
WEBSOCKET_KEEPALIVE_INTERVAL,
|
||||
);
|
||||
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
tokio::select! {
|
||||
maybe_encoded = physical_outgoing_rx.recv() => {
|
||||
let Some(encoded) = maybe_encoded else {
|
||||
break;
|
||||
};
|
||||
if websocket_writer
|
||||
.send(Message::Binary(encoded.into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = keepalive.tick() => {
|
||||
if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut streams: HashMap<String, VirtualStream> = HashMap::new();
|
||||
let mut keepalive = tokio::time::interval_at(
|
||||
tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL,
|
||||
WEBSOCKET_KEEPALIVE_INTERVAL,
|
||||
);
|
||||
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
let frame = match websocket_reader.next().await {
|
||||
Some(Ok(Message::Binary(payload))) => {
|
||||
match decode_relay_message_frame(payload.as_ref()) {
|
||||
Ok(frame) => frame,
|
||||
Err(err) => {
|
||||
warn!("dropping malformed relay message frame from harness: {err}");
|
||||
continue;
|
||||
}
|
||||
let frame = tokio::select! {
|
||||
maybe_encoded = physical_outgoing_rx.recv() => {
|
||||
let Some(encoded) = maybe_encoded else {
|
||||
break;
|
||||
};
|
||||
if websocket.send(Message::Binary(encoded.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => break,
|
||||
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue,
|
||||
Some(Ok(Message::Text(_))) => {
|
||||
warn!("dropping non-binary relay message frame from harness");
|
||||
continue;
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
debug!("multiplexed executor websocket read failed: {err}");
|
||||
break;
|
||||
_ = keepalive.tick() => {
|
||||
if websocket.send(Message::Ping(Vec::new().into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
incoming_message = websocket.next() => match incoming_message {
|
||||
Some(Ok(Message::Binary(payload))) => {
|
||||
match decode_relay_message_frame(payload.as_ref()) {
|
||||
Ok(frame) => frame,
|
||||
Err(err) => {
|
||||
warn!("dropping malformed relay message frame from harness: {err}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => break,
|
||||
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue,
|
||||
Some(Ok(Message::Text(_))) => {
|
||||
warn!("dropping non-binary relay message frame from harness");
|
||||
continue;
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
debug!("multiplexed executor websocket read failed: {err}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -423,7 +415,6 @@ pub(crate) async fn run_multiplexed_executor<S>(
|
||||
stream.disconnect(/*reason*/ None).await;
|
||||
}
|
||||
drop(physical_outgoing_tx);
|
||||
let _ = writer_task.await;
|
||||
}
|
||||
|
||||
struct VirtualStream {
|
||||
@@ -492,8 +483,20 @@ fn spawn_virtual_stream(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use futures::Sink;
|
||||
use futures::Stream;
|
||||
use futures::channel::mpsc as futures_mpsc;
|
||||
use futures::task::AtomicWaker;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::accept_async;
|
||||
@@ -502,40 +505,107 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
fn test_runtime_paths() -> anyhow::Result<crate::ExecServerRuntimePaths> {
|
||||
crate::ExecServerRuntimePaths::new(
|
||||
std::env::current_exe()?,
|
||||
/*codex_linux_sandbox_exe*/ None,
|
||||
)
|
||||
.map_err(anyhow::Error::from)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn multiplexed_executor_sends_keepalive_ping() -> anyhow::Result<()> {
|
||||
async fn harness_connection_receives_relay_data() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let executor_task = tokio::spawn(run_multiplexed_executor(
|
||||
client_websocket,
|
||||
ConnectionProcessor::new(test_runtime_paths()?),
|
||||
let mut connection =
|
||||
harness_connection_from_websocket(client_websocket, "test".to_string());
|
||||
let stream_id = read_resume_stream_id(&mut server_websocket).await?;
|
||||
let message = test_jsonrpc_message();
|
||||
|
||||
server_websocket
|
||||
.send(Message::Binary(
|
||||
encode_relay_message_frame(&RelayMessageFrame::data(
|
||||
stream_id,
|
||||
/*seq*/ 0,
|
||||
jsonrpc_payload(&message)?,
|
||||
))
|
||||
.into(),
|
||||
))
|
||||
.await?;
|
||||
assert!(matches!(
|
||||
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
|
||||
Some(JsonRpcConnectionEvent::Message(actual)) if actual == message
|
||||
));
|
||||
|
||||
read_keepalive_ping(&mut server_websocket).await?;
|
||||
|
||||
executor_task.abort();
|
||||
let _ = executor_task.await;
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn harness_connection_sends_keepalive_ping() -> anyhow::Result<()> {
|
||||
async fn harness_connection_reports_text_frames_as_malformed() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let connection = harness_connection_from_websocket(client_websocket, "test".to_string());
|
||||
let mut connection =
|
||||
harness_connection_from_websocket(client_websocket, "test".to_string());
|
||||
|
||||
read_keepalive_ping(&mut server_websocket).await?;
|
||||
read_resume_stream_id(&mut server_websocket).await?;
|
||||
server_websocket.send(Message::Text("nope".into())).await?;
|
||||
assert!(matches!(
|
||||
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
|
||||
Some(JsonRpcConnectionEvent::MalformedMessage { reason })
|
||||
if reason == "relay exec-server transport expects binary protobuf frames"
|
||||
));
|
||||
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn harness_connection_reports_server_close() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let mut connection =
|
||||
harness_connection_from_websocket(client_websocket, "test".to_string());
|
||||
|
||||
read_resume_stream_id(&mut server_websocket).await?;
|
||||
server_websocket.close(None).await?;
|
||||
assert!(matches!(
|
||||
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
|
||||
Some(JsonRpcConnectionEvent::Disconnected { reason: None })
|
||||
));
|
||||
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn harness_connection_keeps_outbound_frame_while_send_is_backpressured()
|
||||
-> anyhow::Result<()> {
|
||||
let (websocket, control, mut outbound_rx) =
|
||||
ControlledWebSocket::new(/*write_ready*/ true);
|
||||
let mut connection = harness_connection_from_websocket(websocket, "test".to_string());
|
||||
let Message::Binary(resume_payload) = timeout(Duration::from_secs(1), outbound_rx.next())
|
||||
.await?
|
||||
.expect("resume frame")
|
||||
else {
|
||||
anyhow::bail!("expected relay resume frame");
|
||||
};
|
||||
let stream_id = decode_relay_message_frame(resume_payload.as_ref())?.stream_id;
|
||||
let message = test_jsonrpc_message();
|
||||
|
||||
control.set_write_blocked();
|
||||
connection.outgoing_tx.send(message.clone()).await?;
|
||||
control.wait_for_blocked_write().await?;
|
||||
control.send_inbound(Message::Pong(b"check".to_vec().into()))?;
|
||||
assert!(
|
||||
timeout(Duration::from_millis(50), connection.incoming_rx.recv())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
control.set_write_ready();
|
||||
let Message::Binary(data_payload) = timeout(Duration::from_secs(1), outbound_rx.next())
|
||||
.await?
|
||||
.expect("data frame")
|
||||
else {
|
||||
anyhow::bail!("expected relay data frame");
|
||||
};
|
||||
let frame = decode_relay_message_frame(data_payload.as_ref())?;
|
||||
assert_eq!(frame.stream_id, stream_id);
|
||||
assert_eq!(frame.into_jsonrpc_message()?, message);
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn websocket_pair() -> anyhow::Result<(
|
||||
WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
|
||||
WebSocketStream<tokio::net::TcpStream>,
|
||||
@@ -551,18 +621,155 @@ mod tests {
|
||||
Ok((client_websocket, server_websocket))
|
||||
}
|
||||
|
||||
async fn read_keepalive_ping(
|
||||
async fn read_resume_stream_id(
|
||||
websocket: &mut WebSocketStream<tokio::net::TcpStream>,
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let Some(message) = timeout(Duration::from_secs(1), websocket.next()).await? else {
|
||||
anyhow::bail!("websocket closed before keepalive ping");
|
||||
};
|
||||
match message? {
|
||||
Message::Ping(_) => return Ok(()),
|
||||
Message::Binary(_) | Message::Text(_) | Message::Pong(_) | Message::Frame(_) => {}
|
||||
Message::Close(_) => anyhow::bail!("websocket closed before keepalive ping"),
|
||||
) -> anyhow::Result<String> {
|
||||
let message = timeout(Duration::from_secs(1), websocket.next())
|
||||
.await?
|
||||
.expect("websocket should stay open")?;
|
||||
let Message::Binary(payload) = message else {
|
||||
anyhow::bail!("expected relay resume frame, got {message:?}");
|
||||
};
|
||||
let frame = decode_relay_message_frame(payload.as_ref())?;
|
||||
assert_eq!(frame.validate()?, RelayFrameBodyKind::Resume);
|
||||
Ok(frame.stream_id)
|
||||
}
|
||||
|
||||
fn test_jsonrpc_message() -> JSONRPCMessage {
|
||||
JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(1),
|
||||
method: "test".to_string(),
|
||||
params: None,
|
||||
trace: None,
|
||||
})
|
||||
}
|
||||
|
||||
struct ControlledWebSocket {
|
||||
inbound_rx: futures_mpsc::UnboundedReceiver<Result<Message, std::convert::Infallible>>,
|
||||
outbound_tx: futures_mpsc::UnboundedSender<Message>,
|
||||
write_ready: Arc<AtomicBool>,
|
||||
write_blocked: Arc<AtomicBool>,
|
||||
write_blocked_waker: Arc<AtomicWaker>,
|
||||
write_waker: Arc<AtomicWaker>,
|
||||
}
|
||||
|
||||
struct ControlledWebSocketHandle {
|
||||
inbound_tx: futures_mpsc::UnboundedSender<Result<Message, std::convert::Infallible>>,
|
||||
write_ready: Arc<AtomicBool>,
|
||||
write_blocked: Arc<AtomicBool>,
|
||||
write_blocked_waker: Arc<AtomicWaker>,
|
||||
write_waker: Arc<AtomicWaker>,
|
||||
}
|
||||
|
||||
impl ControlledWebSocket {
|
||||
fn new(
|
||||
write_ready: bool,
|
||||
) -> (
|
||||
Self,
|
||||
ControlledWebSocketHandle,
|
||||
futures_mpsc::UnboundedReceiver<Message>,
|
||||
) {
|
||||
let (inbound_tx, inbound_rx) = futures_mpsc::unbounded();
|
||||
let (outbound_tx, outbound_rx) = futures_mpsc::unbounded();
|
||||
let write_ready = Arc::new(AtomicBool::new(write_ready));
|
||||
let write_blocked = Arc::new(AtomicBool::new(false));
|
||||
let write_blocked_waker = Arc::new(AtomicWaker::new());
|
||||
let write_waker = Arc::new(AtomicWaker::new());
|
||||
(
|
||||
Self {
|
||||
inbound_rx,
|
||||
outbound_tx,
|
||||
write_ready: Arc::clone(&write_ready),
|
||||
write_blocked: Arc::clone(&write_blocked),
|
||||
write_blocked_waker: Arc::clone(&write_blocked_waker),
|
||||
write_waker: Arc::clone(&write_waker),
|
||||
},
|
||||
ControlledWebSocketHandle {
|
||||
inbound_tx,
|
||||
write_ready,
|
||||
write_blocked,
|
||||
write_blocked_waker,
|
||||
write_waker,
|
||||
},
|
||||
outbound_rx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ControlledWebSocketHandle {
|
||||
fn send_inbound(&self, message: Message) -> anyhow::Result<()> {
|
||||
self.inbound_tx
|
||||
.unbounded_send(Ok(message))
|
||||
.map_err(anyhow::Error::from)
|
||||
}
|
||||
|
||||
fn set_write_blocked(&self) {
|
||||
self.write_ready.store(false, Ordering::Release);
|
||||
}
|
||||
|
||||
fn set_write_ready(&self) {
|
||||
self.write_ready.store(true, Ordering::Release);
|
||||
self.write_waker.wake();
|
||||
}
|
||||
|
||||
async fn wait_for_blocked_write(&self) -> anyhow::Result<()> {
|
||||
timeout(
|
||||
Duration::from_secs(1),
|
||||
futures::future::poll_fn(|cx| {
|
||||
if self.write_blocked.load(Ordering::Acquire) {
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
self.write_blocked_waker.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<Message> for ControlledWebSocket {
|
||||
type Error = std::convert::Infallible;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
if self.write_ready.load(Ordering::Acquire) {
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
self.write_blocked.store(true, Ordering::Release);
|
||||
self.write_blocked_waker.wake();
|
||||
self.write_waker.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
|
||||
self.outbound_tx
|
||||
.unbounded_send(item)
|
||||
.expect("test outbound receiver should stay open");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for ControlledWebSocket {
|
||||
type Item = Result<Message, std::convert::Infallible>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
Pin::new(&mut self.inbound_rx).poll_next(cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ use crate::FileSystemResult;
|
||||
use crate::FileSystemSandboxContext;
|
||||
use crate::ReadDirectoryEntry;
|
||||
use crate::RemoveOptions;
|
||||
use crate::client::LazyRemoteExecServerClient;
|
||||
use crate::client::RemoteExecServerClient;
|
||||
use crate::protocol::FsCopyParams;
|
||||
use crate::protocol::FsCreateDirectoryParams;
|
||||
use crate::protocol::FsGetMetadataParams;
|
||||
@@ -28,11 +28,11 @@ const NOT_FOUND_ERROR_CODE: i64 = -32004;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RemoteFileSystem {
|
||||
client: LazyRemoteExecServerClient,
|
||||
client: RemoteExecServerClient,
|
||||
}
|
||||
|
||||
impl RemoteFileSystem {
|
||||
pub(crate) fn new(client: LazyRemoteExecServerClient) -> Self {
|
||||
pub(crate) fn new(client: RemoteExecServerClient) -> Self {
|
||||
trace!("remote fs new");
|
||||
Self { client }
|
||||
}
|
||||
@@ -46,8 +46,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
|
||||
sandbox: Option<&FileSystemSandboxContext>,
|
||||
) -> FileSystemResult<Vec<u8>> {
|
||||
trace!("remote fs read_file");
|
||||
let client = self.client.get().await.map_err(map_remote_error)?;
|
||||
let response = client
|
||||
let connection = self.client.connection().await.map_err(map_remote_error)?;
|
||||
let response = connection
|
||||
.fs_read_file(FsReadFileParams {
|
||||
path: path.clone(),
|
||||
sandbox: remote_sandbox_context(sandbox),
|
||||
@@ -69,8 +69,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
|
||||
sandbox: Option<&FileSystemSandboxContext>,
|
||||
) -> FileSystemResult<()> {
|
||||
trace!("remote fs write_file");
|
||||
let client = self.client.get().await.map_err(map_remote_error)?;
|
||||
client
|
||||
let connection = self.client.connection().await.map_err(map_remote_error)?;
|
||||
connection
|
||||
.fs_write_file(FsWriteFileParams {
|
||||
path: path.clone(),
|
||||
data_base64: STANDARD.encode(contents),
|
||||
@@ -88,8 +88,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
|
||||
sandbox: Option<&FileSystemSandboxContext>,
|
||||
) -> FileSystemResult<()> {
|
||||
trace!("remote fs create_directory");
|
||||
let client = self.client.get().await.map_err(map_remote_error)?;
|
||||
client
|
||||
let connection = self.client.connection().await.map_err(map_remote_error)?;
|
||||
connection
|
||||
.fs_create_directory(FsCreateDirectoryParams {
|
||||
path: path.clone(),
|
||||
recursive: Some(options.recursive),
|
||||
@@ -106,8 +106,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
|
||||
sandbox: Option<&FileSystemSandboxContext>,
|
||||
) -> FileSystemResult<FileMetadata> {
|
||||
trace!("remote fs get_metadata");
|
||||
let client = self.client.get().await.map_err(map_remote_error)?;
|
||||
let response = client
|
||||
let connection = self.client.connection().await.map_err(map_remote_error)?;
|
||||
let response = connection
|
||||
.fs_get_metadata(FsGetMetadataParams {
|
||||
path: path.clone(),
|
||||
sandbox: remote_sandbox_context(sandbox),
|
||||
@@ -129,8 +129,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
|
||||
sandbox: Option<&FileSystemSandboxContext>,
|
||||
) -> FileSystemResult<Vec<ReadDirectoryEntry>> {
|
||||
trace!("remote fs read_directory");
|
||||
let client = self.client.get().await.map_err(map_remote_error)?;
|
||||
let response = client
|
||||
let connection = self.client.connection().await.map_err(map_remote_error)?;
|
||||
let response = connection
|
||||
.fs_read_directory(FsReadDirectoryParams {
|
||||
path: path.clone(),
|
||||
sandbox: remote_sandbox_context(sandbox),
|
||||
@@ -155,8 +155,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
|
||||
sandbox: Option<&FileSystemSandboxContext>,
|
||||
) -> FileSystemResult<()> {
|
||||
trace!("remote fs remove");
|
||||
let client = self.client.get().await.map_err(map_remote_error)?;
|
||||
client
|
||||
let connection = self.client.connection().await.map_err(map_remote_error)?;
|
||||
connection
|
||||
.fs_remove(FsRemoveParams {
|
||||
path: path.clone(),
|
||||
recursive: Some(options.recursive),
|
||||
@@ -176,8 +176,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
|
||||
sandbox: Option<&FileSystemSandboxContext>,
|
||||
) -> FileSystemResult<()> {
|
||||
trace!("remote fs copy");
|
||||
let client = self.client.get().await.map_err(map_remote_error)?;
|
||||
client
|
||||
let connection = self.client.connection().await.map_err(map_remote_error)?;
|
||||
connection
|
||||
.fs_copy(FsCopyParams {
|
||||
source_path: source_path.clone(),
|
||||
destination_path: destination_path.clone(),
|
||||
|
||||
@@ -1,31 +1,33 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::sync::watch;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::ExecBackend;
|
||||
use crate::ExecProcess;
|
||||
use crate::ExecProcessEventReceiver;
|
||||
use crate::ExecServerError;
|
||||
use crate::StartedExecProcess;
|
||||
use crate::client::LazyRemoteExecServerClient;
|
||||
use crate::client::Session;
|
||||
use crate::client::ProcessSessionHandle;
|
||||
use crate::client::RemoteExecServerClient;
|
||||
use crate::protocol::ExecParams;
|
||||
use crate::protocol::ReadResponse;
|
||||
use crate::protocol::WriteResponse;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RemoteProcess {
|
||||
client: LazyRemoteExecServerClient,
|
||||
client: RemoteExecServerClient,
|
||||
}
|
||||
|
||||
struct RemoteExecProcess {
|
||||
session: Session,
|
||||
session: ProcessSessionHandle,
|
||||
}
|
||||
|
||||
impl RemoteProcess {
|
||||
pub(crate) fn new(client: LazyRemoteExecServerClient) -> Self {
|
||||
pub(crate) fn new(client: RemoteExecServerClient) -> Self {
|
||||
trace!("remote process new");
|
||||
Self { client }
|
||||
}
|
||||
@@ -35,9 +37,8 @@ impl RemoteProcess {
|
||||
impl ExecBackend for RemoteProcess {
|
||||
async fn start(&self, params: ExecParams) -> Result<StartedExecProcess, ExecServerError> {
|
||||
let process_id = params.process_id.clone();
|
||||
let client = self.client.get().await?;
|
||||
let session = client.register_session(&process_id).await?;
|
||||
if let Err(err) = client.exec(params).await {
|
||||
let (session, connection) = self.client.register_process_session(&process_id).await?;
|
||||
if let Err(err) = connection.exec(params).await {
|
||||
session.unregister().await;
|
||||
return Err(err);
|
||||
}
|
||||
@@ -85,8 +86,14 @@ impl ExecProcess for RemoteExecProcess {
|
||||
impl Drop for RemoteExecProcess {
|
||||
fn drop(&mut self) {
|
||||
let session = self.session.clone();
|
||||
tokio::spawn(async move {
|
||||
let Ok(handle) = Handle::try_current() else {
|
||||
warn!(
|
||||
"Could not schedule remote exec process unregister on drop: no Tokio runtime is available"
|
||||
);
|
||||
return;
|
||||
};
|
||||
std::mem::drop(handle.spawn(async move {
|
||||
session.unregister().await;
|
||||
});
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,6 +166,7 @@ async fn collect_process_output_from_events(
|
||||
drop(session);
|
||||
return Ok((stdout, stderr, exit_code, true));
|
||||
}
|
||||
ExecProcessEvent::ResyncRequired => continue,
|
||||
ExecProcessEvent::Failed(message) => {
|
||||
anyhow::bail!("process failed before closed state: {message}");
|
||||
}
|
||||
@@ -189,6 +190,7 @@ async fn collect_process_event_snapshots(
|
||||
ProcessEventSnapshot::Exited { seq, exit_code }
|
||||
}
|
||||
ExecProcessEvent::Closed { seq } => ProcessEventSnapshot::Closed { seq },
|
||||
ExecProcessEvent::ResyncRequired => continue,
|
||||
ExecProcessEvent::Failed(message) => {
|
||||
anyhow::bail!("process failed before closed state: {message}");
|
||||
}
|
||||
@@ -541,7 +543,7 @@ async fn assert_exec_process_preserves_queued_events_before_subscribe(
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
// Serialize tests that launch a real exec-server process through the full CLI.
|
||||
#[serial_test::serial(remote_exec_server)]
|
||||
async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
|
||||
async fn remote_exec_process_surfaces_reconnect_failure_after_server_shutdown() -> Result<()> {
|
||||
let mut context = create_process_context(/*use_remote*/ true).await?;
|
||||
let session = context
|
||||
.backend
|
||||
@@ -562,7 +564,6 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
|
||||
.await?;
|
||||
|
||||
let process = Arc::clone(&session.process);
|
||||
let mut events = process.subscribe_events();
|
||||
let process_for_pending_read = Arc::clone(&process);
|
||||
let pending_read = tokio::spawn(async move {
|
||||
process_for_pending_read
|
||||
@@ -579,36 +580,33 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
|
||||
.expect("remote context should include exec-server harness");
|
||||
server.shutdown().await?;
|
||||
|
||||
let event = timeout(Duration::from_secs(2), events.recv()).await??;
|
||||
let ExecProcessEvent::Failed(event_message) = event else {
|
||||
anyhow::bail!("expected process failure event, got {event:?}");
|
||||
};
|
||||
let pending_error = timeout(Duration::from_secs(2), pending_read)
|
||||
.await
|
||||
.context("timed out waiting for pending read after reconnect failure")??
|
||||
.expect_err("pending read should fail after reconnect fails");
|
||||
assert!(
|
||||
event_message.starts_with("exec-server transport disconnected"),
|
||||
"unexpected failure event: {event_message}"
|
||||
pending_error
|
||||
.to_string()
|
||||
.starts_with("failed to connect to exec-server websocket"),
|
||||
"unexpected pending read error: {pending_error}"
|
||||
);
|
||||
|
||||
let pending_response = timeout(Duration::from_secs(2), pending_read).await???;
|
||||
let pending_message = pending_response
|
||||
.failure
|
||||
.expect("pending read should surface disconnect as a failure");
|
||||
let read_error = timeout(
|
||||
Duration::from_secs(2),
|
||||
process.read(
|
||||
/*after_seq*/ None,
|
||||
/*max_bytes*/ None,
|
||||
/*wait_ms*/ Some(0),
|
||||
),
|
||||
)
|
||||
.await
|
||||
.context("timed out waiting for read after reconnect failure")?
|
||||
.expect_err("read after reconnect failure should fail");
|
||||
assert!(
|
||||
pending_message.starts_with("exec-server transport disconnected"),
|
||||
"unexpected pending failure message: {pending_message}"
|
||||
);
|
||||
|
||||
let mut wake_rx = process.subscribe_wake();
|
||||
let response = read_process_until_change(process, &mut wake_rx, /*after_seq*/ None).await?;
|
||||
let message = response
|
||||
.failure
|
||||
.expect("disconnect should surface as a failure");
|
||||
assert!(
|
||||
message.starts_with("exec-server transport disconnected"),
|
||||
"unexpected failure message: {message}"
|
||||
);
|
||||
assert!(
|
||||
response.closed,
|
||||
"disconnect should close the process session"
|
||||
read_error
|
||||
.to_string()
|
||||
.starts_with("failed to connect to exec-server websocket"),
|
||||
"unexpected read error: {read_error}"
|
||||
);
|
||||
|
||||
let write_result = timeout(
|
||||
@@ -621,7 +619,7 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
|
||||
assert!(
|
||||
write_error
|
||||
.to_string()
|
||||
.starts_with("exec-server transport disconnected"),
|
||||
.starts_with("failed to connect to exec-server websocket"),
|
||||
"unexpected write error: {write_error}"
|
||||
);
|
||||
|
||||
|
||||
@@ -193,6 +193,15 @@ impl ExecutorProcessTransport {
|
||||
self.note_seq(seq);
|
||||
self.closed = true;
|
||||
}
|
||||
Ok(ExecProcessEvent::ResyncRequired) => {
|
||||
if let Err(error) = self.recover_process_events().await {
|
||||
warn!(
|
||||
"Failed to resync remote MCP server output stream ({}): {error}",
|
||||
self.program_name
|
||||
);
|
||||
self.closed = true;
|
||||
}
|
||||
}
|
||||
Ok(ExecProcessEvent::Failed(message)) => {
|
||||
warn!(
|
||||
"Remote MCP server process failed ({}): {message}",
|
||||
@@ -205,7 +214,7 @@ impl ExecutorProcessTransport {
|
||||
"Remote MCP server output stream lagged ({}): skipped {skipped} events",
|
||||
self.program_name
|
||||
);
|
||||
if let Err(error) = self.recover_lagged_events().await {
|
||||
if let Err(error) = self.recover_process_events().await {
|
||||
warn!(
|
||||
"Failed to recover remote MCP server output stream ({}): {error}",
|
||||
self.program_name
|
||||
@@ -232,7 +241,7 @@ impl ExecutorProcessTransport {
|
||||
true
|
||||
}
|
||||
|
||||
async fn recover_lagged_events(&mut self) -> io::Result<()> {
|
||||
async fn recover_process_events(&mut self) -> io::Result<()> {
|
||||
let response = self
|
||||
.process
|
||||
.read(
|
||||
|
||||
Reference in New Issue
Block a user