mirror of
https://github.com/openai/codex.git
synced 2026-05-24 04:54:52 +00:00
Add exec-server websocket reconnect foundation
This commit is contained in:
@@ -12,7 +12,7 @@ use futures::FutureExt;
|
||||
use futures::future::BoxFuture;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::OnceCell;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::watch;
|
||||
|
||||
@@ -25,6 +25,7 @@ use crate::client_api::ExecServerTransportParams;
|
||||
use crate::client_api::HttpClient;
|
||||
use crate::client_api::RemoteExecServerConnectArgs;
|
||||
use crate::client_api::StdioExecServerConnectArgs;
|
||||
use crate::client_transport::ENVIRONMENT_CLIENT_NAME;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::process::ExecProcessEvent;
|
||||
use crate::process::ExecProcessEventLog;
|
||||
@@ -128,11 +129,12 @@ impl RemoteExecServerConnectArgs {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct SessionState {
|
||||
pub(crate) struct ProcessSession {
|
||||
wake_tx: watch::Sender<u64>,
|
||||
events: ExecProcessEventLog,
|
||||
ordered_events: StdMutex<OrderedSessionEvents>,
|
||||
failure: Mutex<Option<String>>,
|
||||
disconnect_behavior: ProcessSessionDisconnectBehavior,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -145,30 +147,46 @@ struct OrderedSessionEvents {
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct Session {
|
||||
client: ExecServerClient,
|
||||
pub(crate) struct ProcessSessionHandle {
|
||||
control: ProcessSessionControl,
|
||||
process_id: ProcessId,
|
||||
state: Arc<SessionState>,
|
||||
session: Arc<ProcessSession>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum ProcessSessionControl {
|
||||
#[cfg(test)]
|
||||
// Direct connections are used by one-shot callers and focused client tests.
|
||||
Connection(ExecServerConnection),
|
||||
// Remote environments use the logical client so process sessions survive
|
||||
// connection replacement across reconnect.
|
||||
RemoteClient(RemoteExecServerClient),
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
enum ProcessSessionDisconnectBehavior {
|
||||
Fail,
|
||||
Preserve,
|
||||
}
|
||||
|
||||
struct Inner {
|
||||
client: RpcClient,
|
||||
rpc_client: RpcClient,
|
||||
// The remote transport delivers one shared notification stream for every
|
||||
// process on the connection. Keep a local process_id -> session registry so
|
||||
// we can turn those connection-global notifications into process wakeups
|
||||
// process on the connection. Keep a local process_id -> session route map
|
||||
// so we can turn those connection-global notifications into process wakeups
|
||||
// without making notifications the source of truth for output delivery.
|
||||
sessions: ArcSwap<HashMap<ProcessId, Arc<SessionState>>>,
|
||||
process_session_routes: ArcSwap<HashMap<ProcessId, Arc<ProcessSession>>>,
|
||||
// ArcSwap makes reads cheap on the hot notification path, but writes still
|
||||
// need serialization so concurrent register/remove operations do not
|
||||
// overwrite each other's copy-on-write updates.
|
||||
sessions_write_lock: Mutex<()>,
|
||||
process_session_routes_write_lock: Mutex<()>,
|
||||
// Once the transport closes, every executor operation should fail quickly
|
||||
// with the same canonical message. This client never reconnects, so the
|
||||
// latch only moves from unset to set once.
|
||||
// with the same canonical message. This connection never reconnects, so
|
||||
// the latch only moves from unset to set once.
|
||||
disconnected: OnceLock<String>,
|
||||
// Streaming HTTP responses are keyed by a client-generated request id
|
||||
// because they share the same connection-global notification channel as
|
||||
// process output. Keep the routing table local to the client so higher
|
||||
// process output. Keep the routing table local to the connection so higher
|
||||
// layers can consume body chunks like a normal byte stream.
|
||||
http_body_streams: ArcSwap<HashMap<String, mpsc::Sender<HttpRequestBodyDeltaNotification>>>,
|
||||
http_body_stream_failures: ArcSwap<HashMap<String, String>>,
|
||||
@@ -185,43 +203,273 @@ impl Drop for Inner {
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ExecServerClient {
|
||||
pub struct ExecServerConnection {
|
||||
inner: Arc<Inner>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct LazyRemoteExecServerClient {
|
||||
pub(crate) struct RemoteExecServerClient {
|
||||
transport_params: ExecServerTransportParams,
|
||||
client: Arc<OnceCell<ExecServerClient>>,
|
||||
session: Arc<StdMutex<RemoteExecServerSession>>,
|
||||
}
|
||||
|
||||
impl LazyRemoteExecServerClient {
|
||||
// Shared state for one logical remote exec-server client. The logical client
|
||||
// owns resumable session identity and durable process session state; individual
|
||||
// ExecServerConnection values only bind that state to one live transport.
|
||||
struct RemoteExecServerSession {
|
||||
connection: Option<ExecServerConnection>,
|
||||
connection_attempt: Option<Arc<Notify>>,
|
||||
logical_session_id: Option<String>,
|
||||
terminal_error: Option<TerminalReconnectError>,
|
||||
process_sessions: HashMap<ProcessId, Arc<ProcessSession>>,
|
||||
}
|
||||
|
||||
enum RemoteExecServerConnectionAction {
|
||||
Ready(ExecServerConnection),
|
||||
Wait(BoxFuture<'static, ()>),
|
||||
Connect {
|
||||
connection_attempt: Arc<Notify>,
|
||||
resume_session_id: Option<String>,
|
||||
process_sessions: Vec<(ProcessId, Arc<ProcessSession>)>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct TerminalReconnectError {
|
||||
code: i64,
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl RemoteExecServerClient {
|
||||
pub(crate) fn new(transport_params: ExecServerTransportParams) -> Self {
|
||||
Self {
|
||||
transport_params,
|
||||
client: Arc::new(OnceCell::new()),
|
||||
session: Arc::new(StdMutex::new(RemoteExecServerSession {
|
||||
connection: None,
|
||||
connection_attempt: None,
|
||||
logical_session_id: None,
|
||||
terminal_error: None,
|
||||
process_sessions: HashMap::new(),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn get(&self) -> Result<ExecServerClient, ExecServerError> {
|
||||
self.client
|
||||
// TODO: Add reconnect/disconnect handling here instead of reusing
|
||||
// the first successfully initialized connection forever.
|
||||
.get_or_try_init(|| {
|
||||
let transport_params = self.transport_params.clone();
|
||||
async move { ExecServerClient::connect_for_transport(transport_params).await }
|
||||
})
|
||||
pub(crate) async fn connection(&self) -> Result<ExecServerConnection, ExecServerError> {
|
||||
loop {
|
||||
match self.next_connection_action()? {
|
||||
RemoteExecServerConnectionAction::Ready(connection) => return Ok(connection),
|
||||
RemoteExecServerConnectionAction::Wait(connection_attempt) => {
|
||||
connection_attempt.await;
|
||||
}
|
||||
RemoteExecServerConnectionAction::Connect {
|
||||
connection_attempt,
|
||||
resume_session_id,
|
||||
process_sessions,
|
||||
} => {
|
||||
let connection = self
|
||||
.connect_and_rebind(resume_session_id.clone(), process_sessions)
|
||||
.await;
|
||||
return self.finish_connection_attempt(
|
||||
connection_attempt,
|
||||
resume_session_id,
|
||||
connection,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn register_process_session(
|
||||
&self,
|
||||
process_id: &ProcessId,
|
||||
) -> Result<ProcessSessionHandle, ExecServerError> {
|
||||
let process_session = Arc::new(ProcessSession::new(
|
||||
ProcessSessionDisconnectBehavior::Preserve,
|
||||
));
|
||||
{
|
||||
let mut session = self.lock_session();
|
||||
if session.process_sessions.contains_key(process_id) {
|
||||
return Err(ExecServerError::Protocol(format!(
|
||||
"session already registered for process {process_id}"
|
||||
)));
|
||||
}
|
||||
session
|
||||
.process_sessions
|
||||
.insert(process_id.clone(), Arc::clone(&process_session));
|
||||
}
|
||||
|
||||
let connection = self.connection().await?;
|
||||
if let Err(err) = connection
|
||||
.register_process_session_route(process_id, Arc::clone(&process_session))
|
||||
.await
|
||||
.cloned()
|
||||
{
|
||||
self.unregister_process_session(process_id).await;
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
Ok(ProcessSessionHandle {
|
||||
control: ProcessSessionControl::RemoteClient(self.clone()),
|
||||
process_id: process_id.clone(),
|
||||
session: process_session,
|
||||
})
|
||||
}
|
||||
|
||||
async fn read(&self, params: ReadParams) -> Result<ReadResponse, ExecServerError> {
|
||||
let connection = self.connection().await?;
|
||||
match connection.read(params.clone()).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) if is_transport_closed_error(&err) && self.supports_reconnect() => {
|
||||
self.connection().await?.read(params).await
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
async fn write(
|
||||
&self,
|
||||
process_id: &ProcessId,
|
||||
chunk: Vec<u8>,
|
||||
) -> Result<WriteResponse, ExecServerError> {
|
||||
self.connection().await?.write(process_id, chunk).await
|
||||
}
|
||||
|
||||
async fn terminate(
|
||||
&self,
|
||||
process_id: &ProcessId,
|
||||
) -> Result<TerminateResponse, ExecServerError> {
|
||||
self.connection().await?.terminate(process_id).await
|
||||
}
|
||||
|
||||
async fn unregister_process_session(&self, process_id: &ProcessId) {
|
||||
let connection = {
|
||||
let mut session = self.lock_session();
|
||||
session.process_sessions.remove(process_id);
|
||||
session.connection.clone()
|
||||
};
|
||||
if let Some(connection) = connection {
|
||||
connection.unregister_process_session(process_id).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn next_connection_action(&self) -> Result<RemoteExecServerConnectionAction, ExecServerError> {
|
||||
let mut session = self.lock_session();
|
||||
if let Some(error) = &session.terminal_error {
|
||||
return Err(error.to_exec_server_error());
|
||||
}
|
||||
|
||||
if let Some(connection) = &session.connection {
|
||||
if let Some(error) = connection.disconnected_error() {
|
||||
if !self.supports_reconnect() {
|
||||
return Err(error);
|
||||
}
|
||||
} else {
|
||||
return Ok(RemoteExecServerConnectionAction::Ready(connection.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(connection_attempt) = &session.connection_attempt {
|
||||
let connection_attempt = Arc::clone(connection_attempt).notified_owned();
|
||||
return Ok(RemoteExecServerConnectionAction::Wait(
|
||||
connection_attempt.boxed(),
|
||||
));
|
||||
}
|
||||
|
||||
let connection_attempt = Arc::new(Notify::new());
|
||||
let resume_session_id = session.logical_session_id.clone();
|
||||
let process_sessions = session
|
||||
.process_sessions
|
||||
.iter()
|
||||
.map(|(process_id, process_session)| (process_id.clone(), Arc::clone(process_session)))
|
||||
.collect();
|
||||
session.connection_attempt = Some(Arc::clone(&connection_attempt));
|
||||
Ok(RemoteExecServerConnectionAction::Connect {
|
||||
connection_attempt,
|
||||
resume_session_id,
|
||||
process_sessions,
|
||||
})
|
||||
}
|
||||
|
||||
async fn connect_and_rebind(
|
||||
&self,
|
||||
resume_session_id: Option<String>,
|
||||
process_sessions: Vec<(ProcessId, Arc<ProcessSession>)>,
|
||||
) -> Result<ExecServerConnection, ExecServerError> {
|
||||
let connection = self.connect(resume_session_id).await?;
|
||||
for (process_id, process_session) in process_sessions {
|
||||
connection
|
||||
.register_process_session_route(&process_id, process_session)
|
||||
.await?;
|
||||
}
|
||||
Ok(connection)
|
||||
}
|
||||
|
||||
fn finish_connection_attempt(
|
||||
&self,
|
||||
connection_attempt: Arc<Notify>,
|
||||
resume_session_id: Option<String>,
|
||||
connection: Result<ExecServerConnection, ExecServerError>,
|
||||
) -> Result<ExecServerConnection, ExecServerError> {
|
||||
let mut session = self.lock_session();
|
||||
if let Err(err) = &connection {
|
||||
if resume_session_id.is_some()
|
||||
&& let Some(terminal_error) = TerminalReconnectError::from_error(err)
|
||||
{
|
||||
session.terminal_error = Some(terminal_error);
|
||||
}
|
||||
} else if let Ok(connection) = &connection {
|
||||
session.logical_session_id = connection.session_id();
|
||||
session.connection = Some(connection.clone());
|
||||
}
|
||||
session.connection_attempt = None;
|
||||
connection_attempt.notify_waiters();
|
||||
connection
|
||||
}
|
||||
|
||||
fn lock_session(&self) -> std::sync::MutexGuard<'_, RemoteExecServerSession> {
|
||||
self.session
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
|
||||
async fn connect(
|
||||
&self,
|
||||
resume_session_id: Option<String>,
|
||||
) -> Result<ExecServerConnection, ExecServerError> {
|
||||
match &self.transport_params {
|
||||
ExecServerTransportParams::WebSocketUrl {
|
||||
websocket_url,
|
||||
connect_timeout,
|
||||
initialize_timeout,
|
||||
} => {
|
||||
ExecServerConnection::connect_websocket(RemoteExecServerConnectArgs {
|
||||
websocket_url: websocket_url.clone(),
|
||||
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
|
||||
connect_timeout: *connect_timeout,
|
||||
initialize_timeout: *initialize_timeout,
|
||||
resume_session_id,
|
||||
})
|
||||
.await
|
||||
}
|
||||
ExecServerTransportParams::StdioCommand { .. } => {
|
||||
ExecServerConnection::connect_for_transport(self.transport_params.clone()).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_reconnect(&self) -> bool {
|
||||
matches!(
|
||||
&self.transport_params,
|
||||
ExecServerTransportParams::WebSocketUrl { .. }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpClient for LazyRemoteExecServerClient {
|
||||
impl HttpClient for RemoteExecServerClient {
|
||||
fn http_request(
|
||||
&self,
|
||||
params: crate::HttpRequestParams,
|
||||
) -> BoxFuture<'_, Result<crate::HttpRequestResponse, ExecServerError>> {
|
||||
async move { self.get().await?.http_request(params).await }.boxed()
|
||||
async move { self.connection().await?.http_request(params).await }.boxed()
|
||||
}
|
||||
|
||||
fn http_request_stream(
|
||||
@@ -231,7 +479,7 @@ impl HttpClient for LazyRemoteExecServerClient {
|
||||
'_,
|
||||
Result<(crate::HttpRequestResponse, crate::HttpResponseBodyStream), ExecServerError>,
|
||||
> {
|
||||
async move { self.get().await?.http_request_stream(params).await }.boxed()
|
||||
async move { self.connection().await?.http_request_stream(params).await }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,7 +523,7 @@ pub enum ExecServerError {
|
||||
ExecutorRegistryRequest(#[from] reqwest::Error),
|
||||
}
|
||||
|
||||
impl ExecServerClient {
|
||||
impl ExecServerConnection {
|
||||
pub async fn initialize(
|
||||
&self,
|
||||
options: ExecServerClientConnectOptions,
|
||||
@@ -289,7 +537,7 @@ impl ExecServerClient {
|
||||
timeout(initialize_timeout, async {
|
||||
let response: InitializeResponse = self
|
||||
.inner
|
||||
.client
|
||||
.rpc_client
|
||||
.call(
|
||||
INITIALIZE_METHOD,
|
||||
&InitializeParams {
|
||||
@@ -397,23 +645,33 @@ impl ExecServerClient {
|
||||
self.call(FS_COPY_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub(crate) async fn register_session(
|
||||
#[cfg(test)]
|
||||
pub(crate) async fn register_process_session(
|
||||
&self,
|
||||
process_id: &ProcessId,
|
||||
) -> Result<Session, ExecServerError> {
|
||||
let state = Arc::new(SessionState::new());
|
||||
self.inner
|
||||
.insert_session(process_id, Arc::clone(&state))
|
||||
) -> Result<ProcessSessionHandle, ExecServerError> {
|
||||
let session = Arc::new(ProcessSession::new(ProcessSessionDisconnectBehavior::Fail));
|
||||
self.register_process_session_route(process_id, Arc::clone(&session))
|
||||
.await?;
|
||||
Ok(Session {
|
||||
client: self.clone(),
|
||||
Ok(ProcessSessionHandle {
|
||||
control: ProcessSessionControl::Connection(self.clone()),
|
||||
process_id: process_id.clone(),
|
||||
state,
|
||||
session,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn unregister_session(&self, process_id: &ProcessId) {
|
||||
self.inner.remove_session(process_id).await;
|
||||
async fn register_process_session_route(
|
||||
&self,
|
||||
process_id: &ProcessId,
|
||||
session: Arc<ProcessSession>,
|
||||
) -> Result<(), ExecServerError> {
|
||||
self.inner
|
||||
.insert_process_session_route(process_id, session)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn unregister_process_session(&self, process_id: &ProcessId) {
|
||||
self.inner.remove_process_session_route(process_id).await;
|
||||
}
|
||||
|
||||
pub fn session_id(&self) -> Option<String> {
|
||||
@@ -424,6 +682,10 @@ impl ExecServerClient {
|
||||
.clone()
|
||||
}
|
||||
|
||||
fn disconnected_error(&self) -> Option<ExecServerError> {
|
||||
self.inner.disconnected_error()
|
||||
}
|
||||
|
||||
pub(crate) async fn connect(
|
||||
connection: JsonRpcConnection,
|
||||
options: ExecServerClientConnectOptions,
|
||||
@@ -462,9 +724,9 @@ impl ExecServerClient {
|
||||
});
|
||||
|
||||
Inner {
|
||||
client: rpc_client,
|
||||
sessions: ArcSwap::from_pointee(HashMap::new()),
|
||||
sessions_write_lock: Mutex::new(()),
|
||||
rpc_client,
|
||||
process_session_routes: ArcSwap::from_pointee(HashMap::new()),
|
||||
process_session_routes_write_lock: Mutex::new(()),
|
||||
disconnected: OnceLock::new(),
|
||||
http_body_streams: ArcSwap::from_pointee(HashMap::new()),
|
||||
http_body_stream_failures: ArcSwap::from_pointee(HashMap::new()),
|
||||
@@ -475,14 +737,14 @@ impl ExecServerClient {
|
||||
}
|
||||
});
|
||||
|
||||
let client = Self { inner };
|
||||
client.initialize(options).await?;
|
||||
Ok(client)
|
||||
let connection = Self { inner };
|
||||
connection.initialize(options).await?;
|
||||
Ok(connection)
|
||||
}
|
||||
|
||||
async fn notify_initialized(&self) -> Result<(), ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.rpc_client
|
||||
.notify(INITIALIZED_METHOD, &serde_json::json!({}))
|
||||
.await
|
||||
.map_err(ExecServerError::Json)
|
||||
@@ -500,13 +762,13 @@ impl ExecServerClient {
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
match self.inner.client.call(method, params).await {
|
||||
match self.inner.rpc_client.call(method, params).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(error) => {
|
||||
let error = ExecServerError::from(error);
|
||||
if is_transport_closed_error(&error) {
|
||||
// A call can race with disconnect after the preflight
|
||||
// check. Only the reader task drains sessions so queued
|
||||
// check. Only the reader task drains routes so queued
|
||||
// process notifications stay ordered before disconnect.
|
||||
let message = disconnected_message(/*reason*/ None);
|
||||
let message = record_disconnected(&self.inner, message);
|
||||
@@ -532,8 +794,8 @@ impl From<RpcCallError> for ExecServerError {
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
fn new() -> Self {
|
||||
impl ProcessSession {
|
||||
fn new(disconnect_behavior: ProcessSessionDisconnectBehavior) -> Self {
|
||||
let (wake_tx, _wake_rx) = watch::channel(0);
|
||||
Self {
|
||||
wake_tx,
|
||||
@@ -543,6 +805,7 @@ impl SessionState {
|
||||
),
|
||||
ordered_events: StdMutex::new(OrderedSessionEvents::default()),
|
||||
failure: Mutex::new(None),
|
||||
disconnect_behavior,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -638,17 +901,17 @@ impl SessionState {
|
||||
}
|
||||
}
|
||||
|
||||
impl Session {
|
||||
impl ProcessSessionHandle {
|
||||
pub(crate) fn process_id(&self) -> &ProcessId {
|
||||
&self.process_id
|
||||
}
|
||||
|
||||
pub(crate) fn subscribe_wake(&self) -> watch::Receiver<u64> {
|
||||
self.state.subscribe()
|
||||
self.session.subscribe()
|
||||
}
|
||||
|
||||
pub(crate) fn subscribe_events(&self) -> ExecProcessEventReceiver {
|
||||
self.state.subscribe_events()
|
||||
self.session.subscribe_events()
|
||||
}
|
||||
|
||||
pub(crate) async fn read(
|
||||
@@ -657,41 +920,104 @@ impl Session {
|
||||
max_bytes: Option<usize>,
|
||||
wait_ms: Option<u64>,
|
||||
) -> Result<ReadResponse, ExecServerError> {
|
||||
if let Some(response) = self.state.failed_response().await {
|
||||
if let Some(response) = self.session.failed_response().await {
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
match self
|
||||
.client
|
||||
.read(ReadParams {
|
||||
process_id: self.process_id.clone(),
|
||||
after_seq,
|
||||
max_bytes,
|
||||
wait_ms,
|
||||
})
|
||||
.await
|
||||
{
|
||||
let params = ReadParams {
|
||||
process_id: self.process_id.clone(),
|
||||
after_seq,
|
||||
max_bytes,
|
||||
wait_ms,
|
||||
};
|
||||
match self.control.read(params).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) if is_transport_closed_error(&err) => {
|
||||
Err(err)
|
||||
if is_transport_closed_error(&err)
|
||||
&& self.session.disconnect_behavior
|
||||
== ProcessSessionDisconnectBehavior::Fail =>
|
||||
{
|
||||
let message = disconnected_message(/*reason*/ None);
|
||||
self.state.set_failure(message.clone()).await;
|
||||
Ok(self.state.synthesized_failure(message))
|
||||
self.session.set_failure(message.clone()).await;
|
||||
Ok(self.session.synthesized_failure(message))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn write(&self, chunk: Vec<u8>) -> Result<WriteResponse, ExecServerError> {
|
||||
self.client.write(&self.process_id, chunk).await
|
||||
self.control.write(&self.process_id, chunk).await
|
||||
}
|
||||
|
||||
pub(crate) async fn terminate(&self) -> Result<(), ExecServerError> {
|
||||
self.client.terminate(&self.process_id).await?;
|
||||
self.control.terminate(&self.process_id).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn unregister(&self) {
|
||||
self.client.unregister_session(&self.process_id).await;
|
||||
self.control
|
||||
.unregister_process_session(&self.process_id)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
impl ProcessSessionControl {
|
||||
async fn read(&self, params: ReadParams) -> Result<ReadResponse, ExecServerError> {
|
||||
match self {
|
||||
#[cfg(test)]
|
||||
Self::Connection(connection) => connection.read(params).await,
|
||||
Self::RemoteClient(client) => client.read(params).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn write(
|
||||
&self,
|
||||
process_id: &ProcessId,
|
||||
chunk: Vec<u8>,
|
||||
) -> Result<WriteResponse, ExecServerError> {
|
||||
match self {
|
||||
#[cfg(test)]
|
||||
Self::Connection(connection) => connection.write(process_id, chunk).await,
|
||||
Self::RemoteClient(client) => client.write(process_id, chunk).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn terminate(
|
||||
&self,
|
||||
process_id: &ProcessId,
|
||||
) -> Result<TerminateResponse, ExecServerError> {
|
||||
match self {
|
||||
#[cfg(test)]
|
||||
Self::Connection(connection) => connection.terminate(process_id).await,
|
||||
Self::RemoteClient(client) => client.terminate(process_id).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn unregister_process_session(&self, process_id: &ProcessId) {
|
||||
match self {
|
||||
#[cfg(test)]
|
||||
Self::Connection(connection) => connection.unregister_process_session(process_id).await,
|
||||
Self::RemoteClient(client) => client.unregister_process_session(process_id).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TerminalReconnectError {
|
||||
fn from_error(error: &ExecServerError) -> Option<Self> {
|
||||
match error {
|
||||
ExecServerError::Server { code, message } if *code == -32600 => Some(Self {
|
||||
code: *code,
|
||||
message: message.clone(),
|
||||
}),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_exec_server_error(&self) -> ExecServerError {
|
||||
ExecServerError::Server {
|
||||
code: self.code,
|
||||
message: self.message.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -710,51 +1036,57 @@ impl Inner {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_session(&self, process_id: &ProcessId) -> Option<Arc<SessionState>> {
|
||||
self.sessions.load().get(process_id).cloned()
|
||||
fn get_process_session_route(&self, process_id: &ProcessId) -> Option<Arc<ProcessSession>> {
|
||||
self.process_session_routes.load().get(process_id).cloned()
|
||||
}
|
||||
|
||||
async fn insert_session(
|
||||
async fn insert_process_session_route(
|
||||
&self,
|
||||
process_id: &ProcessId,
|
||||
session: Arc<SessionState>,
|
||||
session: Arc<ProcessSession>,
|
||||
) -> Result<(), ExecServerError> {
|
||||
let _sessions_write_guard = self.sessions_write_lock.lock().await;
|
||||
let _routes_write_guard = self.process_session_routes_write_lock.lock().await;
|
||||
// Do not register a process session that can never receive executor
|
||||
// notifications. Without this check, remote MCP startup could create a
|
||||
// dead session and wait for process output that will never arrive.
|
||||
if let Some(error) = self.disconnected_error() {
|
||||
return Err(error);
|
||||
}
|
||||
let sessions = self.sessions.load();
|
||||
if sessions.contains_key(process_id) {
|
||||
let routes = self.process_session_routes.load();
|
||||
if let Some(existing_session) = routes.get(process_id) {
|
||||
if Arc::ptr_eq(existing_session, &session) {
|
||||
return Ok(());
|
||||
}
|
||||
return Err(ExecServerError::Protocol(format!(
|
||||
"session already registered for process {process_id}"
|
||||
)));
|
||||
}
|
||||
let mut next_sessions = sessions.as_ref().clone();
|
||||
next_sessions.insert(process_id.clone(), session);
|
||||
self.sessions.store(Arc::new(next_sessions));
|
||||
let mut next_routes = routes.as_ref().clone();
|
||||
next_routes.insert(process_id.clone(), session);
|
||||
self.process_session_routes.store(Arc::new(next_routes));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_session(&self, process_id: &ProcessId) -> Option<Arc<SessionState>> {
|
||||
let _sessions_write_guard = self.sessions_write_lock.lock().await;
|
||||
let sessions = self.sessions.load();
|
||||
let session = sessions.get(process_id).cloned();
|
||||
async fn remove_process_session_route(
|
||||
&self,
|
||||
process_id: &ProcessId,
|
||||
) -> Option<Arc<ProcessSession>> {
|
||||
let _routes_write_guard = self.process_session_routes_write_lock.lock().await;
|
||||
let routes = self.process_session_routes.load();
|
||||
let session = routes.get(process_id).cloned();
|
||||
session.as_ref()?;
|
||||
let mut next_sessions = sessions.as_ref().clone();
|
||||
next_sessions.remove(process_id);
|
||||
self.sessions.store(Arc::new(next_sessions));
|
||||
let mut next_routes = routes.as_ref().clone();
|
||||
next_routes.remove(process_id);
|
||||
self.process_session_routes.store(Arc::new(next_routes));
|
||||
session
|
||||
}
|
||||
|
||||
async fn take_all_sessions(&self) -> HashMap<ProcessId, Arc<SessionState>> {
|
||||
let _sessions_write_guard = self.sessions_write_lock.lock().await;
|
||||
let sessions = self.sessions.load();
|
||||
let drained_sessions = sessions.as_ref().clone();
|
||||
self.sessions.store(Arc::new(HashMap::new()));
|
||||
drained_sessions
|
||||
async fn take_all_process_session_routes(&self) -> HashMap<ProcessId, Arc<ProcessSession>> {
|
||||
let _routes_write_guard = self.process_session_routes_write_lock.lock().await;
|
||||
let routes = self.process_session_routes.load();
|
||||
let drained_routes = routes.as_ref().clone();
|
||||
self.process_session_routes.store(Arc::new(HashMap::new()));
|
||||
drained_routes
|
||||
}
|
||||
}
|
||||
|
||||
@@ -779,9 +1111,9 @@ fn is_transport_closed_error(error: &ExecServerError) -> bool {
|
||||
}
|
||||
|
||||
fn record_disconnected(inner: &Arc<Inner>, message: String) -> String {
|
||||
// The first observer records the canonical disconnect reason. Session
|
||||
// draining stays with the reader task so it can preserve notification
|
||||
// ordering before publishing the terminal failure.
|
||||
// The first observer records the canonical disconnect reason. Process
|
||||
// session route draining stays with the reader task so it can preserve
|
||||
// notification ordering before publishing the terminal failure.
|
||||
if let Some(message) = inner.set_disconnected(message.clone()) {
|
||||
message
|
||||
} else {
|
||||
@@ -789,20 +1121,22 @@ fn record_disconnected(inner: &Arc<Inner>, message: String) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
async fn fail_all_sessions(inner: &Arc<Inner>, message: String) {
|
||||
let sessions = inner.take_all_sessions().await;
|
||||
async fn fail_all_process_sessions(inner: &Arc<Inner>, message: String) {
|
||||
let routes = inner.take_all_process_session_routes().await;
|
||||
|
||||
for (_, session) in sessions {
|
||||
// Sessions synthesize a closed read response and emit a pushed Failed
|
||||
// event. That covers both polling consumers and streaming consumers
|
||||
// such as executor-backed MCP stdio.
|
||||
session.set_failure(message.clone()).await;
|
||||
for (_, session) in routes {
|
||||
// One-shot sessions synthesize a closed read response and emit a
|
||||
// pushed Failed event. Reconnecting remote sessions keep their local
|
||||
// event state so a reattached client can bind them again.
|
||||
if session.disconnect_behavior == ProcessSessionDisconnectBehavior::Fail {
|
||||
session.set_failure(message.clone()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fails all in-flight work that depends on the shared JSON-RPC transport.
|
||||
async fn fail_all_in_flight_work(inner: &Arc<Inner>, message: String) {
|
||||
fail_all_sessions(inner, message.clone()).await;
|
||||
fail_all_process_sessions(inner, message.clone()).await;
|
||||
inner.fail_all_http_body_streams(message).await;
|
||||
}
|
||||
|
||||
@@ -814,7 +1148,7 @@ async fn handle_server_notification(
|
||||
EXEC_OUTPUT_DELTA_METHOD => {
|
||||
let params: ExecOutputDeltaNotification =
|
||||
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
|
||||
if let Some(session) = inner.get_session(¶ms.process_id) {
|
||||
if let Some(session) = inner.get_process_session_route(¶ms.process_id) {
|
||||
session.note_change(params.seq);
|
||||
let published_closed =
|
||||
session.publish_ordered_event(ExecProcessEvent::Output(ProcessOutputChunk {
|
||||
@@ -823,28 +1157,28 @@ async fn handle_server_notification(
|
||||
chunk: params.chunk,
|
||||
}));
|
||||
if published_closed {
|
||||
inner.remove_session(¶ms.process_id).await;
|
||||
inner.remove_process_session_route(¶ms.process_id).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
EXEC_EXITED_METHOD => {
|
||||
let params: ExecExitedNotification =
|
||||
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
|
||||
if let Some(session) = inner.get_session(¶ms.process_id) {
|
||||
if let Some(session) = inner.get_process_session_route(¶ms.process_id) {
|
||||
session.note_change(params.seq);
|
||||
let published_closed = session.publish_ordered_event(ExecProcessEvent::Exited {
|
||||
seq: params.seq,
|
||||
exit_code: params.exit_code,
|
||||
});
|
||||
if published_closed {
|
||||
inner.remove_session(¶ms.process_id).await;
|
||||
inner.remove_process_session_route(¶ms.process_id).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
EXEC_CLOSED_METHOD => {
|
||||
let params: ExecClosedNotification =
|
||||
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
|
||||
if let Some(session) = inner.get_session(¶ms.process_id) {
|
||||
if let Some(session) = inner.get_process_session_route(¶ms.process_id) {
|
||||
session.note_change(params.seq);
|
||||
// Closed is terminal, but it can arrive before tail output or
|
||||
// exited. Keep routing this process until the ordered publisher
|
||||
@@ -852,7 +1186,7 @@ async fn handle_server_notification(
|
||||
let published_closed =
|
||||
session.publish_ordered_event(ExecProcessEvent::Closed { seq: params.seq });
|
||||
if published_closed {
|
||||
inner.remove_session(¶ms.process_id).await;
|
||||
inner.remove_process_session_route(¶ms.process_id).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -891,8 +1225,8 @@ mod tests {
|
||||
use tokio::time::sleep;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::ExecServerClient;
|
||||
use super::ExecServerClientConnectOptions;
|
||||
use super::ExecServerConnection;
|
||||
use crate::ProcessId;
|
||||
#[cfg(not(windows))]
|
||||
use crate::client_api::DEFAULT_REMOTE_EXEC_SERVER_INITIALIZE_TIMEOUT;
|
||||
@@ -940,7 +1274,7 @@ mod tests {
|
||||
#[cfg(not(windows))]
|
||||
#[tokio::test]
|
||||
async fn connect_stdio_command_initializes_json_rpc_client() {
|
||||
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
let client = ExecServerConnection::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
command: StdioExecServerCommand {
|
||||
program: "sh".to_string(),
|
||||
args: vec![
|
||||
@@ -963,7 +1297,7 @@ mod tests {
|
||||
#[cfg(not(windows))]
|
||||
#[tokio::test]
|
||||
async fn connect_for_transport_initializes_stdio_command() {
|
||||
let client = ExecServerClient::connect_for_transport(
|
||||
let client = ExecServerConnection::connect_for_transport(
|
||||
ExecServerTransportParams::StdioCommand {
|
||||
command: StdioExecServerCommand {
|
||||
program: "sh".to_string(),
|
||||
@@ -986,7 +1320,7 @@ mod tests {
|
||||
#[cfg(windows)]
|
||||
#[tokio::test]
|
||||
async fn connect_stdio_command_initializes_json_rpc_client_on_windows() {
|
||||
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
let client = ExecServerConnection::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
command: StdioExecServerCommand {
|
||||
program: "powershell".to_string(),
|
||||
args: vec![
|
||||
@@ -1024,7 +1358,7 @@ mod tests {
|
||||
shell_quote(child_pid_file.as_path()),
|
||||
);
|
||||
|
||||
let client = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
let client = ExecServerConnection::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
command: StdioExecServerCommand {
|
||||
program: "sh".to_string(),
|
||||
args: vec!["-c".to_string(), stdio_script],
|
||||
@@ -1067,7 +1401,7 @@ mod tests {
|
||||
shell_quote(pid_file.as_path()),
|
||||
);
|
||||
|
||||
let result = ExecServerClient::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
let result = ExecServerConnection::connect_stdio_command(StdioExecServerConnectArgs {
|
||||
command: StdioExecServerCommand {
|
||||
program: "sh".to_string(),
|
||||
args: vec!["-c".to_string(), stdio_script],
|
||||
@@ -1161,7 +1495,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let client = ExecServerClient::connect(
|
||||
let client = ExecServerConnection::connect(
|
||||
JsonRpcConnection::from_stdio(
|
||||
client_stdout,
|
||||
client_stdin,
|
||||
@@ -1174,7 +1508,7 @@ mod tests {
|
||||
|
||||
let process_id = ProcessId::from("reordered");
|
||||
let session = client
|
||||
.register_session(&process_id)
|
||||
.register_process_session(&process_id)
|
||||
.await
|
||||
.expect("session should register");
|
||||
let mut events = session.subscribe_events();
|
||||
@@ -1303,7 +1637,7 @@ mod tests {
|
||||
drop(server_writer);
|
||||
});
|
||||
|
||||
let client = ExecServerClient::connect(
|
||||
let client = ExecServerConnection::connect(
|
||||
JsonRpcConnection::from_stdio(
|
||||
client_stdout,
|
||||
client_stdin,
|
||||
@@ -1316,7 +1650,7 @@ mod tests {
|
||||
|
||||
let process_id = ProcessId::from("disconnect");
|
||||
let session = client
|
||||
.register_session(&process_id)
|
||||
.register_process_session(&process_id)
|
||||
.await
|
||||
.expect("session should register");
|
||||
let mut events = session.subscribe_events();
|
||||
@@ -1344,7 +1678,9 @@ mod tests {
|
||||
);
|
||||
assert!(response.closed);
|
||||
|
||||
let new_session = client.register_session(&ProcessId::from("new")).await;
|
||||
let new_session = client
|
||||
.register_process_session(&ProcessId::from("new"))
|
||||
.await;
|
||||
assert!(matches!(
|
||||
new_session,
|
||||
Err(super::ExecServerError::Disconnected(_))
|
||||
@@ -1390,7 +1726,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let client = ExecServerClient::connect(
|
||||
let client = ExecServerConnection::connect(
|
||||
JsonRpcConnection::from_stdio(
|
||||
client_stdout,
|
||||
client_stdin,
|
||||
@@ -1404,11 +1740,11 @@ mod tests {
|
||||
let noisy_process_id = ProcessId::from("noisy");
|
||||
let quiet_process_id = ProcessId::from("quiet");
|
||||
let _noisy_session = client
|
||||
.register_session(&noisy_process_id)
|
||||
.register_process_session(&noisy_process_id)
|
||||
.await
|
||||
.expect("noisy session should register");
|
||||
let quiet_session = client
|
||||
.register_session(&quiet_process_id)
|
||||
.register_process_session(&quiet_process_id)
|
||||
.await
|
||||
.expect("quiet session should register");
|
||||
let mut quiet_wake_rx = quiet_session.subscribe_wake();
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//! This module is the facade for the environment-owned [`crate::HttpClient`]
|
||||
//! capability:
|
||||
//! - [`ReqwestHttpClient`] executes requests directly with `reqwest`
|
||||
//! - [`ExecServerClient`] forwards requests over the JSON-RPC transport
|
||||
//! - [`ExecServerConnection`] forwards requests over the JSON-RPC transport
|
||||
//! - [`HttpResponseBodyStream`] presents buffered local bodies and streamed
|
||||
//! remote `http/request/bodyDelta` notifications through one byte-stream API
|
||||
//!
|
||||
@@ -11,7 +11,7 @@
|
||||
//! - orchestrator process: holds an `Arc<dyn HttpClient>` and chooses local or
|
||||
//! remote execution
|
||||
//! - remote runtime: serves the `http/request` RPC and runs the concrete local
|
||||
//! HTTP request there when the orchestrator uses [`ExecServerClient`]
|
||||
//! HTTP request there when the orchestrator uses [`ExecServerConnection`]
|
||||
|
||||
#[path = "reqwest_http_client.rs"]
|
||||
mod reqwest_http_client;
|
||||
|
||||
@@ -14,7 +14,7 @@ use tokio::sync::mpsc;
|
||||
use super::HttpResponseBodyStream;
|
||||
use super::response_body_stream::HttpBodyStreamRegistration;
|
||||
use crate::HttpClient;
|
||||
use crate::client::ExecServerClient;
|
||||
use crate::client::ExecServerConnection;
|
||||
use crate::client::ExecServerError;
|
||||
use crate::protocol::HTTP_REQUEST_METHOD;
|
||||
use crate::protocol::HttpRequestParams;
|
||||
@@ -23,7 +23,7 @@ use crate::protocol::HttpRequestResponse;
|
||||
/// Maximum queued body frames per streamed HTTP response.
|
||||
const HTTP_BODY_DELTA_CHANNEL_CAPACITY: usize = 256;
|
||||
|
||||
impl ExecServerClient {
|
||||
impl ExecServerConnection {
|
||||
/// Performs an HTTP request and buffers the response body.
|
||||
pub async fn http_request(
|
||||
&self,
|
||||
@@ -67,14 +67,14 @@ impl ExecServerClient {
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpClient for ExecServerClient {
|
||||
impl HttpClient for ExecServerConnection {
|
||||
/// Orchestrator-side adapter that forwards buffered HTTP requests to the
|
||||
/// remote runtime over the shared JSON-RPC connection.
|
||||
fn http_request(
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>> {
|
||||
async move { ExecServerClient::http_request(self, params).await }.boxed()
|
||||
async move { ExecServerConnection::http_request(self, params).await }.boxed()
|
||||
}
|
||||
|
||||
/// Orchestrator-side adapter that forwards streamed HTTP requests to the
|
||||
@@ -83,6 +83,6 @@ impl HttpClient for ExecServerClient {
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> {
|
||||
async move { ExecServerClient::http_request_stream(self, params).await }.boxed()
|
||||
async move { ExecServerConnection::http_request_stream(self, params).await }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use tracing::warn;
|
||||
|
||||
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
|
||||
|
||||
use crate::ExecServerClient;
|
||||
use crate::ExecServerConnection;
|
||||
use crate::ExecServerError;
|
||||
use crate::client_api::RemoteExecServerConnectArgs;
|
||||
use crate::client_api::StdioExecServerCommand;
|
||||
@@ -17,9 +17,9 @@ use crate::client_api::StdioExecServerConnectArgs;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::relay::harness_connection_from_websocket;
|
||||
|
||||
const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment";
|
||||
pub(crate) const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment";
|
||||
|
||||
impl ExecServerClient {
|
||||
impl ExecServerConnection {
|
||||
pub(crate) async fn connect_for_transport(
|
||||
transport_params: crate::client_api::ExecServerTransportParams,
|
||||
) -> Result<Self, ExecServerError> {
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::ExecServerError;
|
||||
use crate::ExecServerRuntimePaths;
|
||||
use crate::ExecutorFileSystem;
|
||||
use crate::HttpClient;
|
||||
use crate::client::LazyRemoteExecServerClient;
|
||||
use crate::client::RemoteExecServerClient;
|
||||
use crate::client::http_client::ReqwestHttpClient;
|
||||
use crate::client_api::ExecServerTransportParams;
|
||||
use crate::environment_provider::DefaultEnvironmentProvider;
|
||||
@@ -403,7 +403,7 @@ impl Environment {
|
||||
} => Some(exec_server_url.clone()),
|
||||
ExecServerTransportParams::StdioCommand { .. } => None,
|
||||
};
|
||||
let client = LazyRemoteExecServerClient::new(remote_transport.clone());
|
||||
let client = RemoteExecServerClient::new(remote_transport.clone());
|
||||
let exec_backend: Arc<dyn ExecBackend> = Arc::new(RemoteProcess::new(client.clone()));
|
||||
let filesystem: Arc<dyn ExecutorFileSystem> =
|
||||
Arc::new(RemoteFileSystem::new(client.clone()));
|
||||
|
||||
@@ -23,8 +23,9 @@ mod runtime_paths;
|
||||
mod sandboxed_file_system;
|
||||
mod server;
|
||||
|
||||
pub use client::ExecServerClient;
|
||||
pub use client::ExecServerConnection;
|
||||
pub use client::ExecServerError;
|
||||
pub type ExecServerClient = ExecServerConnection;
|
||||
pub use client::http_client::HttpResponseBodyStream;
|
||||
pub use client::http_client::ReqwestHttpClient;
|
||||
pub use client_api::ExecServerClientConnectOptions;
|
||||
|
||||
@@ -14,7 +14,7 @@ use crate::FileSystemResult;
|
||||
use crate::FileSystemSandboxContext;
|
||||
use crate::ReadDirectoryEntry;
|
||||
use crate::RemoveOptions;
|
||||
use crate::client::LazyRemoteExecServerClient;
|
||||
use crate::client::RemoteExecServerClient;
|
||||
use crate::protocol::FsCopyParams;
|
||||
use crate::protocol::FsCreateDirectoryParams;
|
||||
use crate::protocol::FsGetMetadataParams;
|
||||
@@ -28,11 +28,11 @@ const NOT_FOUND_ERROR_CODE: i64 = -32004;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RemoteFileSystem {
|
||||
client: LazyRemoteExecServerClient,
|
||||
client: RemoteExecServerClient,
|
||||
}
|
||||
|
||||
impl RemoteFileSystem {
|
||||
pub(crate) fn new(client: LazyRemoteExecServerClient) -> Self {
|
||||
pub(crate) fn new(client: RemoteExecServerClient) -> Self {
|
||||
trace!("remote fs new");
|
||||
Self { client }
|
||||
}
|
||||
@@ -46,8 +46,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
|
||||
sandbox: Option<&FileSystemSandboxContext>,
|
||||
) -> FileSystemResult<Vec<u8>> {
|
||||
trace!("remote fs read_file");
|
||||
let client = self.client.get().await.map_err(map_remote_error)?;
|
||||
let response = client
|
||||
let connection = self.client.connection().await.map_err(map_remote_error)?;
|
||||
let response = connection
|
||||
.fs_read_file(FsReadFileParams {
|
||||
path: path.clone(),
|
||||
sandbox: remote_sandbox_context(sandbox),
|
||||
@@ -69,8 +69,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
|
||||
sandbox: Option<&FileSystemSandboxContext>,
|
||||
) -> FileSystemResult<()> {
|
||||
trace!("remote fs write_file");
|
||||
let client = self.client.get().await.map_err(map_remote_error)?;
|
||||
client
|
||||
let connection = self.client.connection().await.map_err(map_remote_error)?;
|
||||
connection
|
||||
.fs_write_file(FsWriteFileParams {
|
||||
path: path.clone(),
|
||||
data_base64: STANDARD.encode(contents),
|
||||
@@ -88,8 +88,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
|
||||
sandbox: Option<&FileSystemSandboxContext>,
|
||||
) -> FileSystemResult<()> {
|
||||
trace!("remote fs create_directory");
|
||||
let client = self.client.get().await.map_err(map_remote_error)?;
|
||||
client
|
||||
let connection = self.client.connection().await.map_err(map_remote_error)?;
|
||||
connection
|
||||
.fs_create_directory(FsCreateDirectoryParams {
|
||||
path: path.clone(),
|
||||
recursive: Some(options.recursive),
|
||||
@@ -106,8 +106,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
|
||||
sandbox: Option<&FileSystemSandboxContext>,
|
||||
) -> FileSystemResult<FileMetadata> {
|
||||
trace!("remote fs get_metadata");
|
||||
let client = self.client.get().await.map_err(map_remote_error)?;
|
||||
let response = client
|
||||
let connection = self.client.connection().await.map_err(map_remote_error)?;
|
||||
let response = connection
|
||||
.fs_get_metadata(FsGetMetadataParams {
|
||||
path: path.clone(),
|
||||
sandbox: remote_sandbox_context(sandbox),
|
||||
@@ -129,8 +129,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
|
||||
sandbox: Option<&FileSystemSandboxContext>,
|
||||
) -> FileSystemResult<Vec<ReadDirectoryEntry>> {
|
||||
trace!("remote fs read_directory");
|
||||
let client = self.client.get().await.map_err(map_remote_error)?;
|
||||
let response = client
|
||||
let connection = self.client.connection().await.map_err(map_remote_error)?;
|
||||
let response = connection
|
||||
.fs_read_directory(FsReadDirectoryParams {
|
||||
path: path.clone(),
|
||||
sandbox: remote_sandbox_context(sandbox),
|
||||
@@ -155,8 +155,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
|
||||
sandbox: Option<&FileSystemSandboxContext>,
|
||||
) -> FileSystemResult<()> {
|
||||
trace!("remote fs remove");
|
||||
let client = self.client.get().await.map_err(map_remote_error)?;
|
||||
client
|
||||
let connection = self.client.connection().await.map_err(map_remote_error)?;
|
||||
connection
|
||||
.fs_remove(FsRemoveParams {
|
||||
path: path.clone(),
|
||||
recursive: Some(options.recursive),
|
||||
@@ -176,8 +176,8 @@ impl ExecutorFileSystem for RemoteFileSystem {
|
||||
sandbox: Option<&FileSystemSandboxContext>,
|
||||
) -> FileSystemResult<()> {
|
||||
trace!("remote fs copy");
|
||||
let client = self.client.get().await.map_err(map_remote_error)?;
|
||||
client
|
||||
let connection = self.client.connection().await.map_err(map_remote_error)?;
|
||||
connection
|
||||
.fs_copy(FsCopyParams {
|
||||
source_path: source_path.clone(),
|
||||
destination_path: destination_path.clone(),
|
||||
|
||||
@@ -9,23 +9,23 @@ use crate::ExecProcess;
|
||||
use crate::ExecProcessEventReceiver;
|
||||
use crate::ExecServerError;
|
||||
use crate::StartedExecProcess;
|
||||
use crate::client::LazyRemoteExecServerClient;
|
||||
use crate::client::Session;
|
||||
use crate::client::ProcessSessionHandle;
|
||||
use crate::client::RemoteExecServerClient;
|
||||
use crate::protocol::ExecParams;
|
||||
use crate::protocol::ReadResponse;
|
||||
use crate::protocol::WriteResponse;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RemoteProcess {
|
||||
client: LazyRemoteExecServerClient,
|
||||
client: RemoteExecServerClient,
|
||||
}
|
||||
|
||||
struct RemoteExecProcess {
|
||||
session: Session,
|
||||
session: ProcessSessionHandle,
|
||||
}
|
||||
|
||||
impl RemoteProcess {
|
||||
pub(crate) fn new(client: LazyRemoteExecServerClient) -> Self {
|
||||
pub(crate) fn new(client: RemoteExecServerClient) -> Self {
|
||||
trace!("remote process new");
|
||||
Self { client }
|
||||
}
|
||||
@@ -35,9 +35,9 @@ impl RemoteProcess {
|
||||
impl ExecBackend for RemoteProcess {
|
||||
async fn start(&self, params: ExecParams) -> Result<StartedExecProcess, ExecServerError> {
|
||||
let process_id = params.process_id.clone();
|
||||
let client = self.client.get().await?;
|
||||
let session = client.register_session(&process_id).await?;
|
||||
if let Err(err) = client.exec(params).await {
|
||||
let session = self.client.register_process_session(&process_id).await?;
|
||||
let connection = self.client.connection().await?;
|
||||
if let Err(err) = connection.exec(params).await {
|
||||
session.unregister().await;
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
@@ -541,7 +541,7 @@ async fn assert_exec_process_preserves_queued_events_before_subscribe(
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
// Serialize tests that launch a real exec-server process through the full CLI.
|
||||
#[serial_test::serial(remote_exec_server)]
|
||||
async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
|
||||
async fn remote_exec_process_surfaces_reconnect_failure_after_server_shutdown() -> Result<()> {
|
||||
let mut context = create_process_context(/*use_remote*/ true).await?;
|
||||
let session = context
|
||||
.backend
|
||||
@@ -562,7 +562,6 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
|
||||
.await?;
|
||||
|
||||
let process = Arc::clone(&session.process);
|
||||
let mut events = process.subscribe_events();
|
||||
let process_for_pending_read = Arc::clone(&process);
|
||||
let pending_read = tokio::spawn(async move {
|
||||
process_for_pending_read
|
||||
@@ -579,36 +578,33 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
|
||||
.expect("remote context should include exec-server harness");
|
||||
server.shutdown().await?;
|
||||
|
||||
let event = timeout(Duration::from_secs(2), events.recv()).await??;
|
||||
let ExecProcessEvent::Failed(event_message) = event else {
|
||||
anyhow::bail!("expected process failure event, got {event:?}");
|
||||
};
|
||||
let pending_error = timeout(Duration::from_secs(2), pending_read)
|
||||
.await
|
||||
.context("timed out waiting for pending read after reconnect failure")??
|
||||
.expect_err("pending read should fail after reconnect fails");
|
||||
assert!(
|
||||
event_message.starts_with("exec-server transport disconnected"),
|
||||
"unexpected failure event: {event_message}"
|
||||
pending_error
|
||||
.to_string()
|
||||
.starts_with("failed to connect to exec-server websocket"),
|
||||
"unexpected pending read error: {pending_error}"
|
||||
);
|
||||
|
||||
let pending_response = timeout(Duration::from_secs(2), pending_read).await???;
|
||||
let pending_message = pending_response
|
||||
.failure
|
||||
.expect("pending read should surface disconnect as a failure");
|
||||
let read_error = timeout(
|
||||
Duration::from_secs(2),
|
||||
process.read(
|
||||
/*after_seq*/ None,
|
||||
/*max_bytes*/ None,
|
||||
/*wait_ms*/ Some(0),
|
||||
),
|
||||
)
|
||||
.await
|
||||
.context("timed out waiting for read after reconnect failure")?
|
||||
.expect_err("read after reconnect failure should fail");
|
||||
assert!(
|
||||
pending_message.starts_with("exec-server transport disconnected"),
|
||||
"unexpected pending failure message: {pending_message}"
|
||||
);
|
||||
|
||||
let mut wake_rx = process.subscribe_wake();
|
||||
let response = read_process_until_change(process, &mut wake_rx, /*after_seq*/ None).await?;
|
||||
let message = response
|
||||
.failure
|
||||
.expect("disconnect should surface as a failure");
|
||||
assert!(
|
||||
message.starts_with("exec-server transport disconnected"),
|
||||
"unexpected failure message: {message}"
|
||||
);
|
||||
assert!(
|
||||
response.closed,
|
||||
"disconnect should close the process session"
|
||||
read_error
|
||||
.to_string()
|
||||
.starts_with("failed to connect to exec-server websocket"),
|
||||
"unexpected read error: {read_error}"
|
||||
);
|
||||
|
||||
let write_result = timeout(
|
||||
@@ -621,7 +617,7 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
|
||||
assert!(
|
||||
write_error
|
||||
.to_string()
|
||||
.starts_with("exec-server transport disconnected"),
|
||||
.starts_with("failed to connect to exec-server websocket"),
|
||||
"unexpected write error: {write_error}"
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user