Simplify subscribers

This commit is contained in:
jif-oai
2026-03-25 09:51:07 +00:00
parent 9e91c882fd
commit 3037b80396
7 changed files with 91 additions and 104 deletions

View File

@@ -20,7 +20,7 @@ use codex_app_server_protocol::FsWriteFileResponse;
use codex_app_server_protocol::JSONRPCNotification;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::time::timeout;
use tokio_tungstenite::connect_async;
use tracing::debug;
@@ -97,7 +97,14 @@ impl RemoteExecServerConnectArgs {
struct Inner {
client: RpcClient,
sessions: ArcSwap<HashMap<String, broadcast::Sender<ExecSessionEvent>>>,
// The remote transport delivers one shared notification stream for every
// process on the connection. Keep a local process_id -> sender registry so
// we can demux those connection-global notifications into the single
// process-scoped event channel returned by ExecBackend::start().
sessions: ArcSwap<HashMap<String, mpsc::UnboundedSender<ExecSessionEvent>>>,
// ArcSwap makes reads cheap on the hot notification path, but writes still
// need serialization so concurrent register/remove operations do not
// overwrite each other's copy-on-write updates.
sessions_write_lock: Mutex<()>,
reader_task: tokio::task::JoinHandle<()>,
}
@@ -312,14 +319,8 @@ impl ExecServerClient {
pub(crate) async fn register_session(
&self,
process_id: &str,
) -> Result<
(
broadcast::Sender<ExecSessionEvent>,
broadcast::Receiver<ExecSessionEvent>,
),
ExecServerError,
> {
let (events_tx, events_rx) = broadcast::channel(256);
) -> Result<mpsc::UnboundedReceiver<ExecSessionEvent>, ExecServerError> {
let (events_tx, events_rx) = mpsc::unbounded_channel();
let _sessions_write_guard = self.inner.sessions_write_lock.lock().await;
let sessions = self.inner.sessions.load();
if sessions.contains_key(process_id) {
@@ -328,9 +329,9 @@ impl ExecServerClient {
)));
}
let mut next_sessions = sessions.as_ref().clone();
next_sessions.insert(process_id.to_string(), events_tx.clone());
next_sessions.insert(process_id.to_string(), events_tx);
self.inner.sessions.store(Arc::new(next_sessions));
Ok((events_tx, events_rx))
Ok(events_rx)
}
pub(crate) async fn unregister_session(&self, process_id: &str) {
@@ -416,6 +417,8 @@ async fn handle_server_notification(
EXEC_OUTPUT_DELTA_METHOD => {
let params: ExecOutputDeltaNotification =
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
// Remote exec-server notifications are connection-global, so route
// each event to the single local receiver that owns this process.
let events_tx = inner.sessions.load().get(&params.process_id).cloned();
if let Some(events_tx) = events_tx {
let _ = events_tx.send(ExecSessionEvent::Output {
@@ -444,6 +447,8 @@ async fn handle_server_notification(
let sessions = inner.sessions.load();
let events_tx = sessions.get(&params.process_id).cloned();
if events_tx.is_some() {
// Closed is the terminal lifecycle event for this process,
// so drop the routing entry before forwarding it.
let mut next_sessions = sessions.as_ref().clone();
next_sessions.remove(&params.process_id);
inner.sessions.store(Arc::new(next_sessions));

View File

@@ -142,6 +142,6 @@ mod tests {
.await
.expect("start process");
assert_eq!(response.process_id().as_str(), "default-env-proc");
assert_eq!(response.process.process_id().as_str(), "default-env-proc");
}
}

View File

@@ -43,6 +43,7 @@ pub use process::ExecBackend;
pub use process::ExecProcess;
pub use process::ExecSessionEvent;
pub use process::ProcessId;
pub use process::StartedExecProcess;
pub use protocol::ExecClosedNotification;
pub use protocol::ExecExitedNotification;
pub use protocol::ExecOutputDeltaNotification;

View File

@@ -1,7 +1,6 @@
use std::collections::HashMap;
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::time::Duration;
@@ -12,7 +11,6 @@ use codex_utils_pty::ExecCommandSession;
use codex_utils_pty::TerminalSize;
use tokio::sync::Mutex;
use tokio::sync::Notify;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tracing::warn;
@@ -21,6 +19,7 @@ use crate::ExecProcess;
use crate::ExecServerError;
use crate::ExecSessionEvent;
use crate::ProcessId;
use crate::StartedExecProcess;
use crate::protocol::EXEC_CLOSED_METHOD;
use crate::protocol::ExecClosedNotification;
use crate::protocol::ExecExitedNotification;
@@ -43,7 +42,6 @@ use crate::rpc::invalid_params;
use crate::rpc::invalid_request;
const RETAINED_OUTPUT_BYTES_PER_PROCESS: usize = 1024 * 1024;
const EVENT_CHANNEL_CAPACITY: usize = 256;
const NOTIFICATION_CHANNEL_CAPACITY: usize = 256;
#[cfg(test)]
const EXITED_PROCESS_RETENTION: Duration = Duration::from_millis(25);
@@ -65,7 +63,7 @@ struct RunningProcess {
next_seq: u64,
exit_code: Option<i32>,
output_notify: Arc<Notify>,
session_events_tx: broadcast::Sender<ExecSessionEvent>,
session_events_tx: mpsc::UnboundedSender<ExecSessionEvent>,
open_streams: usize,
closed: bool,
}
@@ -89,8 +87,6 @@ pub(crate) struct LocalProcess {
struct LocalExecProcess {
process_id: ProcessId,
events_tx: broadcast::Sender<ExecSessionEvent>,
initial_events_rx: StdMutex<Option<broadcast::Receiver<ExecSessionEvent>>>,
backend: LocalProcess,
}
@@ -174,14 +170,7 @@ impl LocalProcess {
async fn start_process(
&self,
params: ExecParams,
) -> Result<
(
ExecResponse,
broadcast::Sender<ExecSessionEvent>,
broadcast::Receiver<ExecSessionEvent>,
),
JSONRPCErrorError,
> {
) -> Result<(ExecResponse, mpsc::UnboundedReceiver<ExecSessionEvent>), JSONRPCErrorError> {
self.require_initialized_for("exec")?;
let process_id = params.process_id.clone();
warn!(
@@ -244,7 +233,7 @@ impl LocalProcess {
};
let output_notify = Arc::new(Notify::new());
let (session_events_tx, session_events_rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
let (session_events_tx, session_events_rx) = mpsc::unbounded_channel();
{
let mut process_map = self.inner.processes.lock().await;
process_map.insert(
@@ -298,17 +287,13 @@ impl LocalProcess {
tty = params.tty,
"exec-server started process"
);
Ok((
ExecResponse { process_id },
session_events_tx,
session_events_rx,
))
Ok((ExecResponse { process_id }, session_events_rx))
}
pub(crate) async fn exec(&self, params: ExecParams) -> Result<ExecResponse, JSONRPCErrorError> {
self.start_process(params)
.await
.map(|(response, _, _)| response)
.map(|(response, _)| response)
}
pub(crate) async fn exec_read(
@@ -469,17 +454,18 @@ impl LocalProcess {
#[async_trait]
impl ExecBackend for LocalProcess {
async fn start(&self, params: ExecParams) -> Result<Arc<dyn ExecProcess>, ExecServerError> {
let (response, events_tx, events_rx) = self
async fn start(&self, params: ExecParams) -> Result<StartedExecProcess, ExecServerError> {
let (response, events) = self
.start_process(params)
.await
.map_err(map_handler_error)?;
Ok(Arc::new(LocalExecProcess {
process_id: response.process_id.into(),
events_tx,
initial_events_rx: StdMutex::new(Some(events_rx)),
backend: self.clone(),
}))
Ok(StartedExecProcess {
process: Arc::new(LocalExecProcess {
process_id: response.process_id.into(),
backend: self.clone(),
}),
events,
})
}
}
@@ -489,16 +475,6 @@ impl ExecProcess for LocalExecProcess {
&self.process_id
}
fn subscribe(&self) -> broadcast::Receiver<ExecSessionEvent> {
let mut initial_events_rx = self
.initial_events_rx
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
initial_events_rx
.take()
.unwrap_or_else(|| self.events_tx.subscribe())
}
async fn write(&self, chunk: Vec<u8>) -> Result<(), ExecServerError> {
self.backend.write(&self.process_id, chunk).await
}

View File

@@ -3,7 +3,7 @@ use std::ops::Deref;
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use crate::ExecServerError;
use crate::protocol::ExecOutputStream;
@@ -28,6 +28,11 @@ pub enum ExecSessionEvent {
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ProcessId(String);
pub struct StartedExecProcess {
pub process: Arc<dyn ExecProcess>,
pub events: mpsc::UnboundedReceiver<ExecSessionEvent>,
}
impl ProcessId {
pub fn as_str(&self) -> &str {
&self.0
@@ -68,8 +73,6 @@ impl From<String> for ProcessId {
pub trait ExecProcess: Send + Sync {
fn process_id(&self) -> &ProcessId;
fn subscribe(&self) -> broadcast::Receiver<ExecSessionEvent>;
async fn write(&self, chunk: Vec<u8>) -> Result<(), ExecServerError>;
async fn terminate(&self) -> Result<(), ExecServerError>;
@@ -77,5 +80,5 @@ pub trait ExecProcess: Send + Sync {
#[async_trait]
pub trait ExecBackend: Send + Sync {
async fn start(&self, params: ExecParams) -> Result<Arc<dyn ExecProcess>, ExecServerError>;
async fn start(&self, params: ExecParams) -> Result<StartedExecProcess, ExecServerError>;
}

View File

@@ -1,15 +1,13 @@
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use async_trait::async_trait;
use tokio::sync::broadcast;
use crate::ExecBackend;
use crate::ExecProcess;
use crate::ExecServerClient;
use crate::ExecServerError;
use crate::ExecSessionEvent;
use crate::ProcessId;
use crate::StartedExecProcess;
use crate::protocol::ExecParams;
#[derive(Clone)]
@@ -19,8 +17,6 @@ pub(crate) struct RemoteProcess {
struct RemoteExecProcess {
process_id: ProcessId,
events_tx: broadcast::Sender<ExecSessionEvent>,
initial_events_rx: StdMutex<Option<broadcast::Receiver<ExecSessionEvent>>>,
backend: RemoteProcess,
}
@@ -48,20 +44,21 @@ impl RemoteProcess {
#[async_trait]
impl ExecBackend for RemoteProcess {
async fn start(&self, params: ExecParams) -> Result<Arc<dyn ExecProcess>, ExecServerError> {
async fn start(&self, params: ExecParams) -> Result<StartedExecProcess, ExecServerError> {
let process_id = params.process_id.clone();
let (events_tx, events_rx) = self.client.register_session(&process_id).await?;
let events = self.client.register_session(&process_id).await?;
if let Err(err) = self.client.exec(params).await {
self.client.unregister_session(&process_id).await;
return Err(err);
}
Ok(Arc::new(RemoteExecProcess {
process_id: process_id.into(),
events_tx,
initial_events_rx: StdMutex::new(Some(events_rx)),
backend: self.clone(),
}))
Ok(StartedExecProcess {
process: Arc::new(RemoteExecProcess {
process_id: process_id.into(),
backend: self.clone(),
}),
events,
})
}
}
@@ -71,16 +68,6 @@ impl ExecProcess for RemoteExecProcess {
&self.process_id
}
fn subscribe(&self) -> broadcast::Receiver<ExecSessionEvent> {
let mut initial_events_rx = self
.initial_events_rx
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
initial_events_rx
.take()
.unwrap_or_else(|| self.events_tx.subscribe())
}
async fn write(&self, chunk: Vec<u8>) -> Result<(), ExecServerError> {
self.backend.write(&self.process_id, chunk).await
}

View File

@@ -8,9 +8,12 @@ use anyhow::Result;
use codex_exec_server::Environment;
use codex_exec_server::ExecBackend;
use codex_exec_server::ExecParams;
use codex_exec_server::ExecProcess;
use codex_exec_server::ExecSessionEvent;
use codex_exec_server::StartedExecProcess;
use pretty_assertions::assert_eq;
use test_case::test_case;
use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio::time::timeout;
@@ -52,17 +55,20 @@ async fn assert_exec_process_starts_and_exits(use_remote: bool) -> Result<()> {
arg0: None,
})
.await?;
assert_eq!(session.process_id().as_str(), "proc-1");
let mut events = session.subscribe();
assert_eq!(session.process.process_id().as_str(), "proc-1");
let mut events = session.events;
let mut exit_code = None;
loop {
match timeout(Duration::from_secs(2), events.recv()).await?? {
ExecSessionEvent::Exited {
exit_code: code, ..
} => exit_code = Some(code),
ExecSessionEvent::Closed { .. } => break,
ExecSessionEvent::Output { .. } => {}
match timeout(Duration::from_secs(2), events.recv()).await? {
Some(event) => match event {
ExecSessionEvent::Exited {
exit_code: code, ..
} => exit_code = Some(code),
ExecSessionEvent::Closed { .. } => break,
ExecSessionEvent::Output { .. } => {}
},
None => anyhow::bail!("event stream closed before Closed event"),
}
}
@@ -71,24 +77,30 @@ async fn assert_exec_process_starts_and_exits(use_remote: bool) -> Result<()> {
}
async fn collect_process_output_from_events(
session: Arc<dyn codex_exec_server::ExecProcess>,
session: Arc<dyn ExecProcess>,
mut events: mpsc::UnboundedReceiver<ExecSessionEvent>,
) -> Result<(String, i32, bool)> {
let mut events = session.subscribe();
let mut output = String::new();
let mut exit_code = None;
loop {
match timeout(Duration::from_secs(2), events.recv()).await?? {
ExecSessionEvent::Output { chunk, .. } => {
output.push_str(&String::from_utf8_lossy(&chunk));
}
ExecSessionEvent::Exited {
exit_code: code, ..
} => exit_code = Some(code),
ExecSessionEvent::Closed { .. } => {
break;
match timeout(Duration::from_secs(2), events.recv()).await? {
Some(event) => match event {
ExecSessionEvent::Output { chunk, .. } => {
output.push_str(&String::from_utf8_lossy(&chunk));
}
ExecSessionEvent::Exited {
exit_code: code, ..
} => exit_code = Some(code),
ExecSessionEvent::Closed { .. } => {
break;
}
},
None => {
anyhow::bail!("event stream closed before Closed event");
}
}
}
drop(session);
Ok((output, exit_code.unwrap_or(-1), true))
}
@@ -110,9 +122,10 @@ async fn assert_exec_process_streams_output(use_remote: bool) -> Result<()> {
arg0: None,
})
.await?;
assert_eq!(session.process_id().as_str(), process_id);
assert_eq!(session.process.process_id().as_str(), process_id);
let (output, exit_code, closed) = collect_process_output_from_events(session).await?;
let StartedExecProcess { process, events } = session;
let (output, exit_code, closed) = collect_process_output_from_events(process, events).await?;
assert_eq!(output, "session output\n");
assert_eq!(exit_code, 0);
assert!(closed);
@@ -137,11 +150,12 @@ async fn assert_exec_process_write_then_read(use_remote: bool) -> Result<()> {
arg0: None,
})
.await?;
assert_eq!(session.process_id().as_str(), process_id);
assert_eq!(session.process.process_id().as_str(), process_id);
tokio::time::sleep(Duration::from_millis(200)).await;
session.write(b"hello\n".to_vec()).await?;
let (output, exit_code, closed) = collect_process_output_from_events(session).await?;
session.process.write(b"hello\n".to_vec()).await?;
let StartedExecProcess { process, events } = session;
let (output, exit_code, closed) = collect_process_output_from_events(process, events).await?;
assert!(
output.contains("from-stdin:hello"),
@@ -174,7 +188,8 @@ async fn assert_exec_process_preserves_queued_events_before_subscribe(
tokio::time::sleep(Duration::from_millis(200)).await;
let (output, exit_code, closed) = collect_process_output_from_events(session).await?;
let StartedExecProcess { process, events } = session;
let (output, exit_code, closed) = collect_process_output_from_events(process, events).await?;
assert_eq!(output, "queued output\n");
assert_eq!(exit_code, 0);
assert!(closed);