diff --git a/codex-rs/code-mode/src/lib.rs b/codex-rs/code-mode/src/lib.rs index 3da8c77325..37ba1d2a91 100644 --- a/codex-rs/code-mode/src/lib.rs +++ b/codex-rs/code-mode/src/lib.rs @@ -29,9 +29,18 @@ pub use runtime::WaitOutcome; pub use runtime::WaitRequest; pub use runtime::WaitToPendingOutcome; pub use runtime::WaitToPendingRequest; +pub use service::CellId; pub use service::CodeModeService; -pub use service::CodeModeTurnHost; -pub use service::CodeModeTurnWorker; +pub use service::CodeModeSession; +pub use service::CodeModeSessionDelegate; +pub use service::CodeModeSessionProvider; +pub use service::CodeModeSessionProviderFuture; +pub use service::CodeModeSessionResultFuture; +pub use service::InProcessCodeModeSessionProvider; +pub use service::NoopCodeModeSessionDelegate; +pub use service::NotificationFuture; +pub use service::StartedCell; +pub use service::ToolInvocationFuture; pub const PUBLIC_TOOL_NAME: &str = "exec"; pub const WAIT_TOOL_NAME: &str = "wait"; diff --git a/codex-rs/code-mode/src/runtime/mod.rs b/codex-rs/code-mode/src/runtime/mod.rs index 89893d2b25..0b5c91bf82 100644 --- a/codex-rs/code-mode/src/runtime/mod.rs +++ b/codex-rs/code-mode/src/runtime/mod.rs @@ -27,11 +27,6 @@ const EXIT_SENTINEL: &str = "__codex_code_mode_exit__"; #[derive(Clone, Debug)] pub struct ExecuteRequest { - /// Runtime cell id for this execution. - /// - /// Callers allocate this before execution so tracing, waits, and nested tool - /// calls can refer to the cell as soon as JavaScript starts. - pub cell_id: String, pub tool_call_id: String, pub enabled_tools: Vec, pub source: String, @@ -43,7 +38,6 @@ pub struct ExecuteRequest { pub struct WaitRequest { pub cell_id: String, pub yield_time_ms: u64, - pub terminate: bool, } #[derive(Clone, Debug)] @@ -133,16 +127,6 @@ pub struct CodeModeNestedToolCall { pub input: Option, } -#[derive(Debug)] -pub(crate) enum TurnMessage { - ToolCall(CodeModeNestedToolCall), - Notify { - cell_id: String, - call_id: String, - text: String, - }, -} - #[derive(Debug)] pub(crate) enum RuntimeCommand { ToolResponse { id: String, result: JsonValue }, @@ -460,7 +444,6 @@ mod tests { fn execute_request(source: &str) -> ExecuteRequest { ExecuteRequest { - cell_id: "1".to_string(), tool_call_id: "call_1".to_string(), enabled_tools: Vec::new(), source: source.to_string(), diff --git a/codex-rs/code-mode/src/service.rs b/codex-rs/code-mode/src/service.rs index de4ed13e58..d726c6a9c2 100644 --- a/codex-rs/code-mode/src/service.rs +++ b/codex-rs/code-mode/src/service.rs @@ -1,10 +1,15 @@ use std::collections::HashMap; +use std::fmt; +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; +use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; use std::time::Duration; -use async_trait::async_trait; +use serde::Deserialize; +use serde::Serialize; use serde_json::Value as JsonValue; use tokio::sync::Mutex; use tokio::sync::mpsc; @@ -22,35 +27,168 @@ use crate::runtime::RuntimeCommand; use crate::runtime::RuntimeControlCommand; use crate::runtime::RuntimeEvent; use crate::runtime::RuntimeResponse; -use crate::runtime::TurnMessage; use crate::runtime::WaitOutcome; use crate::runtime::WaitRequest; use crate::runtime::WaitToPendingOutcome; use crate::runtime::WaitToPendingRequest; use crate::runtime::spawn_runtime; -#[async_trait] -pub trait CodeModeTurnHost: Send + Sync { - async fn invoke_tool( - &self, +pub type CodeModeSessionResultFuture<'a, T> = + Pin> + Send + 'a>>; +pub type CodeModeSessionProviderFuture<'a> = + CodeModeSessionResultFuture<'a, Arc>; +pub type ToolInvocationFuture<'a> = + Pin> + Send + 'a>>; +pub type NotificationFuture<'a> = Pin> + Send + 'a>>; + +#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub struct CellId(String); + +impl CellId { + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl AsRef for CellId { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl fmt::Display for CellId { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str(self.as_str()) + } +} + +impl From for CellId { + fn from(value: String) -> Self { + Self(value) + } +} + +impl From for String { + fn from(value: CellId) -> Self { + value.0 + } +} + +pub struct StartedCell { + pub cell_id: CellId, + initial_response_rx: oneshot::Receiver, +} + +impl StartedCell { + pub async fn initial_response(self) -> Result { + self.initial_response_rx + .await + .map_err(|_| "exec runtime ended unexpectedly".to_string()) + } +} + +/// Host callbacks used by a code-mode session while cells are executing. +pub trait CodeModeSessionDelegate: Send + Sync { + fn invoke_tool<'a>( + &'a self, invocation: CodeModeNestedToolCall, cancellation_token: CancellationToken, - ) -> Result; + ) -> ToolInvocationFuture<'a>; - async fn notify(&self, call_id: String, cell_id: String, text: String) -> Result<(), String>; + fn notify<'a>( + &'a self, + call_id: String, + cell_id: CellId, + text: String, + ) -> NotificationFuture<'a>; +} + +pub struct NoopCodeModeSessionDelegate; + +impl CodeModeSessionDelegate for NoopCodeModeSessionDelegate { + fn invoke_tool<'a>( + &'a self, + _invocation: CodeModeNestedToolCall, + cancellation_token: CancellationToken, + ) -> ToolInvocationFuture<'a> { + Box::pin(async move { + cancellation_token.cancelled().await; + Err("code mode nested tools are unavailable".to_string()) + }) + } + + fn notify<'a>( + &'a self, + _call_id: String, + _cell_id: CellId, + _text: String, + ) -> NotificationFuture<'a> { + Box::pin(async { Ok(()) }) + } +} + +/// A durable code-mode session. Implementations may execute cells in-process or remotely. +pub trait CodeModeSession: Send + Sync { + fn execute<'a>( + &'a self, + request: ExecuteRequest, + ) -> CodeModeSessionResultFuture<'a, StartedCell>; + + fn wait<'a>(&'a self, request: WaitRequest) -> CodeModeSessionResultFuture<'a, WaitOutcome>; + + fn terminate<'a>(&'a self, cell_id: String) -> CodeModeSessionResultFuture<'a, WaitOutcome>; + + fn stored_values<'a>( + &'a self, + ) -> Pin> + Send + 'a>>; + + fn replace_stored_values<'a>( + &'a self, + values: HashMap, + ) -> Pin + Send + 'a>>; + + fn shutdown<'a>(&'a self) -> CodeModeSessionResultFuture<'a, ()>; +} + +/// Creates code-mode sessions for one host. +/// +/// Providers choose where a session executes and receive the host delegate that +/// the session should use for nested tool calls and notifications. +pub trait CodeModeSessionProvider: Send + Sync { + fn create_session<'a>( + &'a self, + delegate: Arc, + ) -> CodeModeSessionProviderFuture<'a>; +} + +#[derive(Default)] +pub struct InProcessCodeModeSessionProvider; + +impl CodeModeSessionProvider for InProcessCodeModeSessionProvider { + fn create_session<'a>( + &'a self, + delegate: Arc, + ) -> CodeModeSessionProviderFuture<'a> { + Box::pin(async move { + let session: Arc = + Arc::new(CodeModeService::with_delegate(delegate)); + Ok(session) + }) + } } #[derive(Clone)] struct SessionHandle { control_tx: mpsc::UnboundedSender, runtime_tx: std::sync::mpsc::Sender, + cancellation_token: CancellationToken, } struct Inner { stored_values: Mutex>, sessions: Mutex>, - turn_message_tx: async_channel::Sender, - turn_message_rx: async_channel::Receiver, + delegate: Arc, + shutting_down: AtomicBool, next_cell_id: AtomicU64, } @@ -60,14 +198,16 @@ pub struct CodeModeService { impl CodeModeService { pub fn new() -> Self { - let (turn_message_tx, turn_message_rx) = async_channel::unbounded(); + Self::with_delegate(Arc::new(NoopCodeModeSessionDelegate)) + } + pub fn with_delegate(delegate: Arc) -> Self { Self { inner: Arc::new(Inner { stored_values: Mutex::new(HashMap::new()), sessions: Mutex::new(HashMap::new()), - turn_message_tx, - turn_message_rx, + delegate, + shutting_down: AtomicBool::new(false), next_cell_id: AtomicU64::new(1), }), } @@ -77,23 +217,28 @@ impl CodeModeService { self.inner.stored_values.lock().await.clone() } - /// Reserves the runtime cell id for a future `execute` request. - /// - /// The runtime can issue nested tool calls before the first `execute` - /// response is returned. Hosts that need a parent trace object for those - /// nested calls should allocate the cell id up front and pass it back on the - /// `ExecuteRequest`. - pub fn allocate_cell_id(&self) -> String { - self.inner - .next_cell_id - .fetch_add(1, Ordering::Relaxed) - .to_string() + pub async fn replace_stored_values(&self, values: HashMap) { + *self.inner.stored_values.lock().await = values; } - pub async fn execute(&self, request: ExecuteRequest) -> Result { + fn allocate_cell_id(&self) -> CellId { + CellId( + self.inner + .next_cell_id + .fetch_add(1, Ordering::Relaxed) + .to_string(), + ) + } + + pub async fn execute(&self, request: ExecuteRequest) -> Result { + if self.inner.shutting_down.load(Ordering::Acquire) { + return Err("code mode session is shutting down".to_string()); + } let initial_yield_time_ms = request.yield_time_ms.unwrap_or(DEFAULT_EXEC_YIELD_TIME_MS); let (response_tx, response_rx) = oneshot::channel(); + let cell_id = self.allocate_cell_id(); self.start_session( + cell_id.clone(), request, SessionResponseSender::Runtime(response_tx), Some(initial_yield_time_ms), @@ -101,9 +246,10 @@ impl CodeModeService { ) .await?; - response_rx - .await - .map_err(|_| "exec runtime ended unexpectedly".to_string()) + Ok(StartedCell { + cell_id, + initial_response_rx: response_rx, + }) } pub async fn execute_to_pending( @@ -111,7 +257,9 @@ impl CodeModeService { request: ExecuteRequest, ) -> Result { let (response_tx, response_rx) = oneshot::channel(); + let cell_id = self.allocate_cell_id(); self.start_session( + cell_id, request, SessionResponseSender::ExecuteToPending(response_tx), /*initial_yield_time_ms*/ None, @@ -126,32 +274,32 @@ impl CodeModeService { async fn start_session( &self, + cell_id: CellId, request: ExecuteRequest, initial_response_tx: SessionResponseSender, initial_yield_time_ms: Option, pending_mode: PendingRuntimeMode, ) -> Result<(), String> { - let cell_id = request.cell_id.clone(); + let cell_id_text = cell_id.to_string(); let (event_tx, event_rx) = mpsc::unbounded_channel(); let (control_tx, control_rx) = mpsc::unbounded_channel(); let stored_values = self.stored_values().await; + let cancellation_token = CancellationToken::new(); let (runtime_tx, runtime_control_tx, runtime_terminate_handle) = { let mut sessions = self.inner.sessions.lock().await; - if sessions.contains_key(&cell_id) { + if sessions.contains_key(&cell_id_text) { return Err(format!("exec cell {cell_id} already exists")); } let (runtime_tx, runtime_control_tx, runtime_terminate_handle) = spawn_runtime(stored_values, request, event_tx, pending_mode)?; - // Keep the session registry locked through insertion so a - // caller-owned cell id cannot race with another execute and replace - // a live runtime. sessions.insert( - cell_id.clone(), + cell_id_text.clone(), SessionHandle { control_tx, runtime_tx: runtime_tx.clone(), + cancellation_token: cancellation_token.clone(), }, ); (runtime_tx, runtime_control_tx, runtime_terminate_handle) @@ -160,11 +308,12 @@ impl CodeModeService { tokio::spawn(run_session_control( Arc::clone(&self.inner), SessionControlContext { - cell_id: cell_id.clone(), + cell_id: cell_id_text, runtime_tx, runtime_control_tx, pending_mode, runtime_terminate_handle, + cancellation_token, }, event_rx, control_rx, @@ -188,13 +337,9 @@ impl CodeModeService { return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))); }; let (response_tx, response_rx) = oneshot::channel(); - let control_message = if request.terminate { - SessionControlCommand::Terminate { response_tx } - } else { - SessionControlCommand::Poll { - yield_time_ms: request.yield_time_ms, - response_tx, - } + let control_message = SessionControlCommand::Poll { + yield_time_ms: request.yield_time_ms, + response_tx, }; if handle.control_tx.send(control_message).is_err() { return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))); @@ -207,6 +352,25 @@ impl CodeModeService { } } + pub async fn terminate(&self, cell_id: String) -> Result { + let handle = self.inner.sessions.lock().await.get(&cell_id).cloned(); + let Some(handle) = handle else { + return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))); + }; + let (response_tx, response_rx) = oneshot::channel(); + if handle + .control_tx + .send(SessionControlCommand::Terminate { response_tx }) + .is_err() + { + return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))); + } + match response_rx.await { + Ok(response) => Ok(WaitOutcome::LiveCell(response)), + Err(_) => Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))), + } + } + pub async fn wait_to_pending( &self, request: WaitToPendingRequest, @@ -242,69 +406,28 @@ impl CodeModeService { } } - pub fn start_turn_worker(&self, host: Arc) -> CodeModeTurnWorker { - let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); - let inner = Arc::clone(&self.inner); - let turn_message_rx = self.inner.turn_message_rx.clone(); - - tokio::spawn(async move { - loop { - let next_message = tokio::select! { - _ = &mut shutdown_rx => break, - message = turn_message_rx.recv() => message.ok(), - }; - let Some(next_message) = next_message else { - break; - }; - match next_message { - TurnMessage::Notify { - cell_id, - call_id, - text, - } => { - if let Err(err) = host.notify(call_id, cell_id.clone(), text).await { - warn!( - "failed to deliver code mode notification for cell {cell_id}: {err}" - ); - } - } - TurnMessage::ToolCall(invocation) => { - let host = Arc::clone(&host); - let inner = Arc::clone(&inner); - tokio::spawn(async move { - let cell_id = invocation.cell_id.clone(); - let runtime_tool_call_id = invocation.runtime_tool_call_id.clone(); - let response = - host.invoke_tool(invocation, CancellationToken::new()).await; - let runtime_tx = inner - .sessions - .lock() - .await - .get(&cell_id) - .map(|handle| handle.runtime_tx.clone()); - let Some(runtime_tx) = runtime_tx else { - return; - }; - let command = match response { - Ok(result) => RuntimeCommand::ToolResponse { - id: runtime_tool_call_id, - result, - }, - Err(error_text) => RuntimeCommand::ToolError { - id: runtime_tool_call_id, - error_text, - }, - }; - let _ = runtime_tx.send(command); - }); - } - } - } - }); - - CodeModeTurnWorker { - shutdown_tx: Some(shutdown_tx), + pub async fn shutdown(&self) -> Result<(), String> { + self.inner.shutting_down.store(true, Ordering::Release); + let handles = self + .inner + .sessions + .lock() + .await + .values() + .cloned() + .collect::>(); + for handle in handles { + handle.cancellation_token.cancel(); + let (response_tx, _response_rx) = oneshot::channel(); + let _ = handle + .control_tx + .send(SessionControlCommand::Terminate { response_tx }); + let _ = handle.runtime_tx.send(RuntimeCommand::Terminate); } + while !self.inner.sessions.lock().await.is_empty() { + tokio::task::yield_now().await; + } + Ok(()) } } @@ -314,15 +437,53 @@ impl Default for CodeModeService { } } -pub struct CodeModeTurnWorker { - shutdown_tx: Option>, +impl Drop for CodeModeService { + fn drop(&mut self) { + self.inner.shutting_down.store(true, Ordering::Release); + if let Ok(sessions) = self.inner.sessions.try_lock() { + for handle in sessions.values() { + handle.cancellation_token.cancel(); + let (response_tx, _response_rx) = oneshot::channel(); + let _ = handle + .control_tx + .send(SessionControlCommand::Terminate { response_tx }); + let _ = handle.runtime_tx.send(RuntimeCommand::Terminate); + } + } + } } -impl Drop for CodeModeTurnWorker { - fn drop(&mut self) { - if let Some(shutdown_tx) = self.shutdown_tx.take() { - let _ = shutdown_tx.send(()); - } +impl CodeModeSession for CodeModeService { + fn execute<'a>( + &'a self, + request: ExecuteRequest, + ) -> CodeModeSessionResultFuture<'a, StartedCell> { + Box::pin(CodeModeService::execute(self, request)) + } + + fn wait<'a>(&'a self, request: WaitRequest) -> CodeModeSessionResultFuture<'a, WaitOutcome> { + Box::pin(CodeModeService::wait(self, request)) + } + + fn terminate<'a>(&'a self, cell_id: String) -> CodeModeSessionResultFuture<'a, WaitOutcome> { + Box::pin(CodeModeService::terminate(self, cell_id)) + } + + fn stored_values<'a>( + &'a self, + ) -> Pin> + Send + 'a>> { + Box::pin(CodeModeService::stored_values(self)) + } + + fn replace_stored_values<'a>( + &'a self, + values: HashMap, + ) -> Pin + Send + 'a>> { + Box::pin(CodeModeService::replace_stored_values(self, values)) + } + + fn shutdown<'a>(&'a self) -> CodeModeSessionResultFuture<'a, ()> { + Box::pin(CodeModeService::shutdown(self)) } } @@ -355,6 +516,7 @@ struct SessionControlContext { runtime_control_tx: std::sync::mpsc::Sender, pending_mode: PendingRuntimeMode, runtime_terminate_handle: v8::IsolateHandle, + cancellation_token: CancellationToken, } fn missing_cell_response(cell_id: String) -> RuntimeResponse { @@ -437,6 +599,7 @@ async fn run_session_control( runtime_control_tx, pending_mode, runtime_terminate_handle, + cancellation_token, } = context; let mut content_items = Vec::new(); let mut pending_tool_call_ids = Vec::new(); @@ -516,11 +679,15 @@ async fn run_session_control( send_yield_response(&cell_id, &mut content_items, &mut response_tx); } RuntimeEvent::Notify { call_id, text } => { - let _ = inner.turn_message_tx.send(TurnMessage::Notify { - cell_id: cell_id.clone(), - call_id, - text, - }).await; + let delegate = Arc::clone(&inner.delegate); + let cell_id = CellId::from(cell_id.clone()); + tokio::spawn(async move { + if let Err(err) = delegate.notify(call_id, cell_id.clone(), text).await { + warn!( + "failed to deliver code mode notification for cell {cell_id}: {err}" + ); + } + }); } RuntimeEvent::ToolCall { id, @@ -533,15 +700,25 @@ async fn run_session_control( } let tool_call = CodeModeNestedToolCall { cell_id: cell_id.clone(), - runtime_tool_call_id: id, + runtime_tool_call_id: id.clone(), tool_name: name, tool_kind: kind, input, }; - let _ = inner - .turn_message_tx - .send(TurnMessage::ToolCall(tool_call)) - .await; + let delegate = Arc::clone(&inner.delegate); + let runtime_tx = runtime_tx.clone(); + let cancellation_token = cancellation_token.child_token(); + tokio::spawn(async move { + let response = tokio::select! { + response = delegate.invoke_tool(tool_call, cancellation_token.clone()) => response, + _ = cancellation_token.cancelled() => return, + }; + let command = match response { + Ok(result) => RuntimeCommand::ToolResponse { id, result }, + Err(error_text) => RuntimeCommand::ToolError { id, error_text }, + }; + let _ = runtime_tx.send(command); + }); } RuntimeEvent::Result { stored_value_writes, @@ -562,7 +739,7 @@ async fn run_session_control( .stored_values .lock() .await - .extend(stored_value_writes); + .extend(stored_value_writes.clone()); let result = PendingResult { content_items: std::mem::take(&mut content_items), error_text, @@ -617,6 +794,7 @@ async fn run_session_control( response_tx = Some(SessionResponseSender::Runtime(next_response_tx)); termination_requested = true; + cancellation_token.cancel(); yield_timer = None; let _ = runtime_tx.send(RuntimeCommand::Terminate); terminate_paused_runtime(&runtime_control_tx, pending_mode); @@ -650,6 +828,7 @@ async fn run_session_control( } let _ = runtime_tx.send(RuntimeCommand::Terminate); + cancellation_token.cancel(); terminate_paused_runtime(&runtime_control_tx, pending_mode); inner.sessions.lock().await.remove(&cell_id); } @@ -687,6 +866,7 @@ mod tests { use super::CodeModeService; use super::Inner; + use super::NoopCodeModeSessionDelegate; use super::PendingRuntimeMode; use super::RuntimeCommand; use super::RuntimeResponse; @@ -708,7 +888,6 @@ mod tests { fn execute_request(source: &str) -> ExecuteRequest { ExecuteRequest { - cell_id: "1".to_string(), tool_call_id: "call_1".to_string(), enabled_tools: Vec::new(), source: source.to_string(), @@ -717,13 +896,22 @@ mod tests { } } + async fn execute(service: &CodeModeService, request: ExecuteRequest) -> RuntimeResponse { + service + .execute(request) + .await + .unwrap() + .initial_response() + .await + .unwrap() + } + fn test_inner() -> Arc { - let (turn_message_tx, turn_message_rx) = async_channel::unbounded(); Arc::new(Inner { stored_values: Mutex::new(HashMap::new()), sessions: Mutex::new(HashMap::new()), - turn_message_tx, - turn_message_rx, + delegate: Arc::new(NoopCodeModeSessionDelegate), + shutting_down: std::sync::atomic::AtomicBool::new(false), next_cell_id: AtomicU64::new(1), }) } @@ -732,14 +920,15 @@ mod tests { async fn synchronous_exit_returns_successfully() { let service = CodeModeService::new(); - let response = service - .execute(ExecuteRequest { + let response = execute( + &service, + ExecuteRequest { source: r#"text("before"); exit(); text("after");"#.to_string(), yield_time_ms: None, ..execute_request("") - }) - .await - .unwrap(); + }, + ) + .await; assert_eq!( response, @@ -753,6 +942,31 @@ mod tests { ); } + #[tokio::test] + async fn shutdown_interrupts_cpu_bound_cells() { + let service = CodeModeService::new(); + + let cell = service + .execute(ExecuteRequest { + source: "while (true) {}".to_string(), + ..execute_request("") + }) + .await + .unwrap(); + assert_eq!( + cell.initial_response().await.unwrap(), + RuntimeResponse::Yielded { + cell_id: "1".to_string(), + content_items: Vec::new(), + } + ); + + tokio::time::timeout(Duration::from_secs(1), service.shutdown()) + .await + .unwrap() + .unwrap(); + } + #[tokio::test] async fn execute_to_pending_returns_completed_for_synchronous_results() { let service = CodeModeService::new(); @@ -805,14 +1019,7 @@ mod tests { } ); - let termination = service - .wait(WaitRequest { - cell_id: "1".to_string(), - yield_time_ms: 1, - terminate: true, - }) - .await - .unwrap(); + let termination = service.terminate("1".to_string()).await.unwrap(); assert_eq!( termination, @@ -859,14 +1066,7 @@ await Promise.all([ } ); - let termination = service - .wait(WaitRequest { - cell_id: "1".to_string(), - yield_time_ms: 1, - terminate: true, - }) - .await - .unwrap(); + let termination = service.terminate("1".to_string()).await.unwrap(); assert_eq!( termination, @@ -948,14 +1148,7 @@ await Promise.all([ }) ); - let termination = service - .wait(WaitRequest { - cell_id: "1".to_string(), - yield_time_ms: 1, - terminate: true, - }) - .await - .unwrap(); + let termination = service.terminate("1".to_string()).await.unwrap(); assert_eq!( termination, @@ -1027,14 +1220,7 @@ await new Promise(() => {}); }) ); - let termination = service - .wait(WaitRequest { - cell_id: "1".to_string(), - yield_time_ms: 1, - terminate: true, - }) - .await - .unwrap(); + let termination = service.terminate("1".to_string()).await.unwrap(); assert_eq!( termination, @@ -1112,14 +1298,15 @@ text("done"); async fn v8_console_is_not_exposed_on_global_this() { let service = CodeModeService::new(); - let response = service - .execute(ExecuteRequest { + let response = execute( + &service, + ExecuteRequest { source: r#"text(String(Object.hasOwn(globalThis, "console")));"#.to_string(), yield_time_ms: None, ..execute_request("") - }) - .await - .unwrap(); + }, + ) + .await; assert_eq!( response, @@ -1137,8 +1324,9 @@ text("done"); async fn date_locale_string_formats_with_icu_data() { let service = CodeModeService::new(); - let response = service - .execute(ExecuteRequest { + let response = execute( + &service, + ExecuteRequest { source: r#" const value = new Date("2025-01-02T03:04:05Z") .toLocaleString("fr-FR", { @@ -1156,9 +1344,9 @@ text(value); .to_string(), yield_time_ms: None, ..execute_request("") - }) - .await - .unwrap(); + }, + ) + .await; assert_eq!( response, @@ -1176,8 +1364,9 @@ text(value); async fn intl_date_time_format_formats_with_icu_data() { let service = CodeModeService::new(); - let response = service - .execute(ExecuteRequest { + let response = execute( + &service, + ExecuteRequest { source: r#" const formatter = new Intl.DateTimeFormat("fr-FR", { weekday: "long", @@ -1194,9 +1383,9 @@ text(formatter.format(new Date("2025-01-02T03:04:05Z"))); .to_string(), yield_time_ms: None, ..execute_request("") - }) - .await - .unwrap(); + }, + ) + .await; assert_eq!( response, @@ -1214,8 +1403,9 @@ text(formatter.format(new Date("2025-01-02T03:04:05Z"))); async fn output_helpers_return_undefined() { let service = CodeModeService::new(); - let response = service - .execute(ExecuteRequest { + let response = execute( + &service, + ExecuteRequest { source: r#" const returnsUndefined = [ text("first"), @@ -1227,9 +1417,9 @@ text(JSON.stringify(returnsUndefined)); .to_string(), yield_time_ms: None, ..execute_request("") - }) - .await - .unwrap(); + }, + ) + .await; assert_eq!( response, @@ -1256,8 +1446,9 @@ text(JSON.stringify(returnsUndefined)); async fn image_helper_accepts_raw_mcp_image_block_with_original_detail() { let service = CodeModeService::new(); - let response = service - .execute(ExecuteRequest { + let response = execute( + &service, + ExecuteRequest { source: r#" image({ type: "image", @@ -1269,9 +1460,9 @@ image({ .to_string(), yield_time_ms: None, ..execute_request("") - }) - .await - .unwrap(); + }, + ) + .await; assert_eq!( response, @@ -1290,8 +1481,9 @@ image({ async fn image_helper_second_arg_overrides_explicit_object_detail() { let service = CodeModeService::new(); - let response = service - .execute(ExecuteRequest { + let response = execute( + &service, + ExecuteRequest { source: r#" image( { @@ -1304,9 +1496,9 @@ image( .to_string(), yield_time_ms: None, ..execute_request("") - }) - .await - .unwrap(); + }, + ) + .await; assert_eq!( response, @@ -1325,8 +1517,9 @@ image( async fn image_helper_second_arg_overrides_raw_mcp_image_detail() { let service = CodeModeService::new(); - let response = service - .execute(ExecuteRequest { + let response = execute( + &service, + ExecuteRequest { source: r#" image( { @@ -1341,9 +1534,9 @@ image( .to_string(), yield_time_ms: None, ..execute_request("") - }) - .await - .unwrap(); + }, + ) + .await; assert_eq!( response, @@ -1362,8 +1555,9 @@ image( async fn image_helper_rejects_unsupported_detail() { let service = CodeModeService::new(); - let response = service - .execute(ExecuteRequest { + let response = execute( + &service, + ExecuteRequest { source: r#" image({ image_url: "https://example.com/image.jpg", @@ -1373,9 +1567,9 @@ image({ .to_string(), yield_time_ms: None, ..execute_request("") - }) - .await - .unwrap(); + }, + ) + .await; assert_eq!( response, @@ -1391,8 +1585,9 @@ image({ async fn image_helper_rejects_raw_mcp_result_container() { let service = CodeModeService::new(); - let response = service - .execute(ExecuteRequest { + let response = execute( + &service, + ExecuteRequest { source: r#" image({ content: [ @@ -1409,9 +1604,9 @@ image({ .to_string(), yield_time_ms: None, ..execute_request("") - }) - .await - .unwrap(); + }, + ) + .await; assert_eq!( response, @@ -1433,7 +1628,6 @@ image({ .wait(WaitRequest { cell_id: "missing".to_string(), yield_time_ms: 1, - terminate: false, }) .await .unwrap(); @@ -1475,6 +1669,7 @@ image({ runtime_control_tx, pending_mode: PendingRuntimeMode::Continue, runtime_terminate_handle, + cancellation_token: tokio_util::sync::CancellationToken::new(), }, event_rx, control_rx, diff --git a/codex-rs/core/src/session/handlers.rs b/codex-rs/core/src/session/handlers.rs index b1e36b0347..f30e8d0b83 100644 --- a/codex-rs/core/src/session/handlers.rs +++ b/codex-rs/core/src/session/handlers.rs @@ -603,6 +603,9 @@ async fn shutdown_session_runtime(sess: &Arc) { .unified_exec_manager .terminate_all_processes() .await; + if let Err(err) = sess.services.code_mode_service.shutdown().await { + warn!("failed to shutdown code mode session: {err}"); + } let mcp_shutdown = { let mut manager = sess.services.mcp_connection_manager.write().await; manager.begin_shutdown() diff --git a/codex-rs/core/src/session/turn.rs b/codex-rs/core/src/session/turn.rs index 5a2ed71dda..3ae8112125 100644 --- a/codex-rs/core/src/session/turn.rs +++ b/codex-rs/core/src/session/turn.rs @@ -909,16 +909,12 @@ async fn run_sampling_request( Arc::clone(&turn_context), Arc::clone(&turn_diff_tracker), ); - let _code_mode_worker = sess - .services - .code_mode_service - .start_turn_worker( - &sess, - &turn_context, - Arc::clone(&router), - Arc::clone(&turn_diff_tracker), - ) - .await; + let _code_mode_worker = sess.services.code_mode_service.start_turn_worker( + &sess, + &turn_context, + Arc::clone(&router), + Arc::clone(&turn_diff_tracker), + ); let mut retries = 0; let mut initial_input = Some(input); loop { diff --git a/codex-rs/core/src/tools/code_mode/delegate.rs b/codex-rs/core/src/tools/code_mode/delegate.rs new file mode 100644 index 0000000000..030a150b22 --- /dev/null +++ b/codex-rs/core/src/tools/code_mode/delegate.rs @@ -0,0 +1,234 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::Mutex; + +use codex_code_mode::CellId; +use codex_code_mode::CodeModeNestedToolCall; +use codex_code_mode::CodeModeSessionDelegate; +use codex_code_mode::NotificationFuture; +use codex_code_mode::ToolInvocationFuture; +use codex_protocol::models::FunctionCallOutputPayload; +use codex_protocol::models::ResponseInputItem; +use serde_json::Value as JsonValue; +use tokio::sync::oneshot; +use tokio::sync::watch; +use tokio_util::sync::CancellationToken; +use tracing::warn; + +use super::ExecContext; +use super::PUBLIC_TOOL_NAME; +use super::call_nested_tool; +use crate::tools::ToolRouter; +use crate::tools::context::SharedTurnDiffTracker; +use crate::tools::parallel::ToolCallRuntime; + +pub(super) struct CodeModeDispatchBroker { + dispatch_tx: async_channel::Sender, + dispatch_rx: async_channel::Receiver, + dispatch_gates: Arc>>>, +} + +impl CodeModeDispatchBroker { + pub(super) fn new() -> Self { + let (dispatch_tx, dispatch_rx) = async_channel::unbounded(); + Self { + dispatch_tx, + dispatch_rx, + dispatch_gates: Arc::new(Mutex::new(HashMap::new())), + } + } + + pub(super) fn mark_cell_ready_for_dispatch(&self, cell_id: &CellId) { + dispatch_gate(&self.dispatch_gates, cell_id.as_str()).send_replace(true); + } + + pub(super) fn start_turn_worker( + &self, + exec: ExecContext, + router: Arc, + tracker: SharedTurnDiffTracker, + ) -> CodeModeDispatchWorker { + let tool_runtime = ToolCallRuntime::new( + router, + Arc::clone(&exec.session), + Arc::clone(&exec.turn), + tracker, + ); + let host = Arc::new(CoreTurnHost { exec, tool_runtime }); + let dispatch_rx = self.dispatch_rx.clone(); + let dispatch_gates = Arc::clone(&self.dispatch_gates); + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + tokio::spawn(async move { + loop { + let message = tokio::select! { + _ = &mut shutdown_rx => break, + message = dispatch_rx.recv() => message.ok(), + }; + let Some(message) = message else { + break; + }; + match message { + DispatchMessage::Notify { + call_id, + cell_id, + text, + } => { + wait_until_cell_ready_for_dispatch(&dispatch_gates, &cell_id).await; + if let Err(err) = host.notify(call_id, cell_id.clone(), text).await { + warn!( + "failed to deliver code mode notification for cell {cell_id}: {err}" + ); + } + } + DispatchMessage::InvokeTool { + invocation, + cancellation_token, + response_tx, + } => { + wait_until_cell_ready_for_dispatch(&dispatch_gates, &invocation.cell_id) + .await; + let host = Arc::clone(&host); + tokio::spawn(async move { + let response = host.invoke_tool(invocation, cancellation_token).await; + let _ = response_tx.send(response); + }); + } + } + } + }); + CodeModeDispatchWorker { + shutdown_tx: Some(shutdown_tx), + } + } +} + +fn dispatch_gate( + dispatch_gates: &Mutex>>, + cell_id: &str, +) -> watch::Sender { + let mut dispatch_gates = match dispatch_gates.lock() { + Ok(dispatch_gates) => dispatch_gates, + Err(poisoned) => poisoned.into_inner(), + }; + dispatch_gates + .entry(cell_id.to_string()) + .or_insert_with(|| watch::channel(false).0) + .clone() +} + +async fn wait_until_cell_ready_for_dispatch( + dispatch_gates: &Mutex>>, + cell_id: &str, +) { + let mut ready_rx = dispatch_gate(dispatch_gates, cell_id).subscribe(); + while !*ready_rx.borrow_and_update() && ready_rx.changed().await.is_ok() {} +} + +impl CodeModeSessionDelegate for CodeModeDispatchBroker { + fn invoke_tool<'a>( + &'a self, + invocation: CodeModeNestedToolCall, + cancellation_token: CancellationToken, + ) -> ToolInvocationFuture<'a> { + Box::pin(async move { + let (response_tx, response_rx) = oneshot::channel(); + self.dispatch_tx + .send(DispatchMessage::InvokeTool { + invocation, + cancellation_token: cancellation_token.clone(), + response_tx, + }) + .await + .map_err(|_| "code mode nested tool dispatcher is unavailable".to_string())?; + tokio::select! { + response = response_rx => response + .map_err(|_| "code mode nested tool dispatcher stopped".to_string())?, + _ = cancellation_token.cancelled() => { + Err("code mode nested tool call cancelled".to_string()) + } + } + }) + } + + fn notify<'a>( + &'a self, + call_id: String, + cell_id: CellId, + text: String, + ) -> NotificationFuture<'a> { + Box::pin(async move { + self.dispatch_tx + .send(DispatchMessage::Notify { + call_id, + cell_id: cell_id.to_string(), + text, + }) + .await + .map_err(|_| "code mode notification dispatcher is unavailable".to_string()) + }) + } +} + +enum DispatchMessage { + InvokeTool { + invocation: CodeModeNestedToolCall, + cancellation_token: CancellationToken, + response_tx: oneshot::Sender>, + }, + Notify { + call_id: String, + cell_id: String, + text: String, + }, +} + +pub(crate) struct CodeModeDispatchWorker { + shutdown_tx: Option>, +} + +impl Drop for CodeModeDispatchWorker { + fn drop(&mut self) { + if let Some(shutdown_tx) = self.shutdown_tx.take() { + let _ = shutdown_tx.send(()); + } + } +} + +struct CoreTurnHost { + exec: ExecContext, + tool_runtime: ToolCallRuntime, +} + +impl CoreTurnHost { + async fn invoke_tool( + &self, + invocation: CodeModeNestedToolCall, + cancellation_token: CancellationToken, + ) -> Result { + call_nested_tool( + self.exec.clone(), + self.tool_runtime.clone(), + invocation, + cancellation_token, + ) + .await + .map_err(|error| error.to_string()) + } + + async fn notify(&self, call_id: String, cell_id: String, text: String) -> Result<(), String> { + if text.trim().is_empty() { + return Ok(()); + } + self.exec + .session + .inject_response_items(vec![ResponseInputItem::CustomToolCallOutput { + call_id, + name: Some(PUBLIC_TOOL_NAME.to_string()), + output: FunctionCallOutputPayload::from_text(text), + }]) + .await + .map_err(|_| { + format!("failed to inject exec notify message for cell {cell_id}: no active turn") + }) + } +} diff --git a/codex-rs/core/src/tools/code_mode/execute_handler.rs b/codex-rs/core/src/tools/code_mode/execute_handler.rs index 29f5623849..ab8a8a5107 100644 --- a/codex-rs/core/src/tools/code_mode/execute_handler.rs +++ b/codex-rs/core/src/tools/code_mode/execute_handler.rs @@ -38,9 +38,21 @@ impl CodeModeExecuteHandler { let exec = ExecContext { session, turn }; let enabled_tools = codex_tools::collect_code_mode_tool_definitions(&self.nested_tool_specs); - // Allocate before starting V8 so the trace can create the parent - // CodeCell before model-authored JavaScript issues nested tool calls. - let runtime_cell_id = exec.session.services.code_mode_service.allocate_cell_id(); + let started_at = std::time::Instant::now(); + let started_cell = exec + .session + .services + .code_mode_service + .execute(codex_code_mode::ExecuteRequest { + tool_call_id: call_id.clone(), + enabled_tools, + source: args.code.clone(), + yield_time_ms: args.yield_time_ms, + max_output_tokens: args.max_output_tokens, + }) + .await + .map_err(FunctionCallError::RespondToModel)?; + let runtime_cell_id = started_cell.cell_id.to_string(); let code_cell_trace = exec .session .services @@ -51,19 +63,12 @@ impl CodeModeExecuteHandler { call_id.as_str(), args.code.as_str(), ); - let started_at = std::time::Instant::now(); - let response = exec - .session + exec.session .services .code_mode_service - .execute(codex_code_mode::ExecuteRequest { - cell_id: runtime_cell_id, - tool_call_id: call_id, - enabled_tools, - source: args.code, - yield_time_ms: args.yield_time_ms, - max_output_tokens: args.max_output_tokens, - }) + .mark_cell_ready_for_dispatch(&started_cell.cell_id); + let response = started_cell + .initial_response() .await .map_err(FunctionCallError::RespondToModel)?; // Record the raw runtime boundary. The model-visible custom-tool output diff --git a/codex-rs/core/src/tools/code_mode/mod.rs b/codex-rs/core/src/tools/code_mode/mod.rs index 09d58ea8f3..228a33f6f9 100644 --- a/codex-rs/core/src/tools/code_mode/mod.rs +++ b/codex-rs/core/src/tools/code_mode/mod.rs @@ -1,3 +1,4 @@ +mod delegate; mod execute_handler; pub(crate) mod execute_spec; mod response_adapter; @@ -8,12 +9,10 @@ use std::sync::Arc; use std::time::Duration; use codex_code_mode::CodeModeNestedToolCall; +use codex_code_mode::CodeModeSession; use codex_code_mode::CodeModeToolKind; -use codex_code_mode::CodeModeTurnHost; use codex_code_mode::RuntimeResponse; use codex_protocol::models::FunctionCallOutputContentItem; -use codex_protocol::models::FunctionCallOutputPayload; -use codex_protocol::models::ResponseInputItem; use serde_json::Value as JsonValue; use tokio_util::sync::CancellationToken; @@ -36,6 +35,8 @@ use codex_utils_output_truncation::TruncationPolicy; use codex_utils_output_truncation::formatted_truncate_text_content_items_with_policy; use codex_utils_output_truncation::truncate_function_output_items_with_policy; +use delegate::CodeModeDispatchBroker; +use delegate::CodeModeDispatchWorker; pub(crate) use execute_handler::CodeModeExecuteHandler; use response_adapter::into_function_call_output_content_items; pub(crate) use wait_handler::CodeModeWaitHandler; @@ -56,42 +57,61 @@ pub(crate) struct ExecContext { } pub(crate) struct CodeModeService { - inner: codex_code_mode::CodeModeService, + session: Option>, + dispatch_broker: Arc, } impl CodeModeService { pub(crate) fn new() -> Self { + let dispatch_broker = Arc::new(CodeModeDispatchBroker::new()); Self { - inner: codex_code_mode::CodeModeService::new(), + session: Some(Arc::new(codex_code_mode::CodeModeService::with_delegate( + dispatch_broker.clone(), + ))), + dispatch_broker, } } - pub(crate) fn allocate_cell_id(&self) -> String { - self.inner.allocate_cell_id() - } - pub(crate) async fn execute( &self, request: codex_code_mode::ExecuteRequest, - ) -> Result { - self.inner.execute(request).await + ) -> Result { + self.session()?.execute(request).await } pub(crate) async fn wait( &self, request: codex_code_mode::WaitRequest, ) -> Result { - self.inner.wait(request).await + self.session()?.wait(request).await } - pub(crate) async fn start_turn_worker( + pub(crate) async fn terminate( + &self, + cell_id: String, + ) -> Result { + self.session()?.terminate(cell_id).await + } + + pub(crate) async fn shutdown(&self) -> Result<(), String> { + match &self.session { + Some(session) => session.shutdown().await, + None => Ok(()), + } + } + + pub(crate) fn mark_cell_ready_for_dispatch(&self, cell_id: &codex_code_mode::CellId) { + self.dispatch_broker.mark_cell_ready_for_dispatch(cell_id); + } + + pub(crate) fn start_turn_worker( &self, session: &Arc, turn: &Arc, router: Arc, tracker: SharedTurnDiffTracker, - ) -> Option { - if !turn.features.enabled(Feature::CodeMode) { + ) -> Option { + if !turn.features.enabled(Feature::CodeMode) || self.session.is_none() { return None; } @@ -99,50 +119,16 @@ impl CodeModeService { session: Arc::clone(session), turn: Arc::clone(turn), }; - let tool_runtime = - ToolCallRuntime::new(router, Arc::clone(session), Arc::clone(turn), tracker); - let host = Arc::new(CoreTurnHost { exec, tool_runtime }); - Some(self.inner.start_turn_worker(host)) - } -} - -struct CoreTurnHost { - exec: ExecContext, - tool_runtime: ToolCallRuntime, -} - -#[async_trait::async_trait] -impl CodeModeTurnHost for CoreTurnHost { - async fn invoke_tool( - &self, - invocation: CodeModeNestedToolCall, - cancellation_token: CancellationToken, - ) -> Result { - call_nested_tool( - self.exec.clone(), - self.tool_runtime.clone(), - invocation, - cancellation_token, + Some( + self.dispatch_broker + .start_turn_worker(exec, router, tracker), ) - .await - .map_err(|error| error.to_string()) } - async fn notify(&self, call_id: String, cell_id: String, text: String) -> Result<(), String> { - if text.trim().is_empty() { - return Ok(()); - } - self.exec - .session - .inject_response_items(vec![ResponseInputItem::CustomToolCallOutput { - call_id, - name: Some(PUBLIC_TOOL_NAME.to_string()), - output: FunctionCallOutputPayload::from_text(text), - }]) - .await - .map_err(|_| { - format!("failed to inject exec notify message for cell {cell_id}: no active turn") - }) + fn session(&self) -> Result<&Arc, String> { + self.session + .as_ref() + .ok_or_else(|| "code mode is unavailable".to_string()) } } diff --git a/codex-rs/core/src/tools/code_mode/wait_handler.rs b/codex-rs/core/src/tools/code_mode/wait_handler.rs index d0c0453df6..d78efca626 100644 --- a/codex-rs/core/src/tools/code_mode/wait_handler.rs +++ b/codex-rs/core/src/tools/code_mode/wait_handler.rs @@ -73,17 +73,23 @@ impl ToolExecutor for CodeModeWaitHandler { let args: ExecWaitArgs = parse_arguments(&arguments)?; let exec = ExecContext { session, turn }; let started_at = std::time::Instant::now(); - let wait_response = exec - .session - .services - .code_mode_service - .wait(codex_code_mode::WaitRequest { - cell_id: args.cell_id, - yield_time_ms: args.yield_time_ms, - terminate: args.terminate, - }) - .await - .map_err(FunctionCallError::RespondToModel)?; + let wait_response = if args.terminate { + exec.session + .services + .code_mode_service + .terminate(args.cell_id) + .await + } else { + exec.session + .services + .code_mode_service + .wait(codex_code_mode::WaitRequest { + cell_id: args.cell_id, + yield_time_ms: args.yield_time_ms, + }) + .await + } + .map_err(FunctionCallError::RespondToModel)?; if let codex_code_mode::WaitOutcome::LiveCell(response) = &wait_response && !matches!(response, codex_code_mode::RuntimeResponse::Yielded { .. }) {