diff --git a/codex-rs/code-mode/src/service.rs b/codex-rs/code-mode/src/service.rs index c58d3ee36e..aa3c3c0e44 100644 --- a/codex-rs/code-mode/src/service.rs +++ b/codex-rs/code-mode/src/service.rs @@ -127,7 +127,11 @@ impl CodeModeSessionDelegate for NoopCodeModeSessionDelegate { } } -/// A durable code-mode session. Implementations may execute cells in-process or remotely. +/// A durable code-mode session owned by one Codex thread. +/// +/// Cells executed in the same session share stored values. Separate sessions +/// must keep those values isolated. Implementations may execute cells +/// in-process or remotely. pub trait CodeModeSession: Send + Sync { fn execute<'a>( &'a self, @@ -141,7 +145,7 @@ pub trait CodeModeSession: Send + Sync { fn shutdown<'a>(&'a self) -> CodeModeSessionResultFuture<'a, ()>; } -/// Creates code-mode sessions for one host. +/// Creates code-mode sessions for one Codex thread. /// /// Providers choose where a session executes and receive the host delegate that /// the session should use for nested tool calls and notifications. @@ -169,15 +173,15 @@ impl CodeModeSessionProvider for InProcessCodeModeSessionProvider { } #[derive(Clone)] -struct SessionHandle { - control_tx: mpsc::UnboundedSender, +struct CellHandle { + control_tx: mpsc::UnboundedSender, runtime_tx: std::sync::mpsc::Sender, cancellation_token: CancellationToken, } struct Inner { stored_values: Mutex>, - sessions: Mutex>, + cells: Mutex>, delegate: Arc, shutting_down: AtomicBool, next_cell_id: AtomicU64, @@ -196,7 +200,7 @@ impl CodeModeService { Self { inner: Arc::new(Inner { stored_values: Mutex::new(HashMap::new()), - sessions: Mutex::new(HashMap::new()), + cells: Mutex::new(HashMap::new()), delegate, shutting_down: AtomicBool::new(false), next_cell_id: AtomicU64::new(1), @@ -220,10 +224,10 @@ impl CodeModeService { let initial_yield_time_ms = request.yield_time_ms.unwrap_or(DEFAULT_EXEC_YIELD_TIME_MS); let (response_tx, response_rx) = oneshot::channel(); let cell_id = self.allocate_cell_id(); - self.start_session( + self.start_cell( cell_id.clone(), request, - SessionResponseSender::Runtime(response_tx), + CellResponseSender::Runtime(response_tx), Some(initial_yield_time_ms), PendingRuntimeMode::Continue, ) @@ -241,10 +245,10 @@ impl CodeModeService { ) -> Result { let (response_tx, response_rx) = oneshot::channel(); let cell_id = self.allocate_cell_id(); - self.start_session( + self.start_cell( cell_id, request, - SessionResponseSender::ExecuteToPending(response_tx), + CellResponseSender::ExecuteToPending(response_tx), /*initial_yield_time_ms*/ None, PendingRuntimeMode::PauseUntilResumed, ) @@ -255,31 +259,30 @@ impl CodeModeService { .map_err(|_| "exec runtime ended unexpectedly".to_string()) } - async fn start_session( + async fn start_cell( &self, cell_id: CellId, request: ExecuteRequest, - initial_response_tx: SessionResponseSender, + initial_response_tx: CellResponseSender, initial_yield_time_ms: Option, pending_mode: PendingRuntimeMode, ) -> Result<(), String> { - let cell_id_text = cell_id.to_string(); let (event_tx, event_rx) = mpsc::unbounded_channel(); let (control_tx, control_rx) = mpsc::unbounded_channel(); let stored_values = self.inner.stored_values.lock().await.clone(); let cancellation_token = CancellationToken::new(); let (runtime_tx, runtime_control_tx, runtime_terminate_handle) = { - let mut sessions = self.inner.sessions.lock().await; - if sessions.contains_key(&cell_id_text) { + let mut cells = self.inner.cells.lock().await; + if cells.contains_key(&cell_id) { return Err(format!("exec cell {cell_id} already exists")); } let (runtime_tx, runtime_control_tx, runtime_terminate_handle) = spawn_runtime(stored_values, request, event_tx, pending_mode)?; - sessions.insert( - cell_id_text.clone(), - SessionHandle { + cells.insert( + cell_id.clone(), + CellHandle { control_tx, runtime_tx: runtime_tx.clone(), cancellation_token: cancellation_token.clone(), @@ -288,10 +291,10 @@ impl CodeModeService { (runtime_tx, runtime_control_tx, runtime_terminate_handle) }; - tokio::spawn(run_session_control( + tokio::spawn(run_cell_control( Arc::clone(&self.inner), - SessionControlContext { - cell_id: cell_id_text, + CellControlContext { + cell_id, runtime_tx, runtime_control_tx, pending_mode, @@ -308,20 +311,18 @@ impl CodeModeService { } pub async fn wait(&self, request: WaitRequest) -> Result { - let cell_id = request.cell_id.clone(); - let handle = self - .inner - .sessions - .lock() - .await - .get(&request.cell_id) - .cloned(); + let WaitRequest { + cell_id, + yield_time_ms, + } = request; + let cell_id = CellId::from(cell_id); + let handle = self.inner.cells.lock().await.get(&cell_id).cloned(); let Some(handle) = handle else { return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))); }; let (response_tx, response_rx) = oneshot::channel(); - let control_message = SessionControlCommand::Poll { - yield_time_ms: request.yield_time_ms, + let control_message = CellControlCommand::Poll { + yield_time_ms, response_tx, }; if handle.control_tx.send(control_message).is_err() { @@ -329,21 +330,20 @@ impl CodeModeService { } match response_rx.await { Ok(response) => Ok(WaitOutcome::LiveCell(response)), - Err(_) => Ok(WaitOutcome::MissingCell(missing_cell_response( - request.cell_id, - ))), + Err(_) => Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))), } } pub async fn terminate(&self, cell_id: String) -> Result { - let handle = self.inner.sessions.lock().await.get(&cell_id).cloned(); + let cell_id = CellId::from(cell_id); + let handle = self.inner.cells.lock().await.get(&cell_id).cloned(); let Some(handle) = handle else { return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))); }; let (response_tx, response_rx) = oneshot::channel(); if handle .control_tx - .send(SessionControlCommand::Terminate { response_tx }) + .send(CellControlCommand::Terminate { response_tx }) .is_err() { return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))); @@ -358,14 +358,8 @@ impl CodeModeService { &self, request: WaitToPendingRequest, ) -> Result { - let cell_id = request.cell_id.clone(); - let handle = self - .inner - .sessions - .lock() - .await - .get(&request.cell_id) - .cloned(); + let cell_id = CellId::from(request.cell_id); + let handle = self.inner.cells.lock().await.get(&cell_id).cloned(); let Some(handle) = handle else { return Ok(WaitToPendingOutcome::MissingCell(missing_cell_response( cell_id, @@ -374,7 +368,7 @@ impl CodeModeService { let (response_tx, response_rx) = oneshot::channel(); if handle .control_tx - .send(SessionControlCommand::PollToPending { response_tx }) + .send(CellControlCommand::PollToPending { response_tx }) .is_err() { return Ok(WaitToPendingOutcome::MissingCell(missing_cell_response( @@ -384,7 +378,7 @@ impl CodeModeService { match response_rx.await { Ok(response) => Ok(WaitToPendingOutcome::LiveCell(response)), Err(_) => Ok(WaitToPendingOutcome::MissingCell(missing_cell_response( - request.cell_id, + cell_id, ))), } } @@ -393,7 +387,7 @@ impl CodeModeService { self.inner.shutting_down.store(true, Ordering::Release); let handles = self .inner - .sessions + .cells .lock() .await .values() @@ -404,10 +398,10 @@ impl CodeModeService { let (response_tx, _response_rx) = oneshot::channel(); let _ = handle .control_tx - .send(SessionControlCommand::Terminate { response_tx }); + .send(CellControlCommand::Terminate { response_tx }); let _ = handle.runtime_tx.send(RuntimeCommand::Terminate); } - while !self.inner.sessions.lock().await.is_empty() { + while !self.inner.cells.lock().await.is_empty() { tokio::task::yield_now().await; } Ok(()) @@ -423,13 +417,13 @@ impl Default for CodeModeService { impl Drop for CodeModeService { fn drop(&mut self) { self.inner.shutting_down.store(true, Ordering::Release); - if let Ok(sessions) = self.inner.sessions.try_lock() { - for handle in sessions.values() { + if let Ok(cells) = self.inner.cells.try_lock() { + for handle in cells.values() { handle.cancellation_token.cancel(); let (response_tx, _response_rx) = oneshot::channel(); let _ = handle .control_tx - .send(SessionControlCommand::Terminate { response_tx }); + .send(CellControlCommand::Terminate { response_tx }); let _ = handle.runtime_tx.send(RuntimeCommand::Terminate); } } @@ -457,7 +451,7 @@ impl CodeModeSession for CodeModeService { } } -enum SessionControlCommand { +enum CellControlCommand { Poll { yield_time_ms: u64, response_tx: oneshot::Sender, @@ -470,7 +464,7 @@ enum SessionControlCommand { }, } -enum SessionResponseSender { +enum CellResponseSender { Runtime(oneshot::Sender), ExecuteToPending(oneshot::Sender), } @@ -480,8 +474,8 @@ struct PendingResult { error_text: Option, } -struct SessionControlContext { - cell_id: String, +struct CellControlContext { + cell_id: CellId, runtime_tx: std::sync::mpsc::Sender, runtime_control_tx: std::sync::mpsc::Sender, pending_mode: PendingRuntimeMode, @@ -489,15 +483,15 @@ struct SessionControlContext { cancellation_token: CancellationToken, } -fn missing_cell_response(cell_id: String) -> RuntimeResponse { +fn missing_cell_response(cell_id: CellId) -> RuntimeResponse { RuntimeResponse::Result { error_text: Some(format!("exec cell {cell_id} not found")), - cell_id, + cell_id: cell_id.to_string(), content_items: Vec::new(), } } -fn pending_result_response(cell_id: &str, result: PendingResult) -> RuntimeResponse { +fn pending_result_response(cell_id: &CellId, result: PendingResult) -> RuntimeResponse { RuntimeResponse::Result { cell_id: cell_id.to_string(), content_items: result.content_items, @@ -505,21 +499,21 @@ fn pending_result_response(cell_id: &str, result: PendingResult) -> RuntimeRespo } } -fn send_terminal_response(response_tx: SessionResponseSender, response: RuntimeResponse) { +fn send_terminal_response(response_tx: CellResponseSender, response: RuntimeResponse) { match response_tx { - SessionResponseSender::Runtime(response_tx) => { + CellResponseSender::Runtime(response_tx) => { let _ = response_tx.send(response); } - SessionResponseSender::ExecuteToPending(response_tx) => { + CellResponseSender::ExecuteToPending(response_tx) => { let _ = response_tx.send(ExecuteToPendingOutcome::Completed(response)); } } } fn send_or_buffer_result( - cell_id: &str, + cell_id: &CellId, result: PendingResult, - response_tx: &mut Option, + response_tx: &mut Option, pending_result: &mut Option, ) -> bool { if let Some(response_tx) = response_tx.take() { @@ -533,37 +527,35 @@ fn send_or_buffer_result( } fn send_yield_response( - cell_id: &str, + cell_id: &CellId, content_items: &mut Vec, - response_tx: &mut Option, + response_tx: &mut Option, ) { let Some(current_response_tx) = response_tx.take() else { return; }; match current_response_tx { - SessionResponseSender::Runtime(response_tx) => { + CellResponseSender::Runtime(response_tx) => { let _ = response_tx.send(RuntimeResponse::Yielded { cell_id: cell_id.to_string(), content_items: std::mem::take(content_items), }); } - SessionResponseSender::ExecuteToPending(execute_to_pending_tx) => { - *response_tx = Some(SessionResponseSender::ExecuteToPending( - execute_to_pending_tx, - )); + CellResponseSender::ExecuteToPending(execute_to_pending_tx) => { + *response_tx = Some(CellResponseSender::ExecuteToPending(execute_to_pending_tx)); } } } -async fn run_session_control( +async fn run_cell_control( inner: Arc, - context: SessionControlContext, + context: CellControlContext, mut event_rx: mpsc::UnboundedReceiver, - mut control_rx: mpsc::UnboundedReceiver, - initial_response_tx: SessionResponseSender, + mut control_rx: mpsc::UnboundedReceiver, + initial_response_tx: CellResponseSender, initial_yield_time_ms: Option, ) { - let SessionControlContext { + let CellControlContext { cell_id, runtime_tx, runtime_control_tx, @@ -593,7 +585,7 @@ async fn run_session_control( if termination_requested { if let Some(response_tx) = response_tx.take() { let response = RuntimeResponse::Terminated { - cell_id: cell_id.clone(), + cell_id: cell_id.to_string(), content_items: std::mem::take(&mut content_items), }; send_terminal_response(response_tx, response); @@ -625,13 +617,13 @@ async fn run_session_control( RuntimeEvent::Pending => { if let Some(current_response_tx) = response_tx.take() { match current_response_tx { - SessionResponseSender::Runtime(runtime_response_tx) => { + CellResponseSender::Runtime(runtime_response_tx) => { response_tx = - Some(SessionResponseSender::Runtime(runtime_response_tx)); + Some(CellResponseSender::Runtime(runtime_response_tx)); } - SessionResponseSender::ExecuteToPending(response_tx) => { + CellResponseSender::ExecuteToPending(response_tx) => { let _ = response_tx.send(ExecuteToPendingOutcome::Pending { - cell_id: cell_id.clone(), + cell_id: cell_id.to_string(), content_items: std::mem::take(&mut content_items), pending_tool_call_ids: std::mem::take( &mut pending_tool_call_ids, @@ -650,7 +642,7 @@ async fn run_session_control( } RuntimeEvent::Notify { call_id, text } => { let delegate = Arc::clone(&inner.delegate); - let cell_id = CellId::from(cell_id.clone()); + let cell_id = cell_id.clone(); tokio::spawn(async move { if let Err(err) = delegate.notify(call_id, cell_id.clone(), text).await { warn!( @@ -669,7 +661,7 @@ async fn run_session_control( pending_tool_call_ids.push(id.clone()); } let tool_call = CodeModeNestedToolCall { - cell_id: cell_id.clone(), + cell_id: cell_id.to_string(), runtime_tool_call_id: id.clone(), tool_name: name, tool_kind: kind, @@ -698,7 +690,7 @@ async fn run_session_control( if termination_requested { if let Some(response_tx) = response_tx.take() { let response = RuntimeResponse::Terminated { - cell_id: cell_id.clone(), + cell_id: cell_id.to_string(), content_items: std::mem::take(&mut content_items), }; send_terminal_response(response_tx, response); @@ -730,7 +722,7 @@ async fn run_session_control( break; }; match command { - SessionControlCommand::Poll { + CellControlCommand::Poll { yield_time_ms, response_tx: next_response_tx, } => { @@ -738,11 +730,11 @@ async fn run_session_control( let _ = next_response_tx.send(pending_result_response(&cell_id, result)); break; } - response_tx = Some(SessionResponseSender::Runtime(next_response_tx)); + response_tx = Some(CellResponseSender::Runtime(next_response_tx)); yield_timer = Some(Box::pin(tokio::time::sleep(Duration::from_millis(yield_time_ms)))); resume_paused_runtime(&runtime_control_tx, pending_mode); } - SessionControlCommand::PollToPending { + CellControlCommand::PollToPending { response_tx: next_response_tx, } => { if let Some(result) = pending_result.take() { @@ -752,17 +744,17 @@ async fn run_session_control( break; } response_tx = - Some(SessionResponseSender::ExecuteToPending(next_response_tx)); + Some(CellResponseSender::ExecuteToPending(next_response_tx)); yield_timer = None; resume_paused_runtime(&runtime_control_tx, pending_mode); } - SessionControlCommand::Terminate { response_tx: next_response_tx } => { + CellControlCommand::Terminate { response_tx: next_response_tx } => { if let Some(result) = pending_result.take() { let _ = next_response_tx.send(pending_result_response(&cell_id, result)); break; } - response_tx = Some(SessionResponseSender::Runtime(next_response_tx)); + response_tx = Some(CellResponseSender::Runtime(next_response_tx)); termination_requested = true; cancellation_token.cancel(); yield_timer = None; @@ -772,7 +764,7 @@ async fn run_session_control( if runtime_closed { if let Some(response_tx) = response_tx.take() { let response = RuntimeResponse::Terminated { - cell_id: cell_id.clone(), + cell_id: cell_id.to_string(), content_items: std::mem::take(&mut content_items), }; send_terminal_response(response_tx, response); @@ -800,7 +792,7 @@ async fn run_session_control( let _ = runtime_tx.send(RuntimeCommand::Terminate); cancellation_token.cancel(); terminate_paused_runtime(&runtime_control_tx, pending_mode); - inner.sessions.lock().await.remove(&cell_id); + inner.cells.lock().await.remove(&cell_id); } fn resume_paused_runtime( @@ -834,20 +826,21 @@ mod tests { use tokio::sync::mpsc; use tokio::sync::oneshot; + use super::CellControlCommand; + use super::CellControlContext; + use super::CellId; + use super::CellResponseSender; use super::CodeModeService; use super::Inner; use super::NoopCodeModeSessionDelegate; use super::PendingRuntimeMode; use super::RuntimeCommand; use super::RuntimeResponse; - use super::SessionControlCommand; - use super::SessionControlContext; - use super::SessionResponseSender; use super::WaitOutcome; use super::WaitRequest; use super::WaitToPendingOutcome; use super::WaitToPendingRequest; - use super::run_session_control; + use super::run_cell_control; use crate::CodeModeToolKind; use crate::FunctionCallOutputContentItem; use crate::ToolDefinition; @@ -879,7 +872,7 @@ mod tests { fn test_inner() -> Arc { Arc::new(Inner { stored_values: Mutex::new(HashMap::new()), - sessions: Mutex::new(HashMap::new()), + cells: Mutex::new(HashMap::new()), delegate: Arc::new(NoopCodeModeSessionDelegate), shutting_down: std::sync::atomic::AtomicBool::new(false), next_cell_id: AtomicU64::new(1), @@ -912,6 +905,70 @@ mod tests { ); } + #[tokio::test] + async fn stored_values_are_shared_between_cells_but_not_sessions() { + let first_session = CodeModeService::new(); + let second_session = CodeModeService::new(); + + let write_response = execute( + &first_session, + ExecuteRequest { + source: r#"store("key", "visible");"#.to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + let same_session = execute( + &first_session, + ExecuteRequest { + source: r#"text(String(load("key")));"#.to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + let other_session = execute( + &second_session, + ExecuteRequest { + source: r#"text(String(load("key")));"#.to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + assert_eq!( + write_response, + RuntimeResponse::Result { + cell_id: "1".to_string(), + content_items: Vec::new(), + error_text: None, + } + ); + assert_eq!( + same_session, + RuntimeResponse::Result { + cell_id: "2".to_string(), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "visible".to_string(), + }], + error_text: None, + } + ); + assert_eq!( + other_session, + RuntimeResponse::Result { + cell_id: "1".to_string(), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "undefined".to_string(), + }], + error_text: None, + } + ); + } + #[tokio::test] async fn shutdown_interrupts_cpu_bound_cells() { let service = CodeModeService::new(); @@ -1088,10 +1145,10 @@ await Promise.all([ let runtime_tx = service .inner - .sessions + .cells .lock() .await - .get("1") + .get(&CellId::from("1".to_string())) .unwrap() .runtime_tx .clone(); @@ -1158,10 +1215,10 @@ await new Promise(() => {}); let runtime_tx = service .inner - .sessions + .cells .lock() .await - .get("1") + .get(&CellId::from("1".to_string())) .unwrap() .runtime_tx .clone(); @@ -1229,10 +1286,10 @@ text("done"); let runtime_tx = service .inner - .sessions + .cells .lock() .await - .get("1") + .get(&CellId::from("1".to_string())) .unwrap() .runtime_tx .clone(); @@ -1631,10 +1688,10 @@ image({ ) .unwrap(); - tokio::spawn(run_session_control( + tokio::spawn(run_cell_control( inner, - SessionControlContext { - cell_id: "cell-1".to_string(), + CellControlContext { + cell_id: CellId::from("cell-1".to_string()), runtime_tx: runtime_tx.clone(), runtime_control_tx, pending_mode: PendingRuntimeMode::Continue, @@ -1643,7 +1700,7 @@ image({ }, event_rx, control_rx, - SessionResponseSender::Runtime(initial_response_tx), + CellResponseSender::Runtime(initial_response_tx), Some(/*initial_yield_time_ms*/ 60_000), )); @@ -1659,7 +1716,7 @@ image({ let (terminate_response_tx, terminate_response_rx) = oneshot::channel(); control_tx - .send(SessionControlCommand::Terminate { + .send(CellControlCommand::Terminate { response_tx: terminate_response_tx, }) .unwrap();