mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
Simplify subscribers
This commit is contained in:
@@ -20,7 +20,7 @@ use codex_app_server_protocol::FsWriteFileResponse;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tracing::debug;
|
||||
@@ -97,7 +97,14 @@ impl RemoteExecServerConnectArgs {
|
||||
|
||||
struct Inner {
|
||||
client: RpcClient,
|
||||
sessions: ArcSwap<HashMap<String, broadcast::Sender<ExecSessionEvent>>>,
|
||||
// The remote transport delivers one shared notification stream for every
|
||||
// process on the connection. Keep a local process_id -> sender registry so
|
||||
// we can demux those connection-global notifications into the single
|
||||
// process-scoped event channel returned by ExecBackend::start().
|
||||
sessions: ArcSwap<HashMap<String, mpsc::UnboundedSender<ExecSessionEvent>>>,
|
||||
// ArcSwap makes reads cheap on the hot notification path, but writes still
|
||||
// need serialization so concurrent register/remove operations do not
|
||||
// overwrite each other's copy-on-write updates.
|
||||
sessions_write_lock: Mutex<()>,
|
||||
reader_task: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
@@ -312,14 +319,8 @@ impl ExecServerClient {
|
||||
pub(crate) async fn register_session(
|
||||
&self,
|
||||
process_id: &str,
|
||||
) -> Result<
|
||||
(
|
||||
broadcast::Sender<ExecSessionEvent>,
|
||||
broadcast::Receiver<ExecSessionEvent>,
|
||||
),
|
||||
ExecServerError,
|
||||
> {
|
||||
let (events_tx, events_rx) = broadcast::channel(256);
|
||||
) -> Result<mpsc::UnboundedReceiver<ExecSessionEvent>, ExecServerError> {
|
||||
let (events_tx, events_rx) = mpsc::unbounded_channel();
|
||||
let _sessions_write_guard = self.inner.sessions_write_lock.lock().await;
|
||||
let sessions = self.inner.sessions.load();
|
||||
if sessions.contains_key(process_id) {
|
||||
@@ -328,9 +329,9 @@ impl ExecServerClient {
|
||||
)));
|
||||
}
|
||||
let mut next_sessions = sessions.as_ref().clone();
|
||||
next_sessions.insert(process_id.to_string(), events_tx.clone());
|
||||
next_sessions.insert(process_id.to_string(), events_tx);
|
||||
self.inner.sessions.store(Arc::new(next_sessions));
|
||||
Ok((events_tx, events_rx))
|
||||
Ok(events_rx)
|
||||
}
|
||||
|
||||
pub(crate) async fn unregister_session(&self, process_id: &str) {
|
||||
@@ -416,6 +417,8 @@ async fn handle_server_notification(
|
||||
EXEC_OUTPUT_DELTA_METHOD => {
|
||||
let params: ExecOutputDeltaNotification =
|
||||
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
|
||||
// Remote exec-server notifications are connection-global, so route
|
||||
// each event to the single local receiver that owns this process.
|
||||
let events_tx = inner.sessions.load().get(¶ms.process_id).cloned();
|
||||
if let Some(events_tx) = events_tx {
|
||||
let _ = events_tx.send(ExecSessionEvent::Output {
|
||||
@@ -444,6 +447,8 @@ async fn handle_server_notification(
|
||||
let sessions = inner.sessions.load();
|
||||
let events_tx = sessions.get(¶ms.process_id).cloned();
|
||||
if events_tx.is_some() {
|
||||
// Closed is the terminal lifecycle event for this process,
|
||||
// so drop the routing entry before forwarding it.
|
||||
let mut next_sessions = sessions.as_ref().clone();
|
||||
next_sessions.remove(¶ms.process_id);
|
||||
inner.sessions.store(Arc::new(next_sessions));
|
||||
|
||||
@@ -142,6 +142,6 @@ mod tests {
|
||||
.await
|
||||
.expect("start process");
|
||||
|
||||
assert_eq!(response.process_id().as_str(), "default-env-proc");
|
||||
assert_eq!(response.process.process_id().as_str(), "default-env-proc");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ pub use process::ExecBackend;
|
||||
pub use process::ExecProcess;
|
||||
pub use process::ExecSessionEvent;
|
||||
pub use process::ProcessId;
|
||||
pub use process::StartedExecProcess;
|
||||
pub use protocol::ExecClosedNotification;
|
||||
pub use protocol::ExecExitedNotification;
|
||||
pub use protocol::ExecOutputDeltaNotification;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
@@ -12,7 +11,6 @@ use codex_utils_pty::ExecCommandSession;
|
||||
use codex_utils_pty::TerminalSize;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::warn;
|
||||
|
||||
@@ -21,6 +19,7 @@ use crate::ExecProcess;
|
||||
use crate::ExecServerError;
|
||||
use crate::ExecSessionEvent;
|
||||
use crate::ProcessId;
|
||||
use crate::StartedExecProcess;
|
||||
use crate::protocol::EXEC_CLOSED_METHOD;
|
||||
use crate::protocol::ExecClosedNotification;
|
||||
use crate::protocol::ExecExitedNotification;
|
||||
@@ -43,7 +42,6 @@ use crate::rpc::invalid_params;
|
||||
use crate::rpc::invalid_request;
|
||||
|
||||
const RETAINED_OUTPUT_BYTES_PER_PROCESS: usize = 1024 * 1024;
|
||||
const EVENT_CHANNEL_CAPACITY: usize = 256;
|
||||
const NOTIFICATION_CHANNEL_CAPACITY: usize = 256;
|
||||
#[cfg(test)]
|
||||
const EXITED_PROCESS_RETENTION: Duration = Duration::from_millis(25);
|
||||
@@ -65,7 +63,7 @@ struct RunningProcess {
|
||||
next_seq: u64,
|
||||
exit_code: Option<i32>,
|
||||
output_notify: Arc<Notify>,
|
||||
session_events_tx: broadcast::Sender<ExecSessionEvent>,
|
||||
session_events_tx: mpsc::UnboundedSender<ExecSessionEvent>,
|
||||
open_streams: usize,
|
||||
closed: bool,
|
||||
}
|
||||
@@ -89,8 +87,6 @@ pub(crate) struct LocalProcess {
|
||||
|
||||
struct LocalExecProcess {
|
||||
process_id: ProcessId,
|
||||
events_tx: broadcast::Sender<ExecSessionEvent>,
|
||||
initial_events_rx: StdMutex<Option<broadcast::Receiver<ExecSessionEvent>>>,
|
||||
backend: LocalProcess,
|
||||
}
|
||||
|
||||
@@ -174,14 +170,7 @@ impl LocalProcess {
|
||||
async fn start_process(
|
||||
&self,
|
||||
params: ExecParams,
|
||||
) -> Result<
|
||||
(
|
||||
ExecResponse,
|
||||
broadcast::Sender<ExecSessionEvent>,
|
||||
broadcast::Receiver<ExecSessionEvent>,
|
||||
),
|
||||
JSONRPCErrorError,
|
||||
> {
|
||||
) -> Result<(ExecResponse, mpsc::UnboundedReceiver<ExecSessionEvent>), JSONRPCErrorError> {
|
||||
self.require_initialized_for("exec")?;
|
||||
let process_id = params.process_id.clone();
|
||||
warn!(
|
||||
@@ -244,7 +233,7 @@ impl LocalProcess {
|
||||
};
|
||||
|
||||
let output_notify = Arc::new(Notify::new());
|
||||
let (session_events_tx, session_events_rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
|
||||
let (session_events_tx, session_events_rx) = mpsc::unbounded_channel();
|
||||
{
|
||||
let mut process_map = self.inner.processes.lock().await;
|
||||
process_map.insert(
|
||||
@@ -298,17 +287,13 @@ impl LocalProcess {
|
||||
tty = params.tty,
|
||||
"exec-server started process"
|
||||
);
|
||||
Ok((
|
||||
ExecResponse { process_id },
|
||||
session_events_tx,
|
||||
session_events_rx,
|
||||
))
|
||||
Ok((ExecResponse { process_id }, session_events_rx))
|
||||
}
|
||||
|
||||
pub(crate) async fn exec(&self, params: ExecParams) -> Result<ExecResponse, JSONRPCErrorError> {
|
||||
self.start_process(params)
|
||||
.await
|
||||
.map(|(response, _, _)| response)
|
||||
.map(|(response, _)| response)
|
||||
}
|
||||
|
||||
pub(crate) async fn exec_read(
|
||||
@@ -469,17 +454,18 @@ impl LocalProcess {
|
||||
|
||||
#[async_trait]
|
||||
impl ExecBackend for LocalProcess {
|
||||
async fn start(&self, params: ExecParams) -> Result<Arc<dyn ExecProcess>, ExecServerError> {
|
||||
let (response, events_tx, events_rx) = self
|
||||
async fn start(&self, params: ExecParams) -> Result<StartedExecProcess, ExecServerError> {
|
||||
let (response, events) = self
|
||||
.start_process(params)
|
||||
.await
|
||||
.map_err(map_handler_error)?;
|
||||
Ok(Arc::new(LocalExecProcess {
|
||||
process_id: response.process_id.into(),
|
||||
events_tx,
|
||||
initial_events_rx: StdMutex::new(Some(events_rx)),
|
||||
backend: self.clone(),
|
||||
}))
|
||||
Ok(StartedExecProcess {
|
||||
process: Arc::new(LocalExecProcess {
|
||||
process_id: response.process_id.into(),
|
||||
backend: self.clone(),
|
||||
}),
|
||||
events,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -489,16 +475,6 @@ impl ExecProcess for LocalExecProcess {
|
||||
&self.process_id
|
||||
}
|
||||
|
||||
fn subscribe(&self) -> broadcast::Receiver<ExecSessionEvent> {
|
||||
let mut initial_events_rx = self
|
||||
.initial_events_rx
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
initial_events_rx
|
||||
.take()
|
||||
.unwrap_or_else(|| self.events_tx.subscribe())
|
||||
}
|
||||
|
||||
async fn write(&self, chunk: Vec<u8>) -> Result<(), ExecServerError> {
|
||||
self.backend.write(&self.process_id, chunk).await
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::ExecServerError;
|
||||
use crate::protocol::ExecOutputStream;
|
||||
@@ -28,6 +28,11 @@ pub enum ExecSessionEvent {
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct ProcessId(String);
|
||||
|
||||
pub struct StartedExecProcess {
|
||||
pub process: Arc<dyn ExecProcess>,
|
||||
pub events: mpsc::UnboundedReceiver<ExecSessionEvent>,
|
||||
}
|
||||
|
||||
impl ProcessId {
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
@@ -68,8 +73,6 @@ impl From<String> for ProcessId {
|
||||
pub trait ExecProcess: Send + Sync {
|
||||
fn process_id(&self) -> &ProcessId;
|
||||
|
||||
fn subscribe(&self) -> broadcast::Receiver<ExecSessionEvent>;
|
||||
|
||||
async fn write(&self, chunk: Vec<u8>) -> Result<(), ExecServerError>;
|
||||
|
||||
async fn terminate(&self) -> Result<(), ExecServerError>;
|
||||
@@ -77,5 +80,5 @@ pub trait ExecProcess: Send + Sync {
|
||||
|
||||
#[async_trait]
|
||||
pub trait ExecBackend: Send + Sync {
|
||||
async fn start(&self, params: ExecParams) -> Result<Arc<dyn ExecProcess>, ExecServerError>;
|
||||
async fn start(&self, params: ExecParams) -> Result<StartedExecProcess, ExecServerError>;
|
||||
}
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use crate::ExecBackend;
|
||||
use crate::ExecProcess;
|
||||
use crate::ExecServerClient;
|
||||
use crate::ExecServerError;
|
||||
use crate::ExecSessionEvent;
|
||||
use crate::ProcessId;
|
||||
use crate::StartedExecProcess;
|
||||
use crate::protocol::ExecParams;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -19,8 +17,6 @@ pub(crate) struct RemoteProcess {
|
||||
|
||||
struct RemoteExecProcess {
|
||||
process_id: ProcessId,
|
||||
events_tx: broadcast::Sender<ExecSessionEvent>,
|
||||
initial_events_rx: StdMutex<Option<broadcast::Receiver<ExecSessionEvent>>>,
|
||||
backend: RemoteProcess,
|
||||
}
|
||||
|
||||
@@ -48,20 +44,21 @@ impl RemoteProcess {
|
||||
|
||||
#[async_trait]
|
||||
impl ExecBackend for RemoteProcess {
|
||||
async fn start(&self, params: ExecParams) -> Result<Arc<dyn ExecProcess>, ExecServerError> {
|
||||
async fn start(&self, params: ExecParams) -> Result<StartedExecProcess, ExecServerError> {
|
||||
let process_id = params.process_id.clone();
|
||||
let (events_tx, events_rx) = self.client.register_session(&process_id).await?;
|
||||
let events = self.client.register_session(&process_id).await?;
|
||||
if let Err(err) = self.client.exec(params).await {
|
||||
self.client.unregister_session(&process_id).await;
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
Ok(Arc::new(RemoteExecProcess {
|
||||
process_id: process_id.into(),
|
||||
events_tx,
|
||||
initial_events_rx: StdMutex::new(Some(events_rx)),
|
||||
backend: self.clone(),
|
||||
}))
|
||||
Ok(StartedExecProcess {
|
||||
process: Arc::new(RemoteExecProcess {
|
||||
process_id: process_id.into(),
|
||||
backend: self.clone(),
|
||||
}),
|
||||
events,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,16 +68,6 @@ impl ExecProcess for RemoteExecProcess {
|
||||
&self.process_id
|
||||
}
|
||||
|
||||
fn subscribe(&self) -> broadcast::Receiver<ExecSessionEvent> {
|
||||
let mut initial_events_rx = self
|
||||
.initial_events_rx
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
initial_events_rx
|
||||
.take()
|
||||
.unwrap_or_else(|| self.events_tx.subscribe())
|
||||
}
|
||||
|
||||
async fn write(&self, chunk: Vec<u8>) -> Result<(), ExecServerError> {
|
||||
self.backend.write(&self.process_id, chunk).await
|
||||
}
|
||||
|
||||
@@ -8,9 +8,12 @@ use anyhow::Result;
|
||||
use codex_exec_server::Environment;
|
||||
use codex_exec_server::ExecBackend;
|
||||
use codex_exec_server::ExecParams;
|
||||
use codex_exec_server::ExecProcess;
|
||||
use codex_exec_server::ExecSessionEvent;
|
||||
use codex_exec_server::StartedExecProcess;
|
||||
use pretty_assertions::assert_eq;
|
||||
use test_case::test_case;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
@@ -52,17 +55,20 @@ async fn assert_exec_process_starts_and_exits(use_remote: bool) -> Result<()> {
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
assert_eq!(session.process_id().as_str(), "proc-1");
|
||||
let mut events = session.subscribe();
|
||||
assert_eq!(session.process.process_id().as_str(), "proc-1");
|
||||
let mut events = session.events;
|
||||
|
||||
let mut exit_code = None;
|
||||
loop {
|
||||
match timeout(Duration::from_secs(2), events.recv()).await?? {
|
||||
ExecSessionEvent::Exited {
|
||||
exit_code: code, ..
|
||||
} => exit_code = Some(code),
|
||||
ExecSessionEvent::Closed { .. } => break,
|
||||
ExecSessionEvent::Output { .. } => {}
|
||||
match timeout(Duration::from_secs(2), events.recv()).await? {
|
||||
Some(event) => match event {
|
||||
ExecSessionEvent::Exited {
|
||||
exit_code: code, ..
|
||||
} => exit_code = Some(code),
|
||||
ExecSessionEvent::Closed { .. } => break,
|
||||
ExecSessionEvent::Output { .. } => {}
|
||||
},
|
||||
None => anyhow::bail!("event stream closed before Closed event"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,24 +77,30 @@ async fn assert_exec_process_starts_and_exits(use_remote: bool) -> Result<()> {
|
||||
}
|
||||
|
||||
async fn collect_process_output_from_events(
|
||||
session: Arc<dyn codex_exec_server::ExecProcess>,
|
||||
session: Arc<dyn ExecProcess>,
|
||||
mut events: mpsc::UnboundedReceiver<ExecSessionEvent>,
|
||||
) -> Result<(String, i32, bool)> {
|
||||
let mut events = session.subscribe();
|
||||
let mut output = String::new();
|
||||
let mut exit_code = None;
|
||||
loop {
|
||||
match timeout(Duration::from_secs(2), events.recv()).await?? {
|
||||
ExecSessionEvent::Output { chunk, .. } => {
|
||||
output.push_str(&String::from_utf8_lossy(&chunk));
|
||||
}
|
||||
ExecSessionEvent::Exited {
|
||||
exit_code: code, ..
|
||||
} => exit_code = Some(code),
|
||||
ExecSessionEvent::Closed { .. } => {
|
||||
break;
|
||||
match timeout(Duration::from_secs(2), events.recv()).await? {
|
||||
Some(event) => match event {
|
||||
ExecSessionEvent::Output { chunk, .. } => {
|
||||
output.push_str(&String::from_utf8_lossy(&chunk));
|
||||
}
|
||||
ExecSessionEvent::Exited {
|
||||
exit_code: code, ..
|
||||
} => exit_code = Some(code),
|
||||
ExecSessionEvent::Closed { .. } => {
|
||||
break;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
anyhow::bail!("event stream closed before Closed event");
|
||||
}
|
||||
}
|
||||
}
|
||||
drop(session);
|
||||
Ok((output, exit_code.unwrap_or(-1), true))
|
||||
}
|
||||
|
||||
@@ -110,9 +122,10 @@ async fn assert_exec_process_streams_output(use_remote: bool) -> Result<()> {
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
assert_eq!(session.process_id().as_str(), process_id);
|
||||
assert_eq!(session.process.process_id().as_str(), process_id);
|
||||
|
||||
let (output, exit_code, closed) = collect_process_output_from_events(session).await?;
|
||||
let StartedExecProcess { process, events } = session;
|
||||
let (output, exit_code, closed) = collect_process_output_from_events(process, events).await?;
|
||||
assert_eq!(output, "session output\n");
|
||||
assert_eq!(exit_code, 0);
|
||||
assert!(closed);
|
||||
@@ -137,11 +150,12 @@ async fn assert_exec_process_write_then_read(use_remote: bool) -> Result<()> {
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
assert_eq!(session.process_id().as_str(), process_id);
|
||||
assert_eq!(session.process.process_id().as_str(), process_id);
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
session.write(b"hello\n".to_vec()).await?;
|
||||
let (output, exit_code, closed) = collect_process_output_from_events(session).await?;
|
||||
session.process.write(b"hello\n".to_vec()).await?;
|
||||
let StartedExecProcess { process, events } = session;
|
||||
let (output, exit_code, closed) = collect_process_output_from_events(process, events).await?;
|
||||
|
||||
assert!(
|
||||
output.contains("from-stdin:hello"),
|
||||
@@ -174,7 +188,8 @@ async fn assert_exec_process_preserves_queued_events_before_subscribe(
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
|
||||
let (output, exit_code, closed) = collect_process_output_from_events(session).await?;
|
||||
let StartedExecProcess { process, events } = session;
|
||||
let (output, exit_code, closed) = collect_process_output_from_events(process, events).await?;
|
||||
assert_eq!(output, "queued output\n");
|
||||
assert_eq!(exit_code, 0);
|
||||
assert!(closed);
|
||||
|
||||
Reference in New Issue
Block a user