mirror of
https://github.com/openai/codex.git
synced 2026-04-23 22:24:57 +00:00
feat: move exec-server ownership (#16344)
This introduces session-scoped ownership for exec-server so ws disconnects no longer immediately kill running remote exec processes, and it prepares the protocol for reconnect-based resume. - add session_id / resume_session_id to the exec-server initialize handshake - move process ownership under a shared session registry - detach sessions on websocket disconnect and expire them after a TTL instead of killing processes immediately (we will resume based on this) - allow a new connection to resume an existing session and take over notifications/ownership - I use UUID to make them not predictable as we don't have auth for now - make detached-session expiry authoritative at resume time so teardown wins at the TTL boundary - reject long-poll process/read calls that get resumed out from under an older attachment --------- Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
1
codex-rs/Cargo.lock
generated
1
codex-rs/Cargo.lock
generated
@@ -2124,6 +2124,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tracing",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -40,6 +40,7 @@ tokio = { workspace = true, features = [
|
||||
] }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
uuid = { workspace = true, features = ["v4"] }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
||||
@@ -71,6 +71,7 @@ impl Default for ExecServerClientConnectOptions {
|
||||
Self {
|
||||
client_name: "codex-core".to_string(),
|
||||
initialize_timeout: INITIALIZE_TIMEOUT,
|
||||
resume_session_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -80,6 +81,7 @@ impl From<RemoteExecServerConnectArgs> for ExecServerClientConnectOptions {
|
||||
Self {
|
||||
client_name: value.client_name,
|
||||
initialize_timeout: value.initialize_timeout,
|
||||
resume_session_id: value.resume_session_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -91,6 +93,7 @@ impl RemoteExecServerConnectArgs {
|
||||
client_name,
|
||||
connect_timeout: CONNECT_TIMEOUT,
|
||||
initialize_timeout: INITIALIZE_TIMEOUT,
|
||||
resume_session_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -118,6 +121,7 @@ struct Inner {
|
||||
// need serialization so concurrent register/remove operations do not
|
||||
// overwrite each other's copy-on-write updates.
|
||||
sessions_write_lock: Mutex<()>,
|
||||
session_id: std::sync::RwLock<Option<String>>,
|
||||
reader_task: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
@@ -190,14 +194,29 @@ impl ExecServerClient {
|
||||
let ExecServerClientConnectOptions {
|
||||
client_name,
|
||||
initialize_timeout,
|
||||
resume_session_id,
|
||||
} = options;
|
||||
|
||||
timeout(initialize_timeout, async {
|
||||
let response = self
|
||||
let response: InitializeResponse = self
|
||||
.inner
|
||||
.client
|
||||
.call(INITIALIZE_METHOD, &InitializeParams { client_name })
|
||||
.call(
|
||||
INITIALIZE_METHOD,
|
||||
&InitializeParams {
|
||||
client_name,
|
||||
resume_session_id,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
{
|
||||
let mut session_id = self
|
||||
.inner
|
||||
.session_id
|
||||
.write()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
*session_id = Some(response.session_id.clone());
|
||||
}
|
||||
self.notify_initialized().await?;
|
||||
Ok(response)
|
||||
})
|
||||
@@ -350,6 +369,14 @@ impl ExecServerClient {
|
||||
self.inner.remove_session(process_id).await;
|
||||
}
|
||||
|
||||
pub fn session_id(&self) -> Option<String> {
|
||||
self.inner
|
||||
.session_id
|
||||
.read()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.clone()
|
||||
}
|
||||
|
||||
async fn connect(
|
||||
connection: JsonRpcConnection,
|
||||
options: ExecServerClientConnectOptions,
|
||||
@@ -388,6 +415,7 @@ impl ExecServerClient {
|
||||
client: rpc_client,
|
||||
sessions: ArcSwap::from_pointee(HashMap::new()),
|
||||
sessions_write_lock: Mutex::new(()),
|
||||
session_id: std::sync::RwLock::new(None),
|
||||
reader_task,
|
||||
}
|
||||
});
|
||||
@@ -693,8 +721,10 @@ mod tests {
|
||||
&mut server_writer,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: request.id,
|
||||
result: serde_json::to_value(InitializeResponse {})
|
||||
.expect("initialize response should serialize"),
|
||||
result: serde_json::to_value(InitializeResponse {
|
||||
session_id: "session-1".to_string(),
|
||||
})
|
||||
.expect("initialize response should serialize"),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::time::Duration;
|
||||
pub struct ExecServerClientConnectOptions {
|
||||
pub client_name: String,
|
||||
pub initialize_timeout: Duration,
|
||||
pub resume_session_id: Option<String>,
|
||||
}
|
||||
|
||||
/// WebSocket connection arguments for a remote exec-server.
|
||||
@@ -14,4 +15,5 @@ pub struct RemoteExecServerConnectArgs {
|
||||
pub client_name: String,
|
||||
pub connect_timeout: Duration,
|
||||
pub initialize_timeout: Duration,
|
||||
pub resume_session_id: Option<String>,
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use futures::StreamExt;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::watch;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
@@ -28,6 +29,7 @@ pub(crate) enum JsonRpcConnectionEvent {
|
||||
pub(crate) struct JsonRpcConnection {
|
||||
outgoing_tx: mpsc::Sender<JSONRPCMessage>,
|
||||
incoming_rx: mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
disconnected_rx: watch::Receiver<bool>,
|
||||
task_handles: Vec<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
@@ -40,9 +42,11 @@ impl JsonRpcConnection {
|
||||
{
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (disconnected_tx, disconnected_rx) = watch::channel(false);
|
||||
|
||||
let reader_label = connection_label.clone();
|
||||
let incoming_tx_for_reader = incoming_tx.clone();
|
||||
let disconnected_tx_for_reader = disconnected_tx.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
let mut lines = BufReader::new(reader).lines();
|
||||
loop {
|
||||
@@ -73,12 +77,18 @@ impl JsonRpcConnection {
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await;
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
Some(format!(
|
||||
"failed to read JSON-RPC message from {reader_label}: {err}"
|
||||
)),
|
||||
@@ -96,6 +106,7 @@ impl JsonRpcConnection {
|
||||
if let Err(err) = write_jsonrpc_line_message(&mut writer, &message).await {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to write JSON-RPC message to {connection_label}: {err}"
|
||||
)),
|
||||
@@ -109,6 +120,7 @@ impl JsonRpcConnection {
|
||||
Self {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
}
|
||||
}
|
||||
@@ -119,10 +131,12 @@ impl JsonRpcConnection {
|
||||
{
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (disconnected_tx, disconnected_rx) = watch::channel(false);
|
||||
let (mut websocket_writer, mut websocket_reader) = stream.split();
|
||||
|
||||
let reader_label = connection_label.clone();
|
||||
let incoming_tx_for_reader = incoming_tx.clone();
|
||||
let disconnected_tx_for_reader = disconnected_tx.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match websocket_reader.next().await {
|
||||
@@ -171,7 +185,12 @@ impl JsonRpcConnection {
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) => {
|
||||
send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await;
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {}
|
||||
@@ -179,6 +198,7 @@ impl JsonRpcConnection {
|
||||
Some(Err(err)) => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
Some(format!(
|
||||
"failed to read websocket JSON-RPC message from {reader_label}: {err}"
|
||||
)),
|
||||
@@ -187,7 +207,12 @@ impl JsonRpcConnection {
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await;
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -202,6 +227,7 @@ impl JsonRpcConnection {
|
||||
{
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to write websocket JSON-RPC message to {connection_label}: {err}"
|
||||
)),
|
||||
@@ -213,6 +239,7 @@ impl JsonRpcConnection {
|
||||
Err(err) => {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to serialize JSON-RPC message for {connection_label}: {err}"
|
||||
)),
|
||||
@@ -227,6 +254,7 @@ impl JsonRpcConnection {
|
||||
Self {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
}
|
||||
}
|
||||
@@ -236,16 +264,24 @@ impl JsonRpcConnection {
|
||||
) -> (
|
||||
mpsc::Sender<JSONRPCMessage>,
|
||||
mpsc::Receiver<JsonRpcConnectionEvent>,
|
||||
watch::Receiver<bool>,
|
||||
Vec<tokio::task::JoinHandle<()>>,
|
||||
) {
|
||||
(self.outgoing_tx, self.incoming_rx, self.task_handles)
|
||||
(
|
||||
self.outgoing_tx,
|
||||
self.incoming_rx,
|
||||
self.disconnected_rx,
|
||||
self.task_handles,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_disconnected(
|
||||
incoming_tx: &mpsc::Sender<JsonRpcConnectionEvent>,
|
||||
disconnected_tx: &watch::Sender<bool>,
|
||||
reason: Option<String>,
|
||||
) {
|
||||
let _ = disconnected_tx.send(true);
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Disconnected { reason })
|
||||
.await;
|
||||
|
||||
@@ -105,18 +105,10 @@ pub struct Environment {
|
||||
|
||||
impl Default for Environment {
|
||||
fn default() -> Self {
|
||||
let local_process = LocalProcess::default();
|
||||
if let Err(err) = local_process.initialize() {
|
||||
panic!("default local process initialization should succeed: {err:?}");
|
||||
}
|
||||
if let Err(err) = local_process.initialized() {
|
||||
panic!("default local process should accept initialized notification: {err}");
|
||||
}
|
||||
|
||||
Self {
|
||||
exec_server_url: None,
|
||||
remote_exec_server_client: None,
|
||||
exec_backend: Arc::new(local_process),
|
||||
exec_backend: Arc::new(LocalProcess::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -146,6 +138,7 @@ impl Environment {
|
||||
client_name: "codex-environment".to_string(),
|
||||
connect_timeout: std::time::Duration::from_secs(5),
|
||||
initialize_timeout: std::time::Duration::from_secs(5),
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await?,
|
||||
)
|
||||
@@ -153,24 +146,12 @@ impl Environment {
|
||||
None
|
||||
};
|
||||
|
||||
let exec_backend: Arc<dyn ExecBackend> = match remote_exec_server_client.clone() {
|
||||
Some(client) => Arc::new(RemoteProcess::new(client)),
|
||||
None if exec_server_url.is_some() => {
|
||||
return Err(ExecServerError::Protocol(
|
||||
"remote mode should have an exec-server client".to_string(),
|
||||
));
|
||||
}
|
||||
None => {
|
||||
let local_process = LocalProcess::default();
|
||||
local_process
|
||||
.initialize()
|
||||
.map_err(|err| ExecServerError::Protocol(err.message))?;
|
||||
local_process
|
||||
.initialized()
|
||||
.map_err(ExecServerError::Protocol)?;
|
||||
Arc::new(local_process)
|
||||
}
|
||||
};
|
||||
let exec_backend: Arc<dyn ExecBackend> =
|
||||
if let Some(client) = remote_exec_server_client.clone() {
|
||||
Arc::new(RemoteProcess::new(client))
|
||||
} else {
|
||||
Arc::new(LocalProcess::default())
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
exec_server_url,
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
@@ -26,7 +24,6 @@ use crate::protocol::ExecOutputDeltaNotification;
|
||||
use crate::protocol::ExecOutputStream;
|
||||
use crate::protocol::ExecParams;
|
||||
use crate::protocol::ExecResponse;
|
||||
use crate::protocol::InitializeResponse;
|
||||
use crate::protocol::ProcessOutputChunk;
|
||||
use crate::protocol::ReadParams;
|
||||
use crate::protocol::ReadResponse;
|
||||
@@ -74,10 +71,8 @@ enum ProcessEntry {
|
||||
}
|
||||
|
||||
struct Inner {
|
||||
notifications: RpcNotificationSender,
|
||||
notifications: std::sync::RwLock<Option<RpcNotificationSender>>,
|
||||
processes: Mutex<HashMap<ProcessId, ProcessEntry>>,
|
||||
initialize_requested: AtomicBool,
|
||||
initialized: AtomicBool,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -104,10 +99,8 @@ impl LocalProcess {
|
||||
pub(crate) fn new(notifications: RpcNotificationSender) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(Inner {
|
||||
notifications,
|
||||
notifications: std::sync::RwLock::new(Some(notifications)),
|
||||
processes: Mutex::new(HashMap::new()),
|
||||
initialize_requested: AtomicBool::new(false),
|
||||
initialized: AtomicBool::new(false),
|
||||
}),
|
||||
}
|
||||
}
|
||||
@@ -128,45 +121,19 @@ impl LocalProcess {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn initialize(&self) -> Result<InitializeResponse, JSONRPCErrorError> {
|
||||
if self.inner.initialize_requested.swap(true, Ordering::SeqCst) {
|
||||
return Err(invalid_request(
|
||||
"initialize may only be sent once per connection".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(InitializeResponse {})
|
||||
}
|
||||
|
||||
pub(crate) fn initialized(&self) -> Result<(), String> {
|
||||
if !self.inner.initialize_requested.load(Ordering::SeqCst) {
|
||||
return Err("received `initialized` notification before `initialize`".into());
|
||||
}
|
||||
self.inner.initialized.store(true, Ordering::SeqCst);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn require_initialized_for(
|
||||
&self,
|
||||
method_family: &str,
|
||||
) -> Result<(), JSONRPCErrorError> {
|
||||
if !self.inner.initialize_requested.load(Ordering::SeqCst) {
|
||||
return Err(invalid_request(format!(
|
||||
"client must call initialize before using {method_family} methods"
|
||||
)));
|
||||
}
|
||||
if !self.inner.initialized.load(Ordering::SeqCst) {
|
||||
return Err(invalid_request(format!(
|
||||
"client must send initialized before using {method_family} methods"
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
pub(crate) fn set_notification_sender(&self, notifications: Option<RpcNotificationSender>) {
|
||||
let mut notification_sender = self
|
||||
.inner
|
||||
.notifications
|
||||
.write()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
*notification_sender = notifications;
|
||||
}
|
||||
|
||||
async fn start_process(
|
||||
&self,
|
||||
params: ExecParams,
|
||||
) -> Result<(ExecResponse, watch::Sender<u64>), JSONRPCErrorError> {
|
||||
self.require_initialized_for("exec")?;
|
||||
let process_id = params.process_id.clone();
|
||||
let (program, args) = params
|
||||
.argv
|
||||
@@ -277,7 +244,6 @@ impl LocalProcess {
|
||||
&self,
|
||||
params: ReadParams,
|
||||
) -> Result<ReadResponse, JSONRPCErrorError> {
|
||||
self.require_initialized_for("exec")?;
|
||||
let _process_id = params.process_id.clone();
|
||||
let after_seq = params.after_seq.unwrap_or(0);
|
||||
let max_bytes = params.max_bytes.unwrap_or(usize::MAX);
|
||||
@@ -354,7 +320,6 @@ impl LocalProcess {
|
||||
&self,
|
||||
params: WriteParams,
|
||||
) -> Result<WriteResponse, JSONRPCErrorError> {
|
||||
self.require_initialized_for("exec")?;
|
||||
let _process_id = params.process_id.clone();
|
||||
let _input_bytes = params.chunk.0.len();
|
||||
let writer_tx = {
|
||||
@@ -391,7 +356,6 @@ impl LocalProcess {
|
||||
&self,
|
||||
params: TerminateParams,
|
||||
) -> Result<TerminateResponse, JSONRPCErrorError> {
|
||||
self.require_initialized_for("exec")?;
|
||||
let _process_id = params.process_id.clone();
|
||||
let running = {
|
||||
let process_map = self.inner.processes.lock().await;
|
||||
@@ -546,13 +510,10 @@ async fn stream_output(
|
||||
}
|
||||
};
|
||||
output_notify.notify_waiters();
|
||||
if inner
|
||||
.notifications
|
||||
.notify(crate::protocol::EXEC_OUTPUT_DELTA_METHOD, ¬ification)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
if let Some(notifications) = notification_sender(&inner) {
|
||||
let _ = notifications
|
||||
.notify(crate::protocol::EXEC_OUTPUT_DELTA_METHOD, ¬ification)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -584,13 +545,11 @@ async fn watch_exit(
|
||||
};
|
||||
output_notify.notify_waiters();
|
||||
if let Some(notification) = notification
|
||||
&& inner
|
||||
.notifications
|
||||
.notify(crate::protocol::EXEC_EXITED_METHOD, ¬ification)
|
||||
.await
|
||||
.is_err()
|
||||
&& let Some(notifications) = notification_sender(&inner)
|
||||
{
|
||||
return;
|
||||
let _ = notifications
|
||||
.notify(crate::protocol::EXEC_EXITED_METHOD, ¬ification)
|
||||
.await;
|
||||
}
|
||||
|
||||
maybe_emit_closed(process_id.clone(), Arc::clone(&inner)).await;
|
||||
@@ -645,10 +604,17 @@ async fn maybe_emit_closed(process_id: ProcessId, inner: Arc<Inner>) {
|
||||
return;
|
||||
};
|
||||
|
||||
if inner
|
||||
.notifications
|
||||
.notify(EXEC_CLOSED_METHOD, ¬ification)
|
||||
.await
|
||||
.is_err()
|
||||
{}
|
||||
if let Some(notifications) = notification_sender(&inner) {
|
||||
let _ = notifications
|
||||
.notify(EXEC_CLOSED_METHOD, ¬ification)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
fn notification_sender(inner: &Inner) -> Option<RpcNotificationSender> {
|
||||
inner
|
||||
.notifications
|
||||
.read()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.clone()
|
||||
}
|
||||
|
||||
@@ -46,11 +46,15 @@ impl From<Vec<u8>> for ByteChunk {
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InitializeParams {
|
||||
pub client_name: String,
|
||||
#[serde(default)]
|
||||
pub resume_session_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InitializeResponse {}
|
||||
pub struct InitializeResponse {
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
|
||||
@@ -179,7 +179,8 @@ pub(crate) struct RpcClient {
|
||||
|
||||
impl RpcClient {
|
||||
pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver<RpcClientEvent>) {
|
||||
let (write_tx, mut incoming_rx, transport_tasks) = connection.into_parts();
|
||||
let (write_tx, mut incoming_rx, _disconnected_rx, transport_tasks) =
|
||||
connection.into_parts();
|
||||
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
|
||||
let (event_tx, event_rx) = mpsc::channel(128);
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ mod handler;
|
||||
mod process_handler;
|
||||
mod processor;
|
||||
mod registry;
|
||||
mod session_registry;
|
||||
mod transport;
|
||||
|
||||
pub(crate) use handler::ExecServerHandler;
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
|
||||
use crate::protocol::ExecParams;
|
||||
@@ -16,6 +21,7 @@ use crate::protocol::FsRemoveParams;
|
||||
use crate::protocol::FsRemoveResponse;
|
||||
use crate::protocol::FsWriteFileParams;
|
||||
use crate::protocol::FsWriteFileResponse;
|
||||
use crate::protocol::InitializeParams;
|
||||
use crate::protocol::InitializeResponse;
|
||||
use crate::protocol::ReadParams;
|
||||
use crate::protocol::ReadResponse;
|
||||
@@ -24,65 +30,126 @@ use crate::protocol::TerminateResponse;
|
||||
use crate::protocol::WriteParams;
|
||||
use crate::protocol::WriteResponse;
|
||||
use crate::rpc::RpcNotificationSender;
|
||||
use crate::rpc::invalid_request;
|
||||
use crate::server::file_system_handler::FileSystemHandler;
|
||||
use crate::server::process_handler::ProcessHandler;
|
||||
use crate::server::session_registry::SessionHandle;
|
||||
use crate::server::session_registry::SessionRegistry;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ExecServerHandler {
|
||||
process: ProcessHandler,
|
||||
session_registry: Arc<SessionRegistry>,
|
||||
notifications: RpcNotificationSender,
|
||||
session: StdMutex<Option<SessionHandle>>,
|
||||
file_system: FileSystemHandler,
|
||||
initialize_requested: AtomicBool,
|
||||
initialized: AtomicBool,
|
||||
}
|
||||
|
||||
impl ExecServerHandler {
|
||||
pub(crate) fn new(notifications: RpcNotificationSender) -> Self {
|
||||
pub(crate) fn new(
|
||||
session_registry: Arc<SessionRegistry>,
|
||||
notifications: RpcNotificationSender,
|
||||
) -> Self {
|
||||
Self {
|
||||
process: ProcessHandler::new(notifications),
|
||||
session_registry,
|
||||
notifications,
|
||||
session: StdMutex::new(None),
|
||||
file_system: FileSystemHandler::default(),
|
||||
initialize_requested: AtomicBool::new(false),
|
||||
initialized: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn shutdown(&self) {
|
||||
self.process.shutdown().await;
|
||||
if let Some(session) = self.session() {
|
||||
session.detach().await;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn initialize(&self) -> Result<InitializeResponse, JSONRPCErrorError> {
|
||||
self.process.initialize()
|
||||
pub(crate) fn is_session_attached(&self) -> bool {
|
||||
self.session()
|
||||
.is_none_or(|session| session.is_session_attached())
|
||||
}
|
||||
|
||||
pub(crate) async fn initialize(
|
||||
&self,
|
||||
params: InitializeParams,
|
||||
) -> Result<InitializeResponse, JSONRPCErrorError> {
|
||||
if self.initialize_requested.swap(true, Ordering::SeqCst) {
|
||||
return Err(invalid_request(
|
||||
"initialize may only be sent once per connection".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let session = match self
|
||||
.session_registry
|
||||
.attach(params.resume_session_id.clone(), self.notifications.clone())
|
||||
.await
|
||||
{
|
||||
Ok(session) => session,
|
||||
Err(error) => {
|
||||
self.initialize_requested.store(false, Ordering::SeqCst);
|
||||
return Err(error);
|
||||
}
|
||||
};
|
||||
let session_id = session.session_id().to_string();
|
||||
tracing::debug!(
|
||||
session_id,
|
||||
connection_id = %session.connection_id(),
|
||||
"exec-server session attached"
|
||||
);
|
||||
*self
|
||||
.session
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(session);
|
||||
Ok(InitializeResponse { session_id })
|
||||
}
|
||||
|
||||
pub(crate) fn initialized(&self) -> Result<(), String> {
|
||||
self.process.initialized()
|
||||
if !self.initialize_requested.load(Ordering::SeqCst) {
|
||||
return Err("received `initialized` notification before `initialize`".into());
|
||||
}
|
||||
self.require_session_attached()
|
||||
.map_err(|error| error.message)?;
|
||||
self.initialized.store(true, Ordering::SeqCst);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn exec(&self, params: ExecParams) -> Result<ExecResponse, JSONRPCErrorError> {
|
||||
self.process.exec(params).await
|
||||
let session = self.require_initialized_for("exec")?;
|
||||
session.process().exec(params).await
|
||||
}
|
||||
|
||||
pub(crate) async fn exec_read(
|
||||
&self,
|
||||
params: ReadParams,
|
||||
) -> Result<ReadResponse, JSONRPCErrorError> {
|
||||
self.process.exec_read(params).await
|
||||
let session = self.require_initialized_for("exec")?;
|
||||
let response = session.process().exec_read(params).await?;
|
||||
self.require_session_attached()?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub(crate) async fn exec_write(
|
||||
&self,
|
||||
params: WriteParams,
|
||||
) -> Result<WriteResponse, JSONRPCErrorError> {
|
||||
self.process.exec_write(params).await
|
||||
let session = self.require_initialized_for("exec")?;
|
||||
session.process().exec_write(params).await
|
||||
}
|
||||
|
||||
pub(crate) async fn terminate(
|
||||
&self,
|
||||
params: TerminateParams,
|
||||
) -> Result<TerminateResponse, JSONRPCErrorError> {
|
||||
self.process.terminate(params).await
|
||||
let session = self.require_initialized_for("exec")?;
|
||||
session.process().terminate(params).await
|
||||
}
|
||||
|
||||
pub(crate) async fn fs_read_file(
|
||||
&self,
|
||||
params: FsReadFileParams,
|
||||
) -> Result<FsReadFileResponse, JSONRPCErrorError> {
|
||||
self.process.require_initialized_for("filesystem")?;
|
||||
self.require_initialized_for("filesystem")?;
|
||||
self.file_system.read_file(params).await
|
||||
}
|
||||
|
||||
@@ -90,7 +157,7 @@ impl ExecServerHandler {
|
||||
&self,
|
||||
params: FsWriteFileParams,
|
||||
) -> Result<FsWriteFileResponse, JSONRPCErrorError> {
|
||||
self.process.require_initialized_for("filesystem")?;
|
||||
self.require_initialized_for("filesystem")?;
|
||||
self.file_system.write_file(params).await
|
||||
}
|
||||
|
||||
@@ -98,7 +165,7 @@ impl ExecServerHandler {
|
||||
&self,
|
||||
params: FsCreateDirectoryParams,
|
||||
) -> Result<FsCreateDirectoryResponse, JSONRPCErrorError> {
|
||||
self.process.require_initialized_for("filesystem")?;
|
||||
self.require_initialized_for("filesystem")?;
|
||||
self.file_system.create_directory(params).await
|
||||
}
|
||||
|
||||
@@ -106,7 +173,7 @@ impl ExecServerHandler {
|
||||
&self,
|
||||
params: FsGetMetadataParams,
|
||||
) -> Result<FsGetMetadataResponse, JSONRPCErrorError> {
|
||||
self.process.require_initialized_for("filesystem")?;
|
||||
self.require_initialized_for("filesystem")?;
|
||||
self.file_system.get_metadata(params).await
|
||||
}
|
||||
|
||||
@@ -114,7 +181,7 @@ impl ExecServerHandler {
|
||||
&self,
|
||||
params: FsReadDirectoryParams,
|
||||
) -> Result<FsReadDirectoryResponse, JSONRPCErrorError> {
|
||||
self.process.require_initialized_for("filesystem")?;
|
||||
self.require_initialized_for("filesystem")?;
|
||||
self.file_system.read_directory(params).await
|
||||
}
|
||||
|
||||
@@ -122,7 +189,7 @@ impl ExecServerHandler {
|
||||
&self,
|
||||
params: FsRemoveParams,
|
||||
) -> Result<FsRemoveResponse, JSONRPCErrorError> {
|
||||
self.process.require_initialized_for("filesystem")?;
|
||||
self.require_initialized_for("filesystem")?;
|
||||
self.file_system.remove(params).await
|
||||
}
|
||||
|
||||
@@ -130,9 +197,49 @@ impl ExecServerHandler {
|
||||
&self,
|
||||
params: FsCopyParams,
|
||||
) -> Result<FsCopyResponse, JSONRPCErrorError> {
|
||||
self.process.require_initialized_for("filesystem")?;
|
||||
self.require_initialized_for("filesystem")?;
|
||||
self.file_system.copy(params).await
|
||||
}
|
||||
|
||||
fn require_initialized_for(
|
||||
&self,
|
||||
method_family: &str,
|
||||
) -> Result<SessionHandle, JSONRPCErrorError> {
|
||||
if !self.initialize_requested.load(Ordering::SeqCst) {
|
||||
return Err(invalid_request(format!(
|
||||
"client must call initialize before using {method_family} methods"
|
||||
)));
|
||||
}
|
||||
let session = self.require_session_attached()?;
|
||||
if !self.initialized.load(Ordering::SeqCst) {
|
||||
return Err(invalid_request(format!(
|
||||
"client must send initialized before using {method_family} methods"
|
||||
)));
|
||||
}
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
fn require_session_attached(&self) -> Result<SessionHandle, JSONRPCErrorError> {
|
||||
let Some(session) = self.session() else {
|
||||
return Err(invalid_request(
|
||||
"client must call initialize before using methods".to_string(),
|
||||
));
|
||||
};
|
||||
if session.is_session_attached() {
|
||||
return Ok(session);
|
||||
}
|
||||
|
||||
Err(invalid_request(
|
||||
"session has been resumed by another connection".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn session(&self) -> Option<SessionHandle> {
|
||||
self.session
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -4,43 +4,81 @@ use std::time::Duration;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
use tokio::sync::mpsc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::ExecServerHandler;
|
||||
use crate::ProcessId;
|
||||
use crate::protocol::ExecParams;
|
||||
use crate::protocol::InitializeResponse;
|
||||
use crate::protocol::InitializeParams;
|
||||
use crate::protocol::ReadParams;
|
||||
use crate::protocol::ReadResponse;
|
||||
use crate::protocol::TerminateParams;
|
||||
use crate::protocol::TerminateResponse;
|
||||
use crate::rpc::RpcNotificationSender;
|
||||
use crate::server::session_registry::SessionRegistry;
|
||||
|
||||
fn exec_params(process_id: &str) -> ExecParams {
|
||||
let mut env = HashMap::new();
|
||||
if let Some(path) = std::env::var_os("PATH") {
|
||||
env.insert("PATH".to_string(), path.to_string_lossy().into_owned());
|
||||
}
|
||||
exec_params_with_argv(process_id, sleep_argv())
|
||||
}
|
||||
|
||||
fn exec_params_with_argv(process_id: &str, argv: Vec<String>) -> ExecParams {
|
||||
ExecParams {
|
||||
process_id: ProcessId::from(process_id),
|
||||
argv: vec![
|
||||
"bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"sleep 0.1".to_string(),
|
||||
],
|
||||
argv,
|
||||
cwd: std::env::current_dir().expect("cwd"),
|
||||
env,
|
||||
env: inherited_path_env(),
|
||||
tty: false,
|
||||
arg0: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn inherited_path_env() -> HashMap<String, String> {
|
||||
let mut env = HashMap::new();
|
||||
if let Some(path) = std::env::var_os("PATH") {
|
||||
env.insert("PATH".to_string(), path.to_string_lossy().into_owned());
|
||||
}
|
||||
env
|
||||
}
|
||||
|
||||
fn sleep_argv() -> Vec<String> {
|
||||
shell_argv("sleep 0.1", "ping -n 2 127.0.0.1 >NUL")
|
||||
}
|
||||
|
||||
fn shell_argv(unix_script: &str, windows_script: &str) -> Vec<String> {
|
||||
if cfg!(windows) {
|
||||
vec![
|
||||
windows_command_processor(),
|
||||
"/C".to_string(),
|
||||
windows_script.to_string(),
|
||||
]
|
||||
} else {
|
||||
vec![
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
unix_script.to_string(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
fn windows_command_processor() -> String {
|
||||
std::env::var("COMSPEC").unwrap_or_else(|_| "cmd.exe".to_string())
|
||||
}
|
||||
|
||||
async fn initialized_handler() -> Arc<ExecServerHandler> {
|
||||
let (outgoing_tx, _outgoing_rx) = mpsc::channel(16);
|
||||
let handler = Arc::new(ExecServerHandler::new(RpcNotificationSender::new(
|
||||
outgoing_tx,
|
||||
)));
|
||||
assert_eq!(
|
||||
handler.initialize().expect("initialize"),
|
||||
InitializeResponse {}
|
||||
);
|
||||
let registry = SessionRegistry::new();
|
||||
let handler = Arc::new(ExecServerHandler::new(
|
||||
registry,
|
||||
RpcNotificationSender::new(outgoing_tx),
|
||||
));
|
||||
let initialize_response = handler
|
||||
.initialize(InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
.expect("initialize");
|
||||
Uuid::parse_str(&initialize_response.session_id).expect("session id should be a UUID");
|
||||
handler.initialized().expect("initialized");
|
||||
handler
|
||||
}
|
||||
@@ -101,3 +139,197 @@ async fn terminate_reports_false_after_process_exit() {
|
||||
|
||||
handler.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn long_poll_read_fails_after_session_resume() {
|
||||
let (first_tx, _first_rx) = mpsc::channel(16);
|
||||
let registry = SessionRegistry::new();
|
||||
let first_handler = Arc::new(ExecServerHandler::new(
|
||||
Arc::clone(®istry),
|
||||
RpcNotificationSender::new(first_tx),
|
||||
));
|
||||
let initialize_response = first_handler
|
||||
.initialize(InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
.expect("initialize");
|
||||
first_handler.initialized().expect("initialized");
|
||||
|
||||
first_handler
|
||||
.exec(exec_params_with_argv(
|
||||
"proc-long-poll",
|
||||
shell_argv(
|
||||
"sleep 0.1; printf resumed",
|
||||
"ping -n 2 127.0.0.1 >NUL && echo resumed",
|
||||
),
|
||||
))
|
||||
.await
|
||||
.expect("start process");
|
||||
|
||||
let first_read_handler = Arc::clone(&first_handler);
|
||||
let read_task = tokio::spawn(async move {
|
||||
first_read_handler
|
||||
.exec_read(ReadParams {
|
||||
process_id: ProcessId::from("proc-long-poll"),
|
||||
after_seq: None,
|
||||
max_bytes: None,
|
||||
wait_ms: Some(500),
|
||||
})
|
||||
.await
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
first_handler.shutdown().await;
|
||||
|
||||
let (second_tx, _second_rx) = mpsc::channel(16);
|
||||
let second_handler = Arc::new(ExecServerHandler::new(
|
||||
registry,
|
||||
RpcNotificationSender::new(second_tx),
|
||||
));
|
||||
second_handler
|
||||
.initialize(InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: Some(initialize_response.session_id),
|
||||
})
|
||||
.await
|
||||
.expect("initialize second connection");
|
||||
second_handler
|
||||
.initialized()
|
||||
.expect("initialized second connection");
|
||||
|
||||
let err = read_task
|
||||
.await
|
||||
.expect("read task should join")
|
||||
.expect_err("evicted long-poll read should fail");
|
||||
assert_eq!(err.code, -32600);
|
||||
assert_eq!(
|
||||
err.message,
|
||||
"session has been resumed by another connection"
|
||||
);
|
||||
|
||||
second_handler.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn active_session_resume_is_rejected() {
|
||||
let (first_tx, _first_rx) = mpsc::channel(16);
|
||||
let registry = SessionRegistry::new();
|
||||
let first_handler = Arc::new(ExecServerHandler::new(
|
||||
Arc::clone(®istry),
|
||||
RpcNotificationSender::new(first_tx),
|
||||
));
|
||||
let initialize_response = first_handler
|
||||
.initialize(InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
.expect("initialize");
|
||||
|
||||
let (second_tx, _second_rx) = mpsc::channel(16);
|
||||
let second_handler = Arc::new(ExecServerHandler::new(
|
||||
registry,
|
||||
RpcNotificationSender::new(second_tx),
|
||||
));
|
||||
let err = second_handler
|
||||
.initialize(InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: Some(initialize_response.session_id.clone()),
|
||||
})
|
||||
.await
|
||||
.expect_err("active session resume should fail");
|
||||
|
||||
assert_eq!(err.code, -32600);
|
||||
assert_eq!(
|
||||
err.message,
|
||||
format!(
|
||||
"session {} is already attached to another connection",
|
||||
initialize_response.session_id
|
||||
)
|
||||
);
|
||||
|
||||
first_handler.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn output_and_exit_are_retained_after_notification_receiver_closes() {
|
||||
let (outgoing_tx, outgoing_rx) = mpsc::channel(16);
|
||||
let handler = Arc::new(ExecServerHandler::new(
|
||||
SessionRegistry::new(),
|
||||
RpcNotificationSender::new(outgoing_tx),
|
||||
));
|
||||
handler
|
||||
.initialize(InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: None,
|
||||
})
|
||||
.await
|
||||
.expect("initialize");
|
||||
handler.initialized().expect("initialized");
|
||||
|
||||
let process_id = ProcessId::from("proc-notification-fail");
|
||||
handler
|
||||
.exec(exec_params_with_argv(
|
||||
process_id.as_str(),
|
||||
shell_argv(
|
||||
"sleep 0.05; printf 'first\\n'; sleep 0.05; printf 'second\\n'",
|
||||
"echo first && ping -n 2 127.0.0.1 >NUL && echo second",
|
||||
),
|
||||
))
|
||||
.await
|
||||
.expect("start process");
|
||||
|
||||
drop(outgoing_rx);
|
||||
|
||||
let (output, exit_code) = read_process_until_closed(&handler, process_id.clone()).await;
|
||||
assert_eq!(output.replace("\r\n", "\n"), "first\nsecond\n");
|
||||
assert_eq!(exit_code, Some(0));
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
handler
|
||||
.exec(exec_params(process_id.as_str()))
|
||||
.await
|
||||
.expect("process id should be reusable after exit retention");
|
||||
|
||||
handler.shutdown().await;
|
||||
}
|
||||
|
||||
async fn read_process_until_closed(
|
||||
handler: &ExecServerHandler,
|
||||
process_id: ProcessId,
|
||||
) -> (String, Option<i32>) {
|
||||
let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
|
||||
let mut output = String::new();
|
||||
let mut exit_code = None;
|
||||
let mut after_seq = None;
|
||||
|
||||
loop {
|
||||
let response: ReadResponse = handler
|
||||
.exec_read(ReadParams {
|
||||
process_id: process_id.clone(),
|
||||
after_seq,
|
||||
max_bytes: None,
|
||||
wait_ms: Some(500),
|
||||
})
|
||||
.await
|
||||
.expect("read process");
|
||||
|
||||
for chunk in response.chunks {
|
||||
output.push_str(&String::from_utf8_lossy(&chunk.chunk.into_inner()));
|
||||
after_seq = Some(chunk.seq);
|
||||
}
|
||||
if response.exited {
|
||||
exit_code = response.exit_code;
|
||||
}
|
||||
if response.closed {
|
||||
return (output, exit_code);
|
||||
}
|
||||
after_seq = response.next_seq.checked_sub(1).or(after_seq);
|
||||
assert!(
|
||||
tokio::time::Instant::now() < deadline,
|
||||
"process should close within 2s"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use crate::local_process::LocalProcess;
|
||||
use crate::protocol::ExecParams;
|
||||
use crate::protocol::ExecResponse;
|
||||
use crate::protocol::InitializeResponse;
|
||||
use crate::protocol::ReadParams;
|
||||
use crate::protocol::ReadResponse;
|
||||
use crate::protocol::TerminateParams;
|
||||
@@ -28,19 +27,8 @@ impl ProcessHandler {
|
||||
self.process.shutdown().await;
|
||||
}
|
||||
|
||||
pub(crate) fn initialize(&self) -> Result<InitializeResponse, JSONRPCErrorError> {
|
||||
self.process.initialize()
|
||||
}
|
||||
|
||||
pub(crate) fn initialized(&self) -> Result<(), String> {
|
||||
self.process.initialized()
|
||||
}
|
||||
|
||||
pub(crate) fn require_initialized_for(
|
||||
&self,
|
||||
method_family: &str,
|
||||
) -> Result<(), JSONRPCErrorError> {
|
||||
self.process.require_initialized_for(method_family)
|
||||
pub(crate) fn set_notification_sender(&self, notifications: Option<RpcNotificationSender>) {
|
||||
self.process.set_notification_sender(notifications);
|
||||
}
|
||||
|
||||
pub(crate) async fn exec(&self, params: ExecParams) -> Result<ExecResponse, JSONRPCErrorError> {
|
||||
|
||||
@@ -14,14 +14,33 @@ use crate::rpc::invalid_request;
|
||||
use crate::rpc::method_not_found;
|
||||
use crate::server::ExecServerHandler;
|
||||
use crate::server::registry::build_router;
|
||||
use crate::server::session_registry::SessionRegistry;
|
||||
|
||||
pub(crate) async fn run_connection(connection: JsonRpcConnection) {
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ConnectionProcessor {
|
||||
session_registry: Arc<SessionRegistry>,
|
||||
}
|
||||
|
||||
impl ConnectionProcessor {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
session_registry: SessionRegistry::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn run_connection(&self, connection: JsonRpcConnection) {
|
||||
run_connection(connection, Arc::clone(&self.session_registry)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_connection(connection: JsonRpcConnection, session_registry: Arc<SessionRegistry>) {
|
||||
let router = Arc::new(build_router());
|
||||
let (json_outgoing_tx, mut incoming_rx, connection_tasks) = connection.into_parts();
|
||||
let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) =
|
||||
connection.into_parts();
|
||||
let (outgoing_tx, mut outgoing_rx) =
|
||||
mpsc::channel::<RpcServerOutboundMessage>(CHANNEL_CAPACITY);
|
||||
let notifications = RpcNotificationSender::new(outgoing_tx.clone());
|
||||
let handler = Arc::new(ExecServerHandler::new(notifications));
|
||||
let handler = Arc::new(ExecServerHandler::new(session_registry, notifications));
|
||||
|
||||
let outbound_task = tokio::spawn(async move {
|
||||
while let Some(message) = outgoing_rx.recv().await {
|
||||
@@ -40,6 +59,10 @@ pub(crate) async fn run_connection(connection: JsonRpcConnection) {
|
||||
|
||||
// Process inbound events sequentially to preserve initialize/initialized ordering.
|
||||
while let Some(event) = incoming_rx.recv().await {
|
||||
if !handler.is_session_attached() {
|
||||
debug!("exec-server connection evicted after session resume");
|
||||
break;
|
||||
}
|
||||
match event {
|
||||
JsonRpcConnectionEvent::MalformedMessage { reason } => {
|
||||
warn!("ignoring malformed exec-server message: {reason}");
|
||||
@@ -57,7 +80,13 @@ pub(crate) async fn run_connection(connection: JsonRpcConnection) {
|
||||
JsonRpcConnectionEvent::Message(message) => match message {
|
||||
codex_app_server_protocol::JSONRPCMessage::Request(request) => {
|
||||
if let Some(route) = router.request_route(request.method.as_str()) {
|
||||
let message = route(handler.clone(), request).await;
|
||||
let message = tokio::select! {
|
||||
message = route(Arc::clone(&handler), request) => message,
|
||||
_ = disconnected_rx.changed() => {
|
||||
debug!("exec-server transport disconnected while handling request");
|
||||
break;
|
||||
}
|
||||
};
|
||||
if outgoing_tx.send(message).await.is_err() {
|
||||
break;
|
||||
}
|
||||
@@ -84,7 +113,16 @@ pub(crate) async fn run_connection(connection: JsonRpcConnection) {
|
||||
);
|
||||
break;
|
||||
};
|
||||
if let Err(err) = route(handler.clone(), notification).await {
|
||||
let result = tokio::select! {
|
||||
result = route(Arc::clone(&handler), notification) => result,
|
||||
_ = disconnected_rx.changed() => {
|
||||
debug!(
|
||||
"exec-server transport disconnected while handling notification"
|
||||
);
|
||||
break;
|
||||
}
|
||||
};
|
||||
if let Err(err) = result {
|
||||
warn!("closing exec-server connection after protocol error: {err}");
|
||||
break;
|
||||
}
|
||||
@@ -114,6 +152,7 @@ pub(crate) async fn run_connection(connection: JsonRpcConnection) {
|
||||
}
|
||||
|
||||
handler.shutdown().await;
|
||||
drop(handler);
|
||||
drop(outgoing_tx);
|
||||
for task in connection_tasks {
|
||||
task.abort();
|
||||
@@ -121,3 +160,230 @@ pub(crate) async fn run_connection(connection: JsonRpcConnection) {
|
||||
}
|
||||
let _ = outbound_task.await;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use serde::Serialize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::io::DuplexStream;
|
||||
use tokio::io::Lines;
|
||||
use tokio::io::duplex;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::run_connection;
|
||||
use crate::ProcessId;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::protocol::EXEC_METHOD;
|
||||
use crate::protocol::EXEC_READ_METHOD;
|
||||
use crate::protocol::EXEC_TERMINATE_METHOD;
|
||||
use crate::protocol::ExecParams;
|
||||
use crate::protocol::ExecResponse;
|
||||
use crate::protocol::INITIALIZE_METHOD;
|
||||
use crate::protocol::INITIALIZED_METHOD;
|
||||
use crate::protocol::InitializeParams;
|
||||
use crate::protocol::InitializeResponse;
|
||||
use crate::protocol::ReadParams;
|
||||
use crate::protocol::TerminateParams;
|
||||
use crate::protocol::TerminateResponse;
|
||||
use crate::server::session_registry::SessionRegistry;
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_disconnect_detaches_session_during_in_flight_read() {
|
||||
let registry = SessionRegistry::new();
|
||||
let (mut first_writer, mut first_lines, first_task) =
|
||||
spawn_test_connection(Arc::clone(®istry), "first");
|
||||
|
||||
send_request(
|
||||
&mut first_writer,
|
||||
/*id*/ 1,
|
||||
INITIALIZE_METHOD,
|
||||
&InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
let initialize_response: InitializeResponse =
|
||||
read_response(&mut first_lines, /*expected_id*/ 1).await;
|
||||
send_notification(&mut first_writer, INITIALIZED_METHOD, &()).await;
|
||||
|
||||
let process_id = ProcessId::from("proc-long-poll");
|
||||
send_request(
|
||||
&mut first_writer,
|
||||
/*id*/ 2,
|
||||
EXEC_METHOD,
|
||||
&exec_params(process_id.clone()),
|
||||
)
|
||||
.await;
|
||||
let _: ExecResponse = read_response(&mut first_lines, /*expected_id*/ 2).await;
|
||||
|
||||
send_request(
|
||||
&mut first_writer,
|
||||
/*id*/ 3,
|
||||
EXEC_READ_METHOD,
|
||||
&ReadParams {
|
||||
process_id: process_id.clone(),
|
||||
after_seq: None,
|
||||
max_bytes: None,
|
||||
wait_ms: Some(5_000),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
drop(first_writer);
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
|
||||
let (mut second_writer, mut second_lines, second_task) =
|
||||
spawn_test_connection(Arc::clone(®istry), "second");
|
||||
send_request(
|
||||
&mut second_writer,
|
||||
/*id*/ 1,
|
||||
INITIALIZE_METHOD,
|
||||
&InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: Some(initialize_response.session_id.clone()),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
let second_initialize_response = timeout(
|
||||
Duration::from_secs(1),
|
||||
read_response::<InitializeResponse>(&mut second_lines, /*expected_id*/ 1),
|
||||
)
|
||||
.await
|
||||
.expect("resume initialize should not wait for the old read to finish");
|
||||
assert_eq!(
|
||||
second_initialize_response.session_id,
|
||||
initialize_response.session_id
|
||||
);
|
||||
timeout(Duration::from_secs(1), first_task)
|
||||
.await
|
||||
.expect("first processor should exit")
|
||||
.expect("first processor should join");
|
||||
send_notification(&mut second_writer, INITIALIZED_METHOD, &()).await;
|
||||
|
||||
send_request(
|
||||
&mut second_writer,
|
||||
/*id*/ 2,
|
||||
EXEC_TERMINATE_METHOD,
|
||||
&TerminateParams { process_id },
|
||||
)
|
||||
.await;
|
||||
let _: TerminateResponse = read_response(&mut second_lines, /*expected_id*/ 2).await;
|
||||
|
||||
drop(second_writer);
|
||||
drop(second_lines);
|
||||
timeout(Duration::from_secs(1), second_task)
|
||||
.await
|
||||
.expect("second processor should exit")
|
||||
.expect("second processor should join");
|
||||
}
|
||||
|
||||
fn spawn_test_connection(
|
||||
registry: Arc<SessionRegistry>,
|
||||
label: &str,
|
||||
) -> (DuplexStream, Lines<BufReader<DuplexStream>>, JoinHandle<()>) {
|
||||
let (client_writer, server_reader) = duplex(1 << 20);
|
||||
let (server_writer, client_reader) = duplex(1 << 20);
|
||||
let connection =
|
||||
JsonRpcConnection::from_stdio(server_reader, server_writer, label.to_string());
|
||||
let task = tokio::spawn(run_connection(connection, registry));
|
||||
(client_writer, BufReader::new(client_reader).lines(), task)
|
||||
}
|
||||
|
||||
async fn send_request<P: Serialize>(
|
||||
writer: &mut DuplexStream,
|
||||
id: i64,
|
||||
method: &str,
|
||||
params: &P,
|
||||
) {
|
||||
write_message(
|
||||
writer,
|
||||
&JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(id),
|
||||
method: method.to_string(),
|
||||
params: Some(serde_json::to_value(params).expect("serialize params")),
|
||||
trace: None,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn send_notification<P: Serialize>(writer: &mut DuplexStream, method: &str, params: &P) {
|
||||
write_message(
|
||||
writer,
|
||||
&JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: method.to_string(),
|
||||
params: Some(serde_json::to_value(params).expect("serialize params")),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn write_message(writer: &mut DuplexStream, message: &JSONRPCMessage) {
|
||||
let encoded = serde_json::to_vec(message).expect("serialize JSON-RPC message");
|
||||
writer.write_all(&encoded).await.expect("write request");
|
||||
writer.write_all(b"\n").await.expect("write newline");
|
||||
}
|
||||
|
||||
async fn read_response<T: DeserializeOwned>(
|
||||
lines: &mut Lines<BufReader<DuplexStream>>,
|
||||
expected_id: i64,
|
||||
) -> T {
|
||||
let line = lines
|
||||
.next_line()
|
||||
.await
|
||||
.expect("read response")
|
||||
.expect("response line");
|
||||
match serde_json::from_str::<JSONRPCMessage>(&line).expect("decode JSON-RPC response") {
|
||||
JSONRPCMessage::Response(JSONRPCResponse { id, result }) => {
|
||||
assert_eq!(id, RequestId::Integer(expected_id));
|
||||
serde_json::from_value(result).expect("decode response result")
|
||||
}
|
||||
JSONRPCMessage::Error(error) => panic!("unexpected JSON-RPC error: {error:?}"),
|
||||
other => panic!("expected JSON-RPC response, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn exec_params(process_id: ProcessId) -> ExecParams {
|
||||
let mut env = HashMap::new();
|
||||
if let Some(path) = std::env::var_os("PATH") {
|
||||
env.insert("PATH".to_string(), path.to_string_lossy().into_owned());
|
||||
}
|
||||
ExecParams {
|
||||
process_id,
|
||||
argv: sleep_then_print_argv(),
|
||||
cwd: std::env::current_dir().expect("cwd"),
|
||||
env,
|
||||
tty: false,
|
||||
arg0: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn sleep_then_print_argv() -> Vec<String> {
|
||||
if cfg!(windows) {
|
||||
vec![
|
||||
std::env::var("COMSPEC").unwrap_or_else(|_| "cmd.exe".to_string()),
|
||||
"/C".to_string(),
|
||||
"ping -n 3 127.0.0.1 >NUL && echo late".to_string(),
|
||||
]
|
||||
} else {
|
||||
vec![
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
"sleep 1; printf late".to_string(),
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,18 +30,18 @@ use crate::server::ExecServerHandler;
|
||||
|
||||
pub(crate) fn build_router() -> RpcRouter<ExecServerHandler> {
|
||||
let mut router = RpcRouter::new();
|
||||
router.request(
|
||||
INITIALIZE_METHOD,
|
||||
|handler: Arc<ExecServerHandler>, _params: InitializeParams| async move {
|
||||
handler.initialize()
|
||||
},
|
||||
);
|
||||
router.notification(
|
||||
INITIALIZED_METHOD,
|
||||
|handler: Arc<ExecServerHandler>, _params: serde_json::Value| async move {
|
||||
handler.initialized()
|
||||
},
|
||||
);
|
||||
router.request(
|
||||
INITIALIZE_METHOD,
|
||||
|handler: Arc<ExecServerHandler>, params: InitializeParams| async move {
|
||||
handler.initialize(params).await
|
||||
},
|
||||
);
|
||||
router.request(
|
||||
EXEC_METHOD,
|
||||
|handler: Arc<ExecServerHandler>, params: ExecParams| async move { handler.exec(params).await },
|
||||
|
||||
259
codex-rs/exec-server/src/server/session_registry.rs
Normal file
259
codex-rs/exec-server/src/server/session_registry.rs
Normal file
@@ -0,0 +1,259 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::rpc::RpcNotificationSender;
|
||||
use crate::rpc::invalid_request;
|
||||
use crate::server::process_handler::ProcessHandler;
|
||||
|
||||
#[cfg(test)]
|
||||
const DETACHED_SESSION_TTL: Duration = Duration::from_millis(200);
|
||||
#[cfg(not(test))]
|
||||
const DETACHED_SESSION_TTL: Duration = Duration::from_secs(10);
|
||||
|
||||
pub(crate) struct SessionRegistry {
|
||||
sessions: Mutex<HashMap<String, Arc<SessionEntry>>>,
|
||||
}
|
||||
|
||||
struct SessionEntry {
|
||||
session_id: String,
|
||||
process: ProcessHandler,
|
||||
attachment: StdMutex<AttachmentState>,
|
||||
}
|
||||
|
||||
struct AttachmentState {
|
||||
current_connection_id: Option<ConnectionId>,
|
||||
detached_connection_id: Option<ConnectionId>,
|
||||
detached_expires_at: Option<tokio::time::Instant>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
struct ConnectionId(Uuid);
|
||||
|
||||
impl std::fmt::Display for ConnectionId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SessionHandle {
|
||||
registry: Arc<SessionRegistry>,
|
||||
entry: Arc<SessionEntry>,
|
||||
connection_id: ConnectionId,
|
||||
}
|
||||
|
||||
impl SessionRegistry {
|
||||
pub(crate) fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
sessions: Mutex::new(HashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn attach(
|
||||
self: &Arc<Self>,
|
||||
resume_session_id: Option<String>,
|
||||
notifications: RpcNotificationSender,
|
||||
) -> Result<SessionHandle, JSONRPCErrorError> {
|
||||
enum AttachOutcome {
|
||||
Attached(Arc<SessionEntry>),
|
||||
Expired {
|
||||
session_id: String,
|
||||
entry: Arc<SessionEntry>,
|
||||
},
|
||||
}
|
||||
|
||||
let connection_id = ConnectionId(Uuid::new_v4());
|
||||
let outcome = {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
if let Some(session_id) = resume_session_id {
|
||||
let entry = sessions
|
||||
.get(&session_id)
|
||||
.cloned()
|
||||
.ok_or_else(|| invalid_request(format!("unknown session id {session_id}")))?;
|
||||
if entry.is_expired(tokio::time::Instant::now()) {
|
||||
let entry = sessions.remove(&session_id).ok_or_else(|| {
|
||||
invalid_request(format!("unknown session id {session_id}"))
|
||||
})?;
|
||||
Ok(AttachOutcome::Expired { session_id, entry })
|
||||
} else if entry.has_active_connection() {
|
||||
Err(invalid_request(format!(
|
||||
"session {session_id} is already attached to another connection"
|
||||
)))
|
||||
} else {
|
||||
entry.process.set_notification_sender(Some(notifications));
|
||||
entry.attach(connection_id);
|
||||
Ok(AttachOutcome::Attached(entry))
|
||||
}
|
||||
} else {
|
||||
let session_id = Uuid::new_v4().to_string();
|
||||
let entry = Arc::new(SessionEntry::new(
|
||||
session_id.clone(),
|
||||
ProcessHandler::new(notifications),
|
||||
connection_id,
|
||||
));
|
||||
sessions.insert(session_id, Arc::clone(&entry));
|
||||
Ok(AttachOutcome::Attached(entry))
|
||||
}
|
||||
};
|
||||
let entry = match outcome? {
|
||||
AttachOutcome::Attached(entry) => entry,
|
||||
AttachOutcome::Expired { session_id, entry } => {
|
||||
entry.process.shutdown().await;
|
||||
return Err(invalid_request(format!("unknown session id {session_id}")));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(SessionHandle {
|
||||
registry: Arc::clone(self),
|
||||
entry,
|
||||
connection_id,
|
||||
})
|
||||
}
|
||||
|
||||
async fn expire_if_detached(&self, session_id: String, connection_id: ConnectionId) {
|
||||
tokio::time::sleep(DETACHED_SESSION_TTL).await;
|
||||
|
||||
let removed = {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
let Some(entry) = sessions.get(&session_id) else {
|
||||
return;
|
||||
};
|
||||
if !entry.is_detached_connection_expired(connection_id, tokio::time::Instant::now()) {
|
||||
return;
|
||||
}
|
||||
sessions.remove(&session_id)
|
||||
};
|
||||
|
||||
if let Some(entry) = removed {
|
||||
entry.process.shutdown().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SessionRegistry {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
sessions: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionEntry {
|
||||
fn new(session_id: String, process: ProcessHandler, connection_id: ConnectionId) -> Self {
|
||||
Self {
|
||||
session_id,
|
||||
process,
|
||||
attachment: StdMutex::new(AttachmentState {
|
||||
current_connection_id: Some(connection_id),
|
||||
detached_connection_id: None,
|
||||
detached_expires_at: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn attach(&self, connection_id: ConnectionId) {
|
||||
let mut attachment = self
|
||||
.attachment
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
attachment.current_connection_id = Some(connection_id);
|
||||
attachment.detached_connection_id = None;
|
||||
attachment.detached_expires_at = None;
|
||||
}
|
||||
|
||||
fn detach(&self, connection_id: ConnectionId) -> bool {
|
||||
let mut attachment = self
|
||||
.attachment
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
if attachment.current_connection_id != Some(connection_id) {
|
||||
return false;
|
||||
}
|
||||
|
||||
attachment.current_connection_id = None;
|
||||
attachment.detached_connection_id = Some(connection_id);
|
||||
attachment.detached_expires_at = Some(tokio::time::Instant::now() + DETACHED_SESSION_TTL);
|
||||
true
|
||||
}
|
||||
|
||||
fn has_active_connection(&self) -> bool {
|
||||
self.attachment
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.current_connection_id
|
||||
.is_some()
|
||||
}
|
||||
|
||||
fn is_attached_to(&self, connection_id: ConnectionId) -> bool {
|
||||
self.attachment
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.current_connection_id
|
||||
== Some(connection_id)
|
||||
}
|
||||
|
||||
fn is_expired(&self, now: tokio::time::Instant) -> bool {
|
||||
self.attachment
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.detached_expires_at
|
||||
.is_some_and(|deadline| now >= deadline)
|
||||
}
|
||||
|
||||
fn is_detached_connection_expired(
|
||||
&self,
|
||||
connection_id: ConnectionId,
|
||||
now: tokio::time::Instant,
|
||||
) -> bool {
|
||||
let attachment = self
|
||||
.attachment
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
attachment.current_connection_id.is_none()
|
||||
&& attachment.detached_connection_id == Some(connection_id)
|
||||
&& attachment
|
||||
.detached_expires_at
|
||||
.is_some_and(|deadline| now >= deadline)
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionHandle {
|
||||
pub(crate) fn session_id(&self) -> &str {
|
||||
&self.entry.session_id
|
||||
}
|
||||
|
||||
pub(crate) fn connection_id(&self) -> String {
|
||||
self.connection_id.to_string()
|
||||
}
|
||||
|
||||
pub(crate) fn is_session_attached(&self) -> bool {
|
||||
self.entry.is_attached_to(self.connection_id)
|
||||
}
|
||||
|
||||
pub(crate) fn process(&self) -> &ProcessHandler {
|
||||
&self.entry.process
|
||||
}
|
||||
|
||||
pub(crate) async fn detach(&self) {
|
||||
if !self.entry.detach(self.connection_id) {
|
||||
return;
|
||||
}
|
||||
|
||||
self.entry
|
||||
.process
|
||||
.set_notification_sender(/*notifications*/ None);
|
||||
|
||||
let registry = Arc::clone(&self.registry);
|
||||
let session_id = self.entry.session_id.clone();
|
||||
let connection_id = self.connection_id;
|
||||
tokio::spawn(async move {
|
||||
registry.expire_if_detached(session_id, connection_id).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,7 @@ use tokio_tungstenite::accept_async;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::server::processor::run_connection;
|
||||
use crate::server::processor::ConnectionProcessor;
|
||||
|
||||
pub const DEFAULT_LISTEN_URL: &str = "ws://127.0.0.1:0";
|
||||
|
||||
@@ -58,19 +58,22 @@ async fn run_websocket_listener(
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
let local_addr = listener.local_addr()?;
|
||||
let processor = ConnectionProcessor::new();
|
||||
tracing::info!("codex-exec-server listening on ws://{local_addr}");
|
||||
println!("ws://{local_addr}");
|
||||
|
||||
loop {
|
||||
let (stream, peer_addr) = listener.accept().await?;
|
||||
let processor = processor.clone();
|
||||
tokio::spawn(async move {
|
||||
match accept_async(stream).await {
|
||||
Ok(websocket) => {
|
||||
run_connection(JsonRpcConnection::from_websocket(
|
||||
websocket,
|
||||
format!("exec-server websocket {peer_addr}"),
|
||||
))
|
||||
.await;
|
||||
processor
|
||||
.run_connection(JsonRpcConnection::from_websocket(
|
||||
websocket,
|
||||
format!("exec-server websocket {peer_addr}"),
|
||||
))
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
|
||||
@@ -64,6 +64,17 @@ impl ExecServerHarness {
|
||||
&self.websocket_url
|
||||
}
|
||||
|
||||
pub(crate) async fn disconnect_websocket(&mut self) -> anyhow::Result<()> {
|
||||
self.websocket.close(None).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn reconnect_websocket(&mut self) -> anyhow::Result<()> {
|
||||
let (websocket, _) = connect_websocket_when_ready(&self.websocket_url).await?;
|
||||
self.websocket = websocket;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn send_request(
|
||||
&mut self,
|
||||
method: &str,
|
||||
|
||||
@@ -8,6 +8,7 @@ use codex_exec_server::InitializeParams;
|
||||
use codex_exec_server::InitializeResponse;
|
||||
use common::exec_server::exec_server;
|
||||
use pretty_assertions::assert_eq;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exec_server_accepts_initialize() -> anyhow::Result<()> {
|
||||
@@ -17,6 +18,7 @@ async fn exec_server_accepts_initialize() -> anyhow::Result<()> {
|
||||
"initialize",
|
||||
serde_json::to_value(InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: None,
|
||||
})?,
|
||||
)
|
||||
.await?;
|
||||
@@ -27,7 +29,7 @@ async fn exec_server_accepts_initialize() -> anyhow::Result<()> {
|
||||
};
|
||||
assert_eq!(id, initialize_id);
|
||||
let initialize_response: InitializeResponse = serde_json::from_value(result)?;
|
||||
assert_eq!(initialize_response, InitializeResponse {});
|
||||
Uuid::parse_str(&initialize_response.session_id)?;
|
||||
|
||||
server.shutdown().await?;
|
||||
Ok(())
|
||||
|
||||
@@ -6,7 +6,10 @@ use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_exec_server::ExecResponse;
|
||||
use codex_exec_server::InitializeParams;
|
||||
use codex_exec_server::InitializeResponse;
|
||||
use codex_exec_server::ProcessId;
|
||||
use codex_exec_server::ReadResponse;
|
||||
use codex_exec_server::TerminateResponse;
|
||||
use common::exec_server::exec_server;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
@@ -18,6 +21,7 @@ async fn exec_server_starts_process_over_websocket() -> anyhow::Result<()> {
|
||||
"initialize",
|
||||
serde_json::to_value(InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: None,
|
||||
})?,
|
||||
)
|
||||
.await?;
|
||||
@@ -70,3 +74,137 @@ async fn exec_server_starts_process_over_websocket() -> anyhow::Result<()> {
|
||||
server.shutdown().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exec_server_resumes_detached_session_without_killing_processes() -> anyhow::Result<()> {
|
||||
let mut server = exec_server().await?;
|
||||
let initialize_id = server
|
||||
.send_request(
|
||||
"initialize",
|
||||
serde_json::to_value(InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: None,
|
||||
})?,
|
||||
)
|
||||
.await?;
|
||||
let response = server
|
||||
.wait_for_event(|event| {
|
||||
matches!(
|
||||
event,
|
||||
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &initialize_id
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
let JSONRPCMessage::Response(JSONRPCResponse { result, .. }) = response else {
|
||||
panic!("expected initialize response");
|
||||
};
|
||||
let initialize_response: InitializeResponse = serde_json::from_value(result)?;
|
||||
|
||||
server
|
||||
.send_notification("initialized", serde_json::json!({}))
|
||||
.await?;
|
||||
|
||||
let process_start_id = server
|
||||
.send_request(
|
||||
"process/start",
|
||||
serde_json::json!({
|
||||
"processId": "proc-resume",
|
||||
"argv": ["/bin/sh", "-c", "sleep 5"],
|
||||
"cwd": std::env::current_dir()?,
|
||||
"env": {},
|
||||
"tty": false,
|
||||
"arg0": null
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
let _ = server
|
||||
.wait_for_event(|event| {
|
||||
matches!(
|
||||
event,
|
||||
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &process_start_id
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
server.disconnect_websocket().await?;
|
||||
server.reconnect_websocket().await?;
|
||||
|
||||
let resume_initialize_id = server
|
||||
.send_request(
|
||||
"initialize",
|
||||
serde_json::to_value(InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: Some(initialize_response.session_id.clone()),
|
||||
})?,
|
||||
)
|
||||
.await?;
|
||||
let response = server
|
||||
.wait_for_event(|event| {
|
||||
matches!(
|
||||
event,
|
||||
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &resume_initialize_id
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
let JSONRPCMessage::Response(JSONRPCResponse { result, .. }) = response else {
|
||||
panic!("expected resume initialize response");
|
||||
};
|
||||
let resumed_response: InitializeResponse = serde_json::from_value(result)?;
|
||||
assert_eq!(resumed_response, initialize_response);
|
||||
|
||||
server
|
||||
.send_notification("initialized", serde_json::json!({}))
|
||||
.await?;
|
||||
|
||||
let process_read_id = server
|
||||
.send_request(
|
||||
"process/read",
|
||||
serde_json::json!({
|
||||
"processId": "proc-resume",
|
||||
"afterSeq": null,
|
||||
"maxBytes": null,
|
||||
"waitMs": 0
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
let response = server
|
||||
.wait_for_event(|event| {
|
||||
matches!(
|
||||
event,
|
||||
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &process_read_id
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
let JSONRPCMessage::Response(JSONRPCResponse { result, .. }) = response else {
|
||||
panic!("expected process/read response");
|
||||
};
|
||||
let process_read_response: ReadResponse = serde_json::from_value(result)?;
|
||||
assert!(process_read_response.failure.is_none());
|
||||
assert!(!process_read_response.exited);
|
||||
assert!(!process_read_response.closed);
|
||||
|
||||
let terminate_id = server
|
||||
.send_request(
|
||||
"process/terminate",
|
||||
serde_json::json!({
|
||||
"processId": "proc-resume"
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
let response = server
|
||||
.wait_for_event(|event| {
|
||||
matches!(
|
||||
event,
|
||||
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &terminate_id
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
let JSONRPCMessage::Response(JSONRPCResponse { result, .. }) = response else {
|
||||
panic!("expected process/terminate response");
|
||||
};
|
||||
let terminate_response: TerminateResponse = serde_json::from_value(result)?;
|
||||
assert_eq!(terminate_response, TerminateResponse { running: true });
|
||||
|
||||
server.shutdown().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ use codex_exec_server::InitializeParams;
|
||||
use codex_exec_server::InitializeResponse;
|
||||
use common::exec_server::exec_server;
|
||||
use pretty_assertions::assert_eq;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exec_server_reports_malformed_websocket_json_and_keeps_running() -> anyhow::Result<()> {
|
||||
@@ -36,6 +37,7 @@ async fn exec_server_reports_malformed_websocket_json_and_keeps_running() -> any
|
||||
"initialize",
|
||||
serde_json::to_value(InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: None,
|
||||
})?,
|
||||
)
|
||||
.await?;
|
||||
@@ -53,7 +55,7 @@ async fn exec_server_reports_malformed_websocket_json_and_keeps_running() -> any
|
||||
};
|
||||
assert_eq!(id, initialize_id);
|
||||
let initialize_response: InitializeResponse = serde_json::from_value(result)?;
|
||||
assert_eq!(initialize_response, InitializeResponse {});
|
||||
Uuid::parse_str(&initialize_response.session_id)?;
|
||||
|
||||
server.shutdown().await?;
|
||||
Ok(())
|
||||
|
||||
Reference in New Issue
Block a user