Add exec-server websocket reconnect foundation

This commit is contained in:
starr-openai
2026-05-18 18:25:18 -07:00
parent 90804bb2eb
commit ef8267fb69
9 changed files with 527 additions and 194 deletions

View File

@@ -12,7 +12,7 @@ use futures::FutureExt;
use futures::future::BoxFuture;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::OnceCell;
use tokio::sync::Notify;
use tokio::sync::mpsc;
use tokio::sync::watch;
@@ -25,6 +25,7 @@ use crate::client_api::ExecServerTransportParams;
use crate::client_api::HttpClient;
use crate::client_api::RemoteExecServerConnectArgs;
use crate::client_api::StdioExecServerConnectArgs;
use crate::client_transport::ENVIRONMENT_CLIENT_NAME;
use crate::connection::JsonRpcConnection;
use crate::process::ExecProcessEvent;
use crate::process::ExecProcessEventLog;
@@ -128,11 +129,12 @@ impl RemoteExecServerConnectArgs {
}
}
pub(crate) struct SessionState {
pub(crate) struct ProcessSession {
wake_tx: watch::Sender<u64>,
events: ExecProcessEventLog,
ordered_events: StdMutex<OrderedSessionEvents>,
failure: Mutex<Option<String>>,
disconnect_behavior: ProcessSessionDisconnectBehavior,
}
#[derive(Default)]
@@ -145,30 +147,46 @@ struct OrderedSessionEvents {
}
#[derive(Clone)]
pub(crate) struct Session {
client: ExecServerClient,
pub(crate) struct ProcessSessionHandle {
control: ProcessSessionControl,
process_id: ProcessId,
state: Arc<SessionState>,
session: Arc<ProcessSession>,
}
#[derive(Clone)]
enum ProcessSessionControl {
#[cfg(test)]
// Direct connections are used by one-shot callers and focused client tests.
Connection(ExecServerConnection),
// Remote environments use the logical client so process sessions survive
// connection replacement across reconnect.
RemoteClient(RemoteExecServerClient),
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum ProcessSessionDisconnectBehavior {
Fail,
Preserve,
}
struct Inner {
client: RpcClient,
rpc_client: RpcClient,
// The remote transport delivers one shared notification stream for every
// process on the connection. Keep a local process_id -> session registry so
// we can turn those connection-global notifications into process wakeups
// process on the connection. Keep a local process_id -> session route map
// so we can turn those connection-global notifications into process wakeups
// without making notifications the source of truth for output delivery.
sessions: ArcSwap<HashMap<ProcessId, Arc<SessionState>>>,
process_session_routes: ArcSwap<HashMap<ProcessId, Arc<ProcessSession>>>,
// ArcSwap makes reads cheap on the hot notification path, but writes still
// need serialization so concurrent register/remove operations do not
// overwrite each other's copy-on-write updates.
sessions_write_lock: Mutex<()>,
process_session_routes_write_lock: Mutex<()>,
// Once the transport closes, every executor operation should fail quickly
// with the same canonical message. This client never reconnects, so the
// latch only moves from unset to set once.
// with the same canonical message. This connection never reconnects, so
// the latch only moves from unset to set once.
disconnected: OnceLock<String>,
// Streaming HTTP responses are keyed by a client-generated request id
// because they share the same connection-global notification channel as
// process output. Keep the routing table local to the client so higher
// process output. Keep the routing table local to the connection so higher
// layers can consume body chunks like a normal byte stream.
http_body_streams: ArcSwap<HashMap<String, mpsc::Sender<HttpRequestBodyDeltaNotification>>>,
http_body_stream_failures: ArcSwap<HashMap<String, String>>,
@@ -185,43 +203,273 @@ impl Drop for Inner {
}
#[derive(Clone)]
pub struct ExecServerClient {
pub struct ExecServerConnection {
inner: Arc<Inner>,
}
#[derive(Clone)]
pub(crate) struct LazyRemoteExecServerClient {
pub(crate) struct RemoteExecServerClient {
transport_params: ExecServerTransportParams,
client: Arc<OnceCell<ExecServerClient>>,
session: Arc<StdMutex<RemoteExecServerSession>>,
}
impl LazyRemoteExecServerClient {
// Shared state for one logical remote exec-server client. The logical client
// owns resumable session identity and durable process session state; individual
// ExecServerConnection values only bind that state to one live transport.
struct RemoteExecServerSession {
connection: Option<ExecServerConnection>,
connection_attempt: Option<Arc<Notify>>,
logical_session_id: Option<String>,
terminal_error: Option<TerminalReconnectError>,
process_sessions: HashMap<ProcessId, Arc<ProcessSession>>,
}
enum RemoteExecServerConnectionAction {
Ready(ExecServerConnection),
Wait(BoxFuture<'static, ()>),
Connect {
connection_attempt: Arc<Notify>,
resume_session_id: Option<String>,
process_sessions: Vec<(ProcessId, Arc<ProcessSession>)>,
},
}
#[derive(Clone)]
struct TerminalReconnectError {
code: i64,
message: String,
}
impl RemoteExecServerClient {
pub(crate) fn new(transport_params: ExecServerTransportParams) -> Self {
Self {
transport_params,
client: Arc::new(OnceCell::new()),
session: Arc::new(StdMutex::new(RemoteExecServerSession {
connection: None,
connection_attempt: None,
logical_session_id: None,
terminal_error: None,
process_sessions: HashMap::new(),
})),
}
}
pub(crate) async fn get(&self) -> Result<ExecServerClient, ExecServerError> {
self.client
// TODO: Add reconnect/disconnect handling here instead of reusing
// the first successfully initialized connection forever.
.get_or_try_init(|| {
let transport_params = self.transport_params.clone();
async move { ExecServerClient::connect_for_transport(transport_params).await }
})
pub(crate) async fn connection(&self) -> Result<ExecServerConnection, ExecServerError> {
loop {
match self.next_connection_action()? {
RemoteExecServerConnectionAction::Ready(connection) => return Ok(connection),
RemoteExecServerConnectionAction::Wait(connection_attempt) => {
connection_attempt.await;
}
RemoteExecServerConnectionAction::Connect {
connection_attempt,
resume_session_id,
process_sessions,
} => {
let connection = self
.connect_and_rebind(resume_session_id.clone(), process_sessions)
.await;
return self.finish_connection_attempt(
connection_attempt,
resume_session_id,
connection,
);
}
}
}
}
pub(crate) async fn register_process_session(
&self,
process_id: &ProcessId,
) -> Result<ProcessSessionHandle, ExecServerError> {
let process_session = Arc::new(ProcessSession::new(
ProcessSessionDisconnectBehavior::Preserve,
));
{
let mut session = self.lock_session();
if session.process_sessions.contains_key(process_id) {
return Err(ExecServerError::Protocol(format!(
"session already registered for process {process_id}"
)));
}
session
.process_sessions
.insert(process_id.clone(), Arc::clone(&process_session));
}
let connection = self.connection().await?;
if let Err(err) = connection
.register_process_session_route(process_id, Arc::clone(&process_session))
.await
.cloned()
{
self.unregister_process_session(process_id).await;
return Err(err);
}
Ok(ProcessSessionHandle {
control: ProcessSessionControl::RemoteClient(self.clone()),
process_id: process_id.clone(),
session: process_session,
})
}
async fn read(&self, params: ReadParams) -> Result<ReadResponse, ExecServerError> {
let connection = self.connection().await?;
match connection.read(params.clone()).await {
Ok(response) => Ok(response),
Err(err) if is_transport_closed_error(&err) && self.supports_reconnect() => {
self.connection().await?.read(params).await
}
Err(err) => Err(err),
}
}
async fn write(
&self,
process_id: &ProcessId,
chunk: Vec<u8>,
) -> Result<WriteResponse, ExecServerError> {
self.connection().await?.write(process_id, chunk).await
}
async fn terminate(
&self,
process_id: &ProcessId,
) -> Result<TerminateResponse, ExecServerError> {
self.connection().await?.terminate(process_id).await
}
async fn unregister_process_session(&self, process_id: &ProcessId) {
let connection = {
let mut session = self.lock_session();
session.process_sessions.remove(process_id);
session.connection.clone()
};
if let Some(connection) = connection {
connection.unregister_process_session(process_id).await;
}
}
fn next_connection_action(&self) -> Result<RemoteExecServerConnectionAction, ExecServerError> {
let mut session = self.lock_session();
if let Some(error) = &session.terminal_error {
return Err(error.to_exec_server_error());
}
if let Some(connection) = &session.connection {
if let Some(error) = connection.disconnected_error() {
if !self.supports_reconnect() {
return Err(error);
}
} else {
return Ok(RemoteExecServerConnectionAction::Ready(connection.clone()));
}
}
if let Some(connection_attempt) = &session.connection_attempt {
let connection_attempt = Arc::clone(connection_attempt).notified_owned();
return Ok(RemoteExecServerConnectionAction::Wait(
connection_attempt.boxed(),
));
}
let connection_attempt = Arc::new(Notify::new());
let resume_session_id = session.logical_session_id.clone();
let process_sessions = session
.process_sessions
.iter()
.map(|(process_id, process_session)| (process_id.clone(), Arc::clone(process_session)))
.collect();
session.connection_attempt = Some(Arc::clone(&connection_attempt));
Ok(RemoteExecServerConnectionAction::Connect {
connection_attempt,
resume_session_id,
process_sessions,
})
}
async fn connect_and_rebind(
&self,
resume_session_id: Option<String>,
process_sessions: Vec<(ProcessId, Arc<ProcessSession>)>,
) -> Result<ExecServerConnection, ExecServerError> {
let connection = self.connect(resume_session_id).await?;
for (process_id, process_session) in process_sessions {
connection
.register_process_session_route(&process_id, process_session)
.await?;
}
Ok(connection)
}
fn finish_connection_attempt(
&self,
connection_attempt: Arc<Notify>,
resume_session_id: Option<String>,
connection: Result<ExecServerConnection, ExecServerError>,
) -> Result<ExecServerConnection, ExecServerError> {
let mut session = self.lock_session();
if let Err(err) = &connection {
if resume_session_id.is_some()
&& let Some(terminal_error) = TerminalReconnectError::from_error(err)
{
session.terminal_error = Some(terminal_error);
}
} else if let Ok(connection) = &connection {
session.logical_session_id = connection.session_id();
session.connection = Some(connection.clone());
}
session.connection_attempt = None;
connection_attempt.notify_waiters();
connection
}
fn lock_session(&self) -> std::sync::MutexGuard<'_, RemoteExecServerSession> {
self.session
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
async fn connect(
&self,
resume_session_id: Option<String>,
) -> Result<ExecServerConnection, ExecServerError> {
match &self.transport_params {
ExecServerTransportParams::WebSocketUrl {
websocket_url,
connect_timeout,
initialize_timeout,
} => {
ExecServerConnection::connect_websocket(RemoteExecServerConnectArgs {
websocket_url: websocket_url.clone(),
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
connect_timeout: *connect_timeout,
initialize_timeout: *initialize_timeout,
resume_session_id,
})
.await
}
ExecServerTransportParams::StdioCommand { .. } => {
ExecServerConnection::connect_for_transport(self.transport_params.clone()).await
}
}
}
fn supports_reconnect(&self) -> bool {
matches!(
&self.transport_params,
ExecServerTransportParams::WebSocketUrl { .. }
)
}
}
impl HttpClient for LazyRemoteExecServerClient {
impl HttpClient for RemoteExecServerClient {
fn http_request(
&self,
params: crate::HttpRequestParams,
) -> BoxFuture<'_, Result<crate::HttpRequestResponse, ExecServerError>> {
async move { self.get().await?.http_request(params).await }.boxed()
async move { self.connection().await?.http_request(params).await }.boxed()
}
fn http_request_stream(
@@ -231,7 +479,7 @@ impl HttpClient for LazyRemoteExecServerClient {
'_,
Result<(crate::HttpRequestResponse, crate::HttpResponseBodyStream), ExecServerError>,
> {
async move { self.get().await?.http_request_stream(params).await }.boxed()
async move { self.connection().await?.http_request_stream(params).await }.boxed()
}
}
@@ -275,7 +523,7 @@ pub enum ExecServerError {
ExecutorRegistryRequest(#[from] reqwest::Error),
}
impl ExecServerClient {
impl ExecServerConnection {
pub async fn initialize(
&self,
options: ExecServerClientConnectOptions,
@@ -289,7 +537,7 @@ impl ExecServerClient {
timeout(initialize_timeout, async {
let response: InitializeResponse = self
.inner
.client
.rpc_client
.call(
INITIALIZE_METHOD,
&InitializeParams {
@@ -397,23 +645,33 @@ impl ExecServerClient {
self.call(FS_COPY_METHOD, &params).await
}
pub(crate) async fn register_session(
#[cfg(test)]
pub(crate) async fn register_process_session(
&self,
process_id: &ProcessId,
) -> Result<Session, ExecServerError> {
let state = Arc::new(SessionState::new());
self.inner
.insert_session(process_id, Arc::clone(&state))
) -> Result<ProcessSessionHandle, ExecServerError> {
let session = Arc::new(ProcessSession::new(ProcessSessionDisconnectBehavior::Fail));
self.register_process_session_route(process_id, Arc::clone(&session))
.await?;
Ok(Session {
client: self.clone(),
Ok(ProcessSessionHandle {
control: ProcessSessionControl::Connection(self.clone()),
process_id: process_id.clone(),
state,
session,
})
}
pub(crate) async fn unregister_session(&self, process_id: &ProcessId) {
self.inner.remove_session(process_id).await;
async fn register_process_session_route(
&self,
process_id: &ProcessId,
session: Arc<ProcessSession>,
) -> Result<(), ExecServerError> {
self.inner
.insert_process_session_route(process_id, session)
.await
}
pub(crate) async fn unregister_process_session(&self, process_id: &ProcessId) {
self.inner.remove_process_session_route(process_id).await;
}
pub fn session_id(&self) -> Option<String> {
@@ -424,6 +682,10 @@ impl ExecServerClient {
.clone()
}
fn disconnected_error(&self) -> Option<ExecServerError> {
self.inner.disconnected_error()
}
pub(crate) async fn connect(
connection: JsonRpcConnection,
options: ExecServerClientConnectOptions,
@@ -462,9 +724,9 @@ impl ExecServerClient {
});
Inner {
client: rpc_client,
sessions: ArcSwap::from_pointee(HashMap::new()),
sessions_write_lock: Mutex::new(()),
rpc_client,
process_session_routes: ArcSwap::from_pointee(HashMap::new()),
process_session_routes_write_lock: Mutex::new(()),
disconnected: OnceLock::new(),
http_body_streams: ArcSwap::from_pointee(HashMap::new()),
http_body_stream_failures: ArcSwap::from_pointee(HashMap::new()),
@@ -475,14 +737,14 @@ impl ExecServerClient {
}
});
let client = Self { inner };
client.initialize(options).await?;
Ok(client)
let connection = Self { inner };
connection.initialize(options).await?;
Ok(connection)
}
async fn notify_initialized(&self) -> Result<(), ExecServerError> {
self.inner
.client
.rpc_client
.notify(INITIALIZED_METHOD, &serde_json::json!({}))
.await
.map_err(ExecServerError::Json)
@@ -500,13 +762,13 @@ impl ExecServerClient {
return Err(error);
}
match self.inner.client.call(method, params).await {
match self.inner.rpc_client.call(method, params).await {
Ok(response) => Ok(response),
Err(error) => {
let error = ExecServerError::from(error);
if is_transport_closed_error(&error) {
// A call can race with disconnect after the preflight
// check. Only the reader task drains sessions so queued
// check. Only the reader task drains routes so queued
// process notifications stay ordered before disconnect.
let message = disconnected_message(/*reason*/ None);
let message = record_disconnected(&self.inner, message);
@@ -532,8 +794,8 @@ impl From<RpcCallError> for ExecServerError {
}
}
impl SessionState {
fn new() -> Self {
impl ProcessSession {
fn new(disconnect_behavior: ProcessSessionDisconnectBehavior) -> Self {
let (wake_tx, _wake_rx) = watch::channel(0);
Self {
wake_tx,
@@ -543,6 +805,7 @@ impl SessionState {
),
ordered_events: StdMutex::new(OrderedSessionEvents::default()),
failure: Mutex::new(None),
disconnect_behavior,
}
}
@@ -638,17 +901,17 @@ impl SessionState {
}
}
impl Session {
impl ProcessSessionHandle {
pub(crate) fn process_id(&self) -> &ProcessId {
&self.process_id
}
pub(crate) fn subscribe_wake(&self) -> watch::Receiver<u64> {
self.state.subscribe()
self.session.subscribe()
}
pub(crate) fn subscribe_events(&self) -> ExecProcessEventReceiver {
self.state.subscribe_events()
self.session.subscribe_events()
}
pub(crate) async fn read(
@@ -657,41 +920,104 @@ impl Session {
max_bytes: Option<usize>,
wait_ms: Option<u64>,
) -> Result<ReadResponse, ExecServerError> {
if let Some(response) = self.state.failed_response().await {
if let Some(response) = self.session.failed_response().await {
return Ok(response);
}
match self
.client
.read(ReadParams {
process_id: self.process_id.clone(),
after_seq,
max_bytes,
wait_ms,
})
.await
{
let params = ReadParams {
process_id: self.process_id.clone(),
after_seq,
max_bytes,
wait_ms,
};
match self.control.read(params).await {
Ok(response) => Ok(response),
Err(err) if is_transport_closed_error(&err) => {
Err(err)
if is_transport_closed_error(&err)
&& self.session.disconnect_behavior
== ProcessSessionDisconnectBehavior::Fail =>
{
let message = disconnected_message(/*reason*/ None);
self.state.set_failure(message.clone()).await;
Ok(self.state.synthesized_failure(message))
self.session.set_failure(message.clone()).await;
Ok(self.session.synthesized_failure(message))
}
Err(err) => Err(err),
}
}
pub(crate) async fn write(&self, chunk: Vec<u8>) -> Result<WriteResponse, ExecServerError> {
self.client.write(&self.process_id, chunk).await
self.control.write(&self.process_id, chunk).await
}
pub(crate) async fn terminate(&self) -> Result<(), ExecServerError> {
self.client.terminate(&self.process_id).await?;
self.control.terminate(&self.process_id).await?;
Ok(())
}
pub(crate) async fn unregister(&self) {
self.client.unregister_session(&self.process_id).await;
self.control
.unregister_process_session(&self.process_id)
.await;
}
}
impl ProcessSessionControl {
async fn read(&self, params: ReadParams) -> Result<ReadResponse, ExecServerError> {
match self {
#[cfg(test)]
Self::Connection(connection) => connection.read(params).await,
Self::RemoteClient(client) => client.read(params).await,
}
}
async fn write(
&self,
process_id: &ProcessId,
chunk: Vec<u8>,
) -> Result<WriteResponse, ExecServerError> {
match self {
#[cfg(test)]
Self::Connection(connection) => connection.write(process_id, chunk).await,
Self::RemoteClient(client) => client.write(process_id, chunk).await,
}
}
async fn terminate(
&self,
process_id: &ProcessId,
) -> Result<TerminateResponse, ExecServerError> {
match self {
#[cfg(test)]
Self::Connection(connection) => connection.terminate(process_id).await,
Self::RemoteClient(client) => client.terminate(process_id).await,
}
}
async fn unregister_process_session(&self, process_id: &ProcessId) {
match self {
#[cfg(test)]
Self::Connection(connection) => connection.unregister_process_session(process_id).await,
Self::RemoteClient(client) => client.unregister_process_session(process_id).await,
}
}
}
impl TerminalReconnectError {
fn from_error(error: &ExecServerError) -> Option<Self> {
match error {
ExecServerError::Server { code, message } if *code == -32600 => Some(Self {
code: *code,
message: message.clone(),
}),
_ => None,
}
}
fn to_exec_server_error(&self) -> ExecServerError {
ExecServerError::Server {
code: self.code,
message: self.message.clone(),
}
}
}
@@ -710,51 +1036,57 @@ impl Inner {
}
}
fn get_session(&self, process_id: &ProcessId) -> Option<Arc<SessionState>> {
self.sessions.load().get(process_id).cloned()
fn get_process_session_route(&self, process_id: &ProcessId) -> Option<Arc<ProcessSession>> {
self.process_session_routes.load().get(process_id).cloned()
}
async fn insert_session(
async fn insert_process_session_route(
&self,
process_id: &ProcessId,
session: Arc<SessionState>,
session: Arc<ProcessSession>,
) -> Result<(), ExecServerError> {
let _sessions_write_guard = self.sessions_write_lock.lock().await;
let _routes_write_guard = self.process_session_routes_write_lock.lock().await;
// Do not register a process session that can never receive executor
// notifications. Without this check, remote MCP startup could create a
// dead session and wait for process output that will never arrive.
if let Some(error) = self.disconnected_error() {
return Err(error);
}
let sessions = self.sessions.load();
if sessions.contains_key(process_id) {
let routes = self.process_session_routes.load();
if let Some(existing_session) = routes.get(process_id) {
if Arc::ptr_eq(existing_session, &session) {
return Ok(());
}
return Err(ExecServerError::Protocol(format!(
"session already registered for process {process_id}"
)));
}
let mut next_sessions = sessions.as_ref().clone();
next_sessions.insert(process_id.clone(), session);
self.sessions.store(Arc::new(next_sessions));
let mut next_routes = routes.as_ref().clone();
next_routes.insert(process_id.clone(), session);
self.process_session_routes.store(Arc::new(next_routes));
Ok(())
}
async fn remove_session(&self, process_id: &ProcessId) -> Option<Arc<SessionState>> {
let _sessions_write_guard = self.sessions_write_lock.lock().await;
let sessions = self.sessions.load();
let session = sessions.get(process_id).cloned();
async fn remove_process_session_route(
&self,
process_id: &ProcessId,
) -> Option<Arc<ProcessSession>> {
let _routes_write_guard = self.process_session_routes_write_lock.lock().await;
let routes = self.process_session_routes.load();
let session = routes.get(process_id).cloned();
session.as_ref()?;
let mut next_sessions = sessions.as_ref().clone();
next_sessions.remove(process_id);
self.sessions.store(Arc::new(next_sessions));
let mut next_routes = routes.as_ref().clone();
next_routes.remove(process_id);
self.process_session_routes.store(Arc::new(next_routes));
session
}
async fn take_all_sessions(&self) -> HashMap<ProcessId, Arc<SessionState>> {
let _sessions_write_guard = self.sessions_write_lock.lock().await;
let sessions = self.sessions.load();
let drained_sessions = sessions.as_ref().clone();
self.sessions.store(Arc::new(HashMap::new()));
drained_sessions
async fn take_all_process_session_routes(&self) -> HashMap<ProcessId, Arc<ProcessSession>> {
let _routes_write_guard = self.process_session_routes_write_lock.lock().await;
let routes = self.process_session_routes.load();
let drained_routes = routes.as_ref().clone();
self.process_session_routes.store(Arc::new(HashMap::new()));
drained_routes
}
}
@@ -779,9 +1111,9 @@ fn is_transport_closed_error(error: &ExecServerError) -> bool {
}
fn record_disconnected(inner: &Arc<Inner>, message: String) -> String {
// The first observer records the canonical disconnect reason. Session
// draining stays with the reader task so it can preserve notification
// ordering before publishing the terminal failure.
// The first observer records the canonical disconnect reason. Process
// session route draining stays with the reader task so it can preserve
// notification ordering before publishing the terminal failure.
if let Some(message) = inner.set_disconnected(message.clone()) {
message
} else {
@@ -789,20 +1121,22 @@ fn record_disconnected(inner: &Arc<Inner>, message: String) -> String {
}
}
async fn fail_all_sessions(inner: &Arc<Inner>, message: String) {
let sessions = inner.take_all_sessions().await;
async fn fail_all_process_sessions(inner: &Arc<Inner>, message: String) {
let routes = inner.take_all_process_session_routes().await;
for (_, session) in sessions {
// Sessions synthesize a closed read response and emit a pushed Failed
// event. That covers both polling consumers and streaming consumers
// such as executor-backed MCP stdio.
session.set_failure(message.clone()).await;
for (_, session) in routes {
// One-shot sessions synthesize a closed read response and emit a
// pushed Failed event. Reconnecting remote sessions keep their local
// event state so a reattached client can bind them again.
if session.disconnect_behavior == ProcessSessionDisconnectBehavior::Fail {
session.set_failure(message.clone()).await;
}
}
}
/// Fails all in-flight work that depends on the shared JSON-RPC transport.
async fn fail_all_in_flight_work(inner: &Arc<Inner>, message: String) {
fail_all_sessions(inner, message.clone()).await;
fail_all_process_sessions(inner, message.clone()).await;
inner.fail_all_http_body_streams(message).await;
}
@@ -814,7 +1148,7 @@ async fn handle_server_notification(
EXEC_OUTPUT_DELTA_METHOD => {
let params: ExecOutputDeltaNotification =
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
if let Some(session) = inner.get_session(&params.process_id) {
if let Some(session) = inner.get_process_session_route(&params.process_id) {
session.note_change(params.seq);
let published_closed =
session.publish_ordered_event(ExecProcessEvent::Output(ProcessOutputChunk {
@@ -823,28 +1157,28 @@ async fn handle_server_notification(
chunk: params.chunk,
}));
if published_closed {
inner.remove_session(&params.process_id).await;
inner.remove_process_session_route(&params.process_id).await;
}
}
}
EXEC_EXITED_METHOD => {
let params: ExecExitedNotification =
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
if let Some(session) = inner.get_session(&params.process_id) {
if let Some(session) = inner.get_process_session_route(&params.process_id) {
session.note_change(params.seq);
let published_closed = session.publish_ordered_event(ExecProcessEvent::Exited {
seq: params.seq,
exit_code: params.exit_code,
});
if published_closed {
inner.remove_session(&params.process_id).await;
inner.remove_process_session_route(&params.process_id).await;
}
}
}
EXEC_CLOSED_METHOD => {
let params: ExecClosedNotification =
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
if let Some(session) = inner.get_session(&params.process_id) {
if let Some(session) = inner.get_process_session_route(&params.process_id) {
session.note_change(params.seq);
// Closed is terminal, but it can arrive before tail output or
// exited. Keep routing this process until the ordered publisher
@@ -852,7 +1186,7 @@ async fn handle_server_notification(
let published_closed =
session.publish_ordered_event(ExecProcessEvent::Closed { seq: params.seq });
if published_closed {
inner.remove_session(&params.process_id).await;
inner.remove_process_session_route(&params.process_id).await;
}
}
}
@@ -891,8 +1225,8 @@ mod tests {
use tokio::time::sleep;
use tokio::time::timeout;
use super::ExecServerClient;
use super::ExecServerClientConnectOptions;
use super::ExecServerConnection;
use crate::ProcessId;
#[cfg(not(windows))]
use crate::client_api::DEFAULT_REMOTE_EXEC_SERVER_INITIALIZE_TIMEOUT;
@@ -940,7 +1274,7 @@ mod tests {
#[cfg(not(windows))]
#[tokio::test]
async fn connect_stdio_command_initializes_json_rpc_client() {
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
let client = ExecServerConnection::connect_stdio_command(StdioExecServerConnectArgs {
command: StdioExecServerCommand {
program: "sh".to_string(),
args: vec![
@@ -963,7 +1297,7 @@ mod tests {
#[cfg(not(windows))]
#[tokio::test]
async fn connect_for_transport_initializes_stdio_command() {
let client = ExecServerClient::connect_for_transport(
let client = ExecServerConnection::connect_for_transport(
ExecServerTransportParams::StdioCommand {
command: StdioExecServerCommand {
program: "sh".to_string(),
@@ -986,7 +1320,7 @@ mod tests {
#[cfg(windows)]
#[tokio::test]
async fn connect_stdio_command_initializes_json_rpc_client_on_windows() {
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
let client = ExecServerConnection::connect_stdio_command(StdioExecServerConnectArgs {
command: StdioExecServerCommand {
program: "powershell".to_string(),
args: vec![
@@ -1024,7 +1358,7 @@ mod tests {
shell_quote(child_pid_file.as_path()),
);
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
let client = ExecServerConnection::connect_stdio_command(StdioExecServerConnectArgs {
command: StdioExecServerCommand {
program: "sh".to_string(),
args: vec!["-c".to_string(), stdio_script],
@@ -1067,7 +1401,7 @@ mod tests {
shell_quote(pid_file.as_path()),
);
let result = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
let result = ExecServerConnection::connect_stdio_command(StdioExecServerConnectArgs {
command: StdioExecServerCommand {
program: "sh".to_string(),
args: vec!["-c".to_string(), stdio_script],
@@ -1161,7 +1495,7 @@ mod tests {
}
});
let client = ExecServerClient::connect(
let client = ExecServerConnection::connect(
JsonRpcConnection::from_stdio(
client_stdout,
client_stdin,
@@ -1174,7 +1508,7 @@ mod tests {
let process_id = ProcessId::from("reordered");
let session = client
.register_session(&process_id)
.register_process_session(&process_id)
.await
.expect("session should register");
let mut events = session.subscribe_events();
@@ -1303,7 +1637,7 @@ mod tests {
drop(server_writer);
});
let client = ExecServerClient::connect(
let client = ExecServerConnection::connect(
JsonRpcConnection::from_stdio(
client_stdout,
client_stdin,
@@ -1316,7 +1650,7 @@ mod tests {
let process_id = ProcessId::from("disconnect");
let session = client
.register_session(&process_id)
.register_process_session(&process_id)
.await
.expect("session should register");
let mut events = session.subscribe_events();
@@ -1344,7 +1678,9 @@ mod tests {
);
assert!(response.closed);
let new_session = client.register_session(&ProcessId::from("new")).await;
let new_session = client
.register_process_session(&ProcessId::from("new"))
.await;
assert!(matches!(
new_session,
Err(super::ExecServerError::Disconnected(_))
@@ -1390,7 +1726,7 @@ mod tests {
}
});
let client = ExecServerClient::connect(
let client = ExecServerConnection::connect(
JsonRpcConnection::from_stdio(
client_stdout,
client_stdin,
@@ -1404,11 +1740,11 @@ mod tests {
let noisy_process_id = ProcessId::from("noisy");
let quiet_process_id = ProcessId::from("quiet");
let _noisy_session = client
.register_session(&noisy_process_id)
.register_process_session(&noisy_process_id)
.await
.expect("noisy session should register");
let quiet_session = client
.register_session(&quiet_process_id)
.register_process_session(&quiet_process_id)
.await
.expect("quiet session should register");
let mut quiet_wake_rx = quiet_session.subscribe_wake();

View File

@@ -3,7 +3,7 @@
//! This module is the facade for the environment-owned [`crate::HttpClient`]
//! capability:
//! - [`ReqwestHttpClient`] executes requests directly with `reqwest`
//! - [`ExecServerClient`] forwards requests over the JSON-RPC transport
//! - [`ExecServerConnection`] forwards requests over the JSON-RPC transport
//! - [`HttpResponseBodyStream`] presents buffered local bodies and streamed
//! remote `http/request/bodyDelta` notifications through one byte-stream API
//!
@@ -11,7 +11,7 @@
//! - orchestrator process: holds an `Arc<dyn HttpClient>` and chooses local or
//! remote execution
//! - remote runtime: serves the `http/request` RPC and runs the concrete local
//! HTTP request there when the orchestrator uses [`ExecServerClient`]
//! HTTP request there when the orchestrator uses [`ExecServerConnection`]
#[path = "reqwest_http_client.rs"]
mod reqwest_http_client;

View File

@@ -14,7 +14,7 @@ use tokio::sync::mpsc;
use super::HttpResponseBodyStream;
use super::response_body_stream::HttpBodyStreamRegistration;
use crate::HttpClient;
use crate::client::ExecServerClient;
use crate::client::ExecServerConnection;
use crate::client::ExecServerError;
use crate::protocol::HTTP_REQUEST_METHOD;
use crate::protocol::HttpRequestParams;
@@ -23,7 +23,7 @@ use crate::protocol::HttpRequestResponse;
/// Maximum queued body frames per streamed HTTP response.
const HTTP_BODY_DELTA_CHANNEL_CAPACITY: usize = 256;
impl ExecServerClient {
impl ExecServerConnection {
/// Performs an HTTP request and buffers the response body.
pub async fn http_request(
&self,
@@ -67,14 +67,14 @@ impl ExecServerClient {
}
}
impl HttpClient for ExecServerClient {
impl HttpClient for ExecServerConnection {
/// Orchestrator-side adapter that forwards buffered HTTP requests to the
/// remote runtime over the shared JSON-RPC connection.
fn http_request(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>> {
async move { ExecServerClient::http_request(self, params).await }.boxed()
async move { ExecServerConnection::http_request(self, params).await }.boxed()
}
/// Orchestrator-side adapter that forwards streamed HTTP requests to the
@@ -83,6 +83,6 @@ impl HttpClient for ExecServerClient {
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> {
async move { ExecServerClient::http_request_stream(self, params).await }.boxed()
async move { ExecServerConnection::http_request_stream(self, params).await }.boxed()
}
}

View File

@@ -9,7 +9,7 @@ use tracing::warn;
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
use crate::ExecServerClient;
use crate::ExecServerConnection;
use crate::ExecServerError;
use crate::client_api::RemoteExecServerConnectArgs;
use crate::client_api::StdioExecServerCommand;
@@ -17,9 +17,9 @@ use crate::client_api::StdioExecServerConnectArgs;
use crate::connection::JsonRpcConnection;
use crate::relay::harness_connection_from_websocket;
const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment";
pub(crate) const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment";
impl ExecServerClient {
impl ExecServerConnection {
pub(crate) async fn connect_for_transport(
transport_params: crate::client_api::ExecServerTransportParams,
) -> Result<Self, ExecServerError> {

View File

@@ -6,7 +6,7 @@ use crate::ExecServerError;
use crate::ExecServerRuntimePaths;
use crate::ExecutorFileSystem;
use crate::HttpClient;
use crate::client::LazyRemoteExecServerClient;
use crate::client::RemoteExecServerClient;
use crate::client::http_client::ReqwestHttpClient;
use crate::client_api::ExecServerTransportParams;
use crate::environment_provider::DefaultEnvironmentProvider;
@@ -403,7 +403,7 @@ impl Environment {
} => Some(exec_server_url.clone()),
ExecServerTransportParams::StdioCommand { .. } => None,
};
let client = LazyRemoteExecServerClient::new(remote_transport.clone());
let client = RemoteExecServerClient::new(remote_transport.clone());
let exec_backend: Arc<dyn ExecBackend> = Arc::new(RemoteProcess::new(client.clone()));
let filesystem: Arc<dyn ExecutorFileSystem> =
Arc::new(RemoteFileSystem::new(client.clone()));

View File

@@ -23,8 +23,9 @@ mod runtime_paths;
mod sandboxed_file_system;
mod server;
pub use client::ExecServerClient;
pub use client::ExecServerConnection;
pub use client::ExecServerError;
pub type ExecServerClient = ExecServerConnection;
pub use client::http_client::HttpResponseBodyStream;
pub use client::http_client::ReqwestHttpClient;
pub use client_api::ExecServerClientConnectOptions;

View File

@@ -14,7 +14,7 @@ use crate::FileSystemResult;
use crate::FileSystemSandboxContext;
use crate::ReadDirectoryEntry;
use crate::RemoveOptions;
use crate::client::LazyRemoteExecServerClient;
use crate::client::RemoteExecServerClient;
use crate::protocol::FsCopyParams;
use crate::protocol::FsCreateDirectoryParams;
use crate::protocol::FsGetMetadataParams;
@@ -28,11 +28,11 @@ const NOT_FOUND_ERROR_CODE: i64 = -32004;
#[derive(Clone)]
pub(crate) struct RemoteFileSystem {
client: LazyRemoteExecServerClient,
client: RemoteExecServerClient,
}
impl RemoteFileSystem {
pub(crate) fn new(client: LazyRemoteExecServerClient) -> Self {
pub(crate) fn new(client: RemoteExecServerClient) -> Self {
trace!("remote fs new");
Self { client }
}
@@ -46,8 +46,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
sandbox: Option<&FileSystemSandboxContext>,
) -> FileSystemResult<Vec<u8>> {
trace!("remote fs read_file");
let client = self.client.get().await.map_err(map_remote_error)?;
let response = client
let connection = self.client.connection().await.map_err(map_remote_error)?;
let response = connection
.fs_read_file(FsReadFileParams {
path: path.clone(),
sandbox: remote_sandbox_context(sandbox),
@@ -69,8 +69,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
sandbox: Option<&FileSystemSandboxContext>,
) -> FileSystemResult<()> {
trace!("remote fs write_file");
let client = self.client.get().await.map_err(map_remote_error)?;
client
let connection = self.client.connection().await.map_err(map_remote_error)?;
connection
.fs_write_file(FsWriteFileParams {
path: path.clone(),
data_base64: STANDARD.encode(contents),
@@ -88,8 +88,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
sandbox: Option<&FileSystemSandboxContext>,
) -> FileSystemResult<()> {
trace!("remote fs create_directory");
let client = self.client.get().await.map_err(map_remote_error)?;
client
let connection = self.client.connection().await.map_err(map_remote_error)?;
connection
.fs_create_directory(FsCreateDirectoryParams {
path: path.clone(),
recursive: Some(options.recursive),
@@ -106,8 +106,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
sandbox: Option<&FileSystemSandboxContext>,
) -> FileSystemResult<FileMetadata> {
trace!("remote fs get_metadata");
let client = self.client.get().await.map_err(map_remote_error)?;
let response = client
let connection = self.client.connection().await.map_err(map_remote_error)?;
let response = connection
.fs_get_metadata(FsGetMetadataParams {
path: path.clone(),
sandbox: remote_sandbox_context(sandbox),
@@ -129,8 +129,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
sandbox: Option<&FileSystemSandboxContext>,
) -> FileSystemResult<Vec<ReadDirectoryEntry>> {
trace!("remote fs read_directory");
let client = self.client.get().await.map_err(map_remote_error)?;
let response = client
let connection = self.client.connection().await.map_err(map_remote_error)?;
let response = connection
.fs_read_directory(FsReadDirectoryParams {
path: path.clone(),
sandbox: remote_sandbox_context(sandbox),
@@ -155,8 +155,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
sandbox: Option<&FileSystemSandboxContext>,
) -> FileSystemResult<()> {
trace!("remote fs remove");
let client = self.client.get().await.map_err(map_remote_error)?;
client
let connection = self.client.connection().await.map_err(map_remote_error)?;
connection
.fs_remove(FsRemoveParams {
path: path.clone(),
recursive: Some(options.recursive),
@@ -176,8 +176,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
sandbox: Option<&FileSystemSandboxContext>,
) -> FileSystemResult<()> {
trace!("remote fs copy");
let client = self.client.get().await.map_err(map_remote_error)?;
client
let connection = self.client.connection().await.map_err(map_remote_error)?;
connection
.fs_copy(FsCopyParams {
source_path: source_path.clone(),
destination_path: destination_path.clone(),

View File

@@ -9,23 +9,23 @@ use crate::ExecProcess;
use crate::ExecProcessEventReceiver;
use crate::ExecServerError;
use crate::StartedExecProcess;
use crate::client::LazyRemoteExecServerClient;
use crate::client::Session;
use crate::client::ProcessSessionHandle;
use crate::client::RemoteExecServerClient;
use crate::protocol::ExecParams;
use crate::protocol::ReadResponse;
use crate::protocol::WriteResponse;
#[derive(Clone)]
pub(crate) struct RemoteProcess {
client: LazyRemoteExecServerClient,
client: RemoteExecServerClient,
}
struct RemoteExecProcess {
session: Session,
session: ProcessSessionHandle,
}
impl RemoteProcess {
pub(crate) fn new(client: LazyRemoteExecServerClient) -> Self {
pub(crate) fn new(client: RemoteExecServerClient) -> Self {
trace!("remote process new");
Self { client }
}
@@ -35,9 +35,9 @@ impl RemoteProcess {
impl ExecBackend for RemoteProcess {
async fn start(&self, params: ExecParams) -> Result<StartedExecProcess, ExecServerError> {
let process_id = params.process_id.clone();
let client = self.client.get().await?;
let session = client.register_session(&process_id).await?;
if let Err(err) = client.exec(params).await {
let session = self.client.register_process_session(&process_id).await?;
let connection = self.client.connection().await?;
if let Err(err) = connection.exec(params).await {
session.unregister().await;
return Err(err);
}

View File

@@ -541,7 +541,7 @@ async fn assert_exec_process_preserves_queued_events_before_subscribe(
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
// Serialize tests that launch a real exec-server process through the full CLI.
#[serial_test::serial(remote_exec_server)]
async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
async fn remote_exec_process_surfaces_reconnect_failure_after_server_shutdown() -> Result<()> {
let mut context = create_process_context(/*use_remote*/ true).await?;
let session = context
.backend
@@ -562,7 +562,6 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
.await?;
let process = Arc::clone(&session.process);
let mut events = process.subscribe_events();
let process_for_pending_read = Arc::clone(&process);
let pending_read = tokio::spawn(async move {
process_for_pending_read
@@ -579,36 +578,33 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
.expect("remote context should include exec-server harness");
server.shutdown().await?;
let event = timeout(Duration::from_secs(2), events.recv()).await??;
let ExecProcessEvent::Failed(event_message) = event else {
anyhow::bail!("expected process failure event, got {event:?}");
};
let pending_error = timeout(Duration::from_secs(2), pending_read)
.await
.context("timed out waiting for pending read after reconnect failure")??
.expect_err("pending read should fail after reconnect fails");
assert!(
event_message.starts_with("exec-server transport disconnected"),
"unexpected failure event: {event_message}"
pending_error
.to_string()
.starts_with("failed to connect to exec-server websocket"),
"unexpected pending read error: {pending_error}"
);
let pending_response = timeout(Duration::from_secs(2), pending_read).await???;
let pending_message = pending_response
.failure
.expect("pending read should surface disconnect as a failure");
let read_error = timeout(
Duration::from_secs(2),
process.read(
/*after_seq*/ None,
/*max_bytes*/ None,
/*wait_ms*/ Some(0),
),
)
.await
.context("timed out waiting for read after reconnect failure")?
.expect_err("read after reconnect failure should fail");
assert!(
pending_message.starts_with("exec-server transport disconnected"),
"unexpected pending failure message: {pending_message}"
);
let mut wake_rx = process.subscribe_wake();
let response = read_process_until_change(process, &mut wake_rx, /*after_seq*/ None).await?;
let message = response
.failure
.expect("disconnect should surface as a failure");
assert!(
message.starts_with("exec-server transport disconnected"),
"unexpected failure message: {message}"
);
assert!(
response.closed,
"disconnect should close the process session"
read_error
.to_string()
.starts_with("failed to connect to exec-server websocket"),
"unexpected read error: {read_error}"
);
let write_result = timeout(
@@ -621,7 +617,7 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
assert!(
write_error
.to_string()
.starts_with("exec-server transport disconnected"),
.starts_with("failed to connect to exec-server websocket"),
"unexpected write error: {write_error}"
);