mirror of
https://github.com/openai/codex.git
synced 2026-04-24 06:35:50 +00:00
nits in exec server
This commit is contained in:
@@ -310,7 +310,13 @@ impl ExecServerClient {
|
||||
pub(crate) async fn register_session(
|
||||
&self,
|
||||
process_id: &str,
|
||||
) -> Result<broadcast::Receiver<ExecSessionEvent>, ExecServerError> {
|
||||
) -> Result<
|
||||
(
|
||||
broadcast::Sender<ExecSessionEvent>,
|
||||
broadcast::Receiver<ExecSessionEvent>,
|
||||
),
|
||||
ExecServerError,
|
||||
> {
|
||||
let (events_tx, events_rx) = broadcast::channel(256);
|
||||
let mut sessions = self.inner.sessions.lock().await;
|
||||
if sessions.contains_key(process_id) {
|
||||
@@ -318,8 +324,8 @@ impl ExecServerClient {
|
||||
"session already registered for process {process_id}"
|
||||
)));
|
||||
}
|
||||
sessions.insert(process_id.to_string(), events_tx);
|
||||
Ok(events_rx)
|
||||
sessions.insert(process_id.to_string(), events_tx.clone());
|
||||
Ok((events_tx, events_rx))
|
||||
}
|
||||
|
||||
pub(crate) async fn unregister_session(&self, process_id: &str) {
|
||||
|
||||
@@ -89,7 +89,8 @@ pub(crate) struct LocalProcess {
|
||||
|
||||
struct LocalExecProcess {
|
||||
process_id: ProcessId,
|
||||
events: StdMutex<broadcast::Receiver<ExecSessionEvent>>,
|
||||
events_tx: broadcast::Sender<ExecSessionEvent>,
|
||||
initial_events_rx: StdMutex<Option<broadcast::Receiver<ExecSessionEvent>>>,
|
||||
backend: LocalProcess,
|
||||
}
|
||||
|
||||
@@ -173,7 +174,14 @@ impl LocalProcess {
|
||||
async fn start_process(
|
||||
&self,
|
||||
params: ExecParams,
|
||||
) -> Result<(ExecResponse, broadcast::Receiver<ExecSessionEvent>), JSONRPCErrorError> {
|
||||
) -> Result<
|
||||
(
|
||||
ExecResponse,
|
||||
broadcast::Sender<ExecSessionEvent>,
|
||||
broadcast::Receiver<ExecSessionEvent>,
|
||||
),
|
||||
JSONRPCErrorError,
|
||||
> {
|
||||
self.require_initialized_for("exec")?;
|
||||
let process_id = params.process_id.clone();
|
||||
warn!(
|
||||
@@ -249,7 +257,7 @@ impl LocalProcess {
|
||||
next_seq: 1,
|
||||
exit_code: None,
|
||||
output_notify: Arc::clone(&output_notify),
|
||||
session_events_tx,
|
||||
session_events_tx: session_events_tx.clone(),
|
||||
open_streams: 2,
|
||||
closed: false,
|
||||
})),
|
||||
@@ -290,13 +298,17 @@ impl LocalProcess {
|
||||
tty = params.tty,
|
||||
"exec-server started process"
|
||||
);
|
||||
Ok((ExecResponse { process_id }, session_events_rx))
|
||||
Ok((
|
||||
ExecResponse { process_id },
|
||||
session_events_tx,
|
||||
session_events_rx,
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) async fn exec(&self, params: ExecParams) -> Result<ExecResponse, JSONRPCErrorError> {
|
||||
self.start_process(params)
|
||||
.await
|
||||
.map(|(response, _)| response)
|
||||
.map(|(response, _, _)| response)
|
||||
}
|
||||
|
||||
pub(crate) async fn exec_read(
|
||||
@@ -458,13 +470,14 @@ impl LocalProcess {
|
||||
#[async_trait]
|
||||
impl ExecBackend for LocalProcess {
|
||||
async fn start(&self, params: ExecParams) -> Result<Arc<dyn ExecProcess>, ExecServerError> {
|
||||
let (response, events) = self
|
||||
let (response, events_tx, events_rx) = self
|
||||
.start_process(params)
|
||||
.await
|
||||
.map_err(map_handler_error)?;
|
||||
Ok(Arc::new(LocalExecProcess {
|
||||
process_id: response.process_id.into(),
|
||||
events: StdMutex::new(events),
|
||||
events_tx,
|
||||
initial_events_rx: StdMutex::new(Some(events_rx)),
|
||||
backend: self.clone(),
|
||||
}))
|
||||
}
|
||||
@@ -477,10 +490,13 @@ impl ExecProcess for LocalExecProcess {
|
||||
}
|
||||
|
||||
fn subscribe(&self) -> broadcast::Receiver<ExecSessionEvent> {
|
||||
self.events
|
||||
let mut initial_events_rx = self
|
||||
.initial_events_rx
|
||||
.lock()
|
||||
.expect("local exec process events mutex should not be poisoned")
|
||||
.resubscribe()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
initial_events_rx
|
||||
.take()
|
||||
.unwrap_or_else(|| self.events_tx.subscribe())
|
||||
}
|
||||
|
||||
async fn write(&self, chunk: Vec<u8>) -> Result<(), ExecServerError> {
|
||||
|
||||
@@ -19,7 +19,8 @@ pub(crate) struct RemoteProcess {
|
||||
|
||||
struct RemoteExecProcess {
|
||||
process_id: ProcessId,
|
||||
events: StdMutex<broadcast::Receiver<ExecSessionEvent>>,
|
||||
events_tx: broadcast::Sender<ExecSessionEvent>,
|
||||
initial_events_rx: StdMutex<Option<broadcast::Receiver<ExecSessionEvent>>>,
|
||||
backend: RemoteProcess,
|
||||
}
|
||||
|
||||
@@ -49,7 +50,7 @@ impl RemoteProcess {
|
||||
impl ExecBackend for RemoteProcess {
|
||||
async fn start(&self, params: ExecParams) -> Result<Arc<dyn ExecProcess>, ExecServerError> {
|
||||
let process_id = params.process_id.clone();
|
||||
let events = self.client.register_session(&process_id).await?;
|
||||
let (events_tx, events_rx) = self.client.register_session(&process_id).await?;
|
||||
if let Err(err) = self.client.exec(params).await {
|
||||
self.client.unregister_session(&process_id).await;
|
||||
return Err(err);
|
||||
@@ -57,7 +58,8 @@ impl ExecBackend for RemoteProcess {
|
||||
|
||||
Ok(Arc::new(RemoteExecProcess {
|
||||
process_id: process_id.into(),
|
||||
events: StdMutex::new(events),
|
||||
events_tx,
|
||||
initial_events_rx: StdMutex::new(Some(events_rx)),
|
||||
backend: self.clone(),
|
||||
}))
|
||||
}
|
||||
@@ -70,10 +72,13 @@ impl ExecProcess for RemoteExecProcess {
|
||||
}
|
||||
|
||||
fn subscribe(&self) -> broadcast::Receiver<ExecSessionEvent> {
|
||||
self.events
|
||||
let mut initial_events_rx = self
|
||||
.initial_events_rx
|
||||
.lock()
|
||||
.expect("remote exec process events mutex should not be poisoned")
|
||||
.resubscribe()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
initial_events_rx
|
||||
.take()
|
||||
.unwrap_or_else(|| self.events_tx.subscribe())
|
||||
}
|
||||
|
||||
async fn write(&self, chunk: Vec<u8>) -> Result<(), ExecServerError> {
|
||||
|
||||
@@ -6,10 +6,9 @@ use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_exec_server::Environment;
|
||||
use codex_exec_server::ExecBackend;
|
||||
use codex_exec_server::ExecParams;
|
||||
use codex_exec_server::ExecProcess;
|
||||
use codex_exec_server::ExecSessionEvent;
|
||||
use codex_exec_server::ExecSessionHandle;
|
||||
use pretty_assertions::assert_eq;
|
||||
use test_case::test_case;
|
||||
use tokio::time::Duration;
|
||||
@@ -19,7 +18,7 @@ use common::exec_server::ExecServerHarness;
|
||||
use common::exec_server::exec_server;
|
||||
|
||||
struct ProcessContext {
|
||||
process: Arc<dyn ExecProcess>,
|
||||
backend: Arc<dyn ExecBackend>,
|
||||
_server: Option<ExecServerHarness>,
|
||||
}
|
||||
|
||||
@@ -28,13 +27,13 @@ async fn create_process_context(use_remote: bool) -> Result<ProcessContext> {
|
||||
let server = exec_server().await?;
|
||||
let environment = Environment::create(Some(server.websocket_url().to_string())).await?;
|
||||
Ok(ProcessContext {
|
||||
process: environment.get_executor(),
|
||||
backend: environment.get_exec_backend(),
|
||||
_server: Some(server),
|
||||
})
|
||||
} else {
|
||||
let environment = Environment::create(None).await?;
|
||||
Ok(ProcessContext {
|
||||
process: environment.get_executor(),
|
||||
backend: environment.get_exec_backend(),
|
||||
_server: None,
|
||||
})
|
||||
}
|
||||
@@ -42,8 +41,8 @@ async fn create_process_context(use_remote: bool) -> Result<ProcessContext> {
|
||||
|
||||
async fn assert_exec_process_starts_and_exits(use_remote: bool) -> Result<()> {
|
||||
let context = create_process_context(use_remote).await?;
|
||||
let mut session = context
|
||||
.process
|
||||
let session = context
|
||||
.backend
|
||||
.start(ExecParams {
|
||||
process_id: "proc-1".to_string(),
|
||||
argv: vec!["true".to_string()],
|
||||
@@ -53,11 +52,12 @@ async fn assert_exec_process_starts_and_exits(use_remote: bool) -> Result<()> {
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
assert_eq!(session.process_id, "proc-1");
|
||||
assert_eq!(session.process_id().as_str(), "proc-1");
|
||||
let mut events = session.subscribe();
|
||||
|
||||
let mut exit_code = None;
|
||||
loop {
|
||||
match timeout(Duration::from_secs(2), session.events.recv()).await?? {
|
||||
match timeout(Duration::from_secs(2), events.recv()).await?? {
|
||||
ExecSessionEvent::Exited {
|
||||
exit_code: code, ..
|
||||
} => exit_code = Some(code),
|
||||
@@ -71,12 +71,13 @@ async fn assert_exec_process_starts_and_exits(use_remote: bool) -> Result<()> {
|
||||
}
|
||||
|
||||
async fn collect_process_output_from_events(
|
||||
mut session: ExecSessionHandle,
|
||||
session: Arc<dyn codex_exec_server::ExecProcess>,
|
||||
) -> Result<(String, i32, bool)> {
|
||||
let mut events = session.subscribe();
|
||||
let mut output = String::new();
|
||||
let mut exit_code = None;
|
||||
loop {
|
||||
match timeout(Duration::from_secs(2), session.events.recv()).await?? {
|
||||
match timeout(Duration::from_secs(2), events.recv()).await?? {
|
||||
ExecSessionEvent::Output { chunk, .. } => {
|
||||
output.push_str(&String::from_utf8_lossy(&chunk));
|
||||
}
|
||||
@@ -95,7 +96,7 @@ async fn assert_exec_process_streams_output(use_remote: bool) -> Result<()> {
|
||||
let context = create_process_context(use_remote).await?;
|
||||
let process_id = "proc-stream".to_string();
|
||||
let session = context
|
||||
.process
|
||||
.backend
|
||||
.start(ExecParams {
|
||||
process_id: process_id.clone(),
|
||||
argv: vec![
|
||||
@@ -109,7 +110,7 @@ async fn assert_exec_process_streams_output(use_remote: bool) -> Result<()> {
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
assert_eq!(session.process_id, process_id);
|
||||
assert_eq!(session.process_id().as_str(), process_id);
|
||||
|
||||
let (output, exit_code, closed) = collect_process_output_from_events(session).await?;
|
||||
assert_eq!(output, "session output\n");
|
||||
@@ -122,7 +123,7 @@ async fn assert_exec_process_write_then_read(use_remote: bool) -> Result<()> {
|
||||
let context = create_process_context(use_remote).await?;
|
||||
let process_id = "proc-stdin".to_string();
|
||||
let session = context
|
||||
.process
|
||||
.backend
|
||||
.start(ExecParams {
|
||||
process_id: process_id.clone(),
|
||||
argv: vec![
|
||||
@@ -136,10 +137,10 @@ async fn assert_exec_process_write_then_read(use_remote: bool) -> Result<()> {
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
assert_eq!(session.process_id, process_id);
|
||||
assert_eq!(session.process_id().as_str(), process_id);
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
session.write_stdin(b"hello\n".to_vec()).await?;
|
||||
session.write(b"hello\n".to_vec()).await?;
|
||||
let (output, exit_code, closed) = collect_process_output_from_events(session).await?;
|
||||
|
||||
assert!(
|
||||
@@ -151,6 +152,35 @@ async fn assert_exec_process_write_then_read(use_remote: bool) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn assert_exec_process_preserves_queued_events_before_subscribe(
|
||||
use_remote: bool,
|
||||
) -> Result<()> {
|
||||
let context = create_process_context(use_remote).await?;
|
||||
let session = context
|
||||
.backend
|
||||
.start(ExecParams {
|
||||
process_id: "proc-queued".to_string(),
|
||||
argv: vec![
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
"printf 'queued output\\n'".to_string(),
|
||||
],
|
||||
cwd: std::env::current_dir()?,
|
||||
env: Default::default(),
|
||||
tty: false,
|
||||
arg0: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
|
||||
let (output, exit_code, closed) = collect_process_output_from_events(session).await?;
|
||||
assert_eq!(output, "queued output\n");
|
||||
assert_eq!(exit_code, 0);
|
||||
assert!(closed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test_case(false ; "local")]
|
||||
#[test_case(true ; "remote")]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
@@ -171,3 +201,10 @@ async fn exec_process_streams_output(use_remote: bool) -> Result<()> {
|
||||
async fn exec_process_write_then_read(use_remote: bool) -> Result<()> {
|
||||
assert_exec_process_write_then_read(use_remote).await
|
||||
}
|
||||
|
||||
#[test_case(false ; "local")]
|
||||
#[test_case(true ; "remote")]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exec_process_preserves_queued_events_before_subscribe(use_remote: bool) -> Result<()> {
|
||||
assert_exec_process_preserves_queued_events_before_subscribe(use_remote).await
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user