Normalize the names Provider => Session (per thread) => Cell (per execution)

This commit is contained in:
Channing Conger
2026-05-22 21:53:28 -07:00
parent 64a9aa24f4
commit e190ed8f14

View File

@@ -127,7 +127,11 @@ impl CodeModeSessionDelegate for NoopCodeModeSessionDelegate {
}
}
/// A durable code-mode session. Implementations may execute cells in-process or remotely.
/// A durable code-mode session owned by one Codex thread.
///
/// Cells executed in the same session share stored values. Separate sessions
/// must keep those values isolated. Implementations may execute cells
/// in-process or remotely.
pub trait CodeModeSession: Send + Sync {
fn execute<'a>(
&'a self,
@@ -141,7 +145,7 @@ pub trait CodeModeSession: Send + Sync {
fn shutdown<'a>(&'a self) -> CodeModeSessionResultFuture<'a, ()>;
}
/// Creates code-mode sessions for one host.
/// Creates code-mode sessions for one Codex thread.
///
/// Providers choose where a session executes and receive the host delegate that
/// the session should use for nested tool calls and notifications.
@@ -169,15 +173,15 @@ impl CodeModeSessionProvider for InProcessCodeModeSessionProvider {
}
#[derive(Clone)]
struct SessionHandle {
control_tx: mpsc::UnboundedSender<SessionControlCommand>,
struct CellHandle {
control_tx: mpsc::UnboundedSender<CellControlCommand>,
runtime_tx: std::sync::mpsc::Sender<RuntimeCommand>,
cancellation_token: CancellationToken,
}
struct Inner {
stored_values: Mutex<HashMap<String, JsonValue>>,
sessions: Mutex<HashMap<String, SessionHandle>>,
cells: Mutex<HashMap<CellId, CellHandle>>,
delegate: Arc<dyn CodeModeSessionDelegate>,
shutting_down: AtomicBool,
next_cell_id: AtomicU64,
@@ -196,7 +200,7 @@ impl CodeModeService {
Self {
inner: Arc::new(Inner {
stored_values: Mutex::new(HashMap::new()),
sessions: Mutex::new(HashMap::new()),
cells: Mutex::new(HashMap::new()),
delegate,
shutting_down: AtomicBool::new(false),
next_cell_id: AtomicU64::new(1),
@@ -220,10 +224,10 @@ impl CodeModeService {
let initial_yield_time_ms = request.yield_time_ms.unwrap_or(DEFAULT_EXEC_YIELD_TIME_MS);
let (response_tx, response_rx) = oneshot::channel();
let cell_id = self.allocate_cell_id();
self.start_session(
self.start_cell(
cell_id.clone(),
request,
SessionResponseSender::Runtime(response_tx),
CellResponseSender::Runtime(response_tx),
Some(initial_yield_time_ms),
PendingRuntimeMode::Continue,
)
@@ -241,10 +245,10 @@ impl CodeModeService {
) -> Result<ExecuteToPendingOutcome, String> {
let (response_tx, response_rx) = oneshot::channel();
let cell_id = self.allocate_cell_id();
self.start_session(
self.start_cell(
cell_id,
request,
SessionResponseSender::ExecuteToPending(response_tx),
CellResponseSender::ExecuteToPending(response_tx),
/*initial_yield_time_ms*/ None,
PendingRuntimeMode::PauseUntilResumed,
)
@@ -255,31 +259,30 @@ impl CodeModeService {
.map_err(|_| "exec runtime ended unexpectedly".to_string())
}
async fn start_session(
async fn start_cell(
&self,
cell_id: CellId,
request: ExecuteRequest,
initial_response_tx: SessionResponseSender,
initial_response_tx: CellResponseSender,
initial_yield_time_ms: Option<u64>,
pending_mode: PendingRuntimeMode,
) -> Result<(), String> {
let cell_id_text = cell_id.to_string();
let (event_tx, event_rx) = mpsc::unbounded_channel();
let (control_tx, control_rx) = mpsc::unbounded_channel();
let stored_values = self.inner.stored_values.lock().await.clone();
let cancellation_token = CancellationToken::new();
let (runtime_tx, runtime_control_tx, runtime_terminate_handle) = {
let mut sessions = self.inner.sessions.lock().await;
if sessions.contains_key(&cell_id_text) {
let mut cells = self.inner.cells.lock().await;
if cells.contains_key(&cell_id) {
return Err(format!("exec cell {cell_id} already exists"));
}
let (runtime_tx, runtime_control_tx, runtime_terminate_handle) =
spawn_runtime(stored_values, request, event_tx, pending_mode)?;
sessions.insert(
cell_id_text.clone(),
SessionHandle {
cells.insert(
cell_id.clone(),
CellHandle {
control_tx,
runtime_tx: runtime_tx.clone(),
cancellation_token: cancellation_token.clone(),
@@ -288,10 +291,10 @@ impl CodeModeService {
(runtime_tx, runtime_control_tx, runtime_terminate_handle)
};
tokio::spawn(run_session_control(
tokio::spawn(run_cell_control(
Arc::clone(&self.inner),
SessionControlContext {
cell_id: cell_id_text,
CellControlContext {
cell_id,
runtime_tx,
runtime_control_tx,
pending_mode,
@@ -308,20 +311,18 @@ impl CodeModeService {
}
pub async fn wait(&self, request: WaitRequest) -> Result<WaitOutcome, String> {
let cell_id = request.cell_id.clone();
let handle = self
.inner
.sessions
.lock()
.await
.get(&request.cell_id)
.cloned();
let WaitRequest {
cell_id,
yield_time_ms,
} = request;
let cell_id = CellId::from(cell_id);
let handle = self.inner.cells.lock().await.get(&cell_id).cloned();
let Some(handle) = handle else {
return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id)));
};
let (response_tx, response_rx) = oneshot::channel();
let control_message = SessionControlCommand::Poll {
yield_time_ms: request.yield_time_ms,
let control_message = CellControlCommand::Poll {
yield_time_ms,
response_tx,
};
if handle.control_tx.send(control_message).is_err() {
@@ -329,21 +330,20 @@ impl CodeModeService {
}
match response_rx.await {
Ok(response) => Ok(WaitOutcome::LiveCell(response)),
Err(_) => Ok(WaitOutcome::MissingCell(missing_cell_response(
request.cell_id,
))),
Err(_) => Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))),
}
}
pub async fn terminate(&self, cell_id: String) -> Result<WaitOutcome, String> {
let handle = self.inner.sessions.lock().await.get(&cell_id).cloned();
let cell_id = CellId::from(cell_id);
let handle = self.inner.cells.lock().await.get(&cell_id).cloned();
let Some(handle) = handle else {
return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id)));
};
let (response_tx, response_rx) = oneshot::channel();
if handle
.control_tx
.send(SessionControlCommand::Terminate { response_tx })
.send(CellControlCommand::Terminate { response_tx })
.is_err()
{
return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id)));
@@ -358,14 +358,8 @@ impl CodeModeService {
&self,
request: WaitToPendingRequest,
) -> Result<WaitToPendingOutcome, String> {
let cell_id = request.cell_id.clone();
let handle = self
.inner
.sessions
.lock()
.await
.get(&request.cell_id)
.cloned();
let cell_id = CellId::from(request.cell_id);
let handle = self.inner.cells.lock().await.get(&cell_id).cloned();
let Some(handle) = handle else {
return Ok(WaitToPendingOutcome::MissingCell(missing_cell_response(
cell_id,
@@ -374,7 +368,7 @@ impl CodeModeService {
let (response_tx, response_rx) = oneshot::channel();
if handle
.control_tx
.send(SessionControlCommand::PollToPending { response_tx })
.send(CellControlCommand::PollToPending { response_tx })
.is_err()
{
return Ok(WaitToPendingOutcome::MissingCell(missing_cell_response(
@@ -384,7 +378,7 @@ impl CodeModeService {
match response_rx.await {
Ok(response) => Ok(WaitToPendingOutcome::LiveCell(response)),
Err(_) => Ok(WaitToPendingOutcome::MissingCell(missing_cell_response(
request.cell_id,
cell_id,
))),
}
}
@@ -393,7 +387,7 @@ impl CodeModeService {
self.inner.shutting_down.store(true, Ordering::Release);
let handles = self
.inner
.sessions
.cells
.lock()
.await
.values()
@@ -404,10 +398,10 @@ impl CodeModeService {
let (response_tx, _response_rx) = oneshot::channel();
let _ = handle
.control_tx
.send(SessionControlCommand::Terminate { response_tx });
.send(CellControlCommand::Terminate { response_tx });
let _ = handle.runtime_tx.send(RuntimeCommand::Terminate);
}
while !self.inner.sessions.lock().await.is_empty() {
while !self.inner.cells.lock().await.is_empty() {
tokio::task::yield_now().await;
}
Ok(())
@@ -423,13 +417,13 @@ impl Default for CodeModeService {
impl Drop for CodeModeService {
fn drop(&mut self) {
self.inner.shutting_down.store(true, Ordering::Release);
if let Ok(sessions) = self.inner.sessions.try_lock() {
for handle in sessions.values() {
if let Ok(cells) = self.inner.cells.try_lock() {
for handle in cells.values() {
handle.cancellation_token.cancel();
let (response_tx, _response_rx) = oneshot::channel();
let _ = handle
.control_tx
.send(SessionControlCommand::Terminate { response_tx });
.send(CellControlCommand::Terminate { response_tx });
let _ = handle.runtime_tx.send(RuntimeCommand::Terminate);
}
}
@@ -457,7 +451,7 @@ impl CodeModeSession for CodeModeService {
}
}
enum SessionControlCommand {
enum CellControlCommand {
Poll {
yield_time_ms: u64,
response_tx: oneshot::Sender<RuntimeResponse>,
@@ -470,7 +464,7 @@ enum SessionControlCommand {
},
}
enum SessionResponseSender {
enum CellResponseSender {
Runtime(oneshot::Sender<RuntimeResponse>),
ExecuteToPending(oneshot::Sender<ExecuteToPendingOutcome>),
}
@@ -480,8 +474,8 @@ struct PendingResult {
error_text: Option<String>,
}
struct SessionControlContext {
cell_id: String,
struct CellControlContext {
cell_id: CellId,
runtime_tx: std::sync::mpsc::Sender<RuntimeCommand>,
runtime_control_tx: std::sync::mpsc::Sender<RuntimeControlCommand>,
pending_mode: PendingRuntimeMode,
@@ -489,15 +483,15 @@ struct SessionControlContext {
cancellation_token: CancellationToken,
}
fn missing_cell_response(cell_id: String) -> RuntimeResponse {
fn missing_cell_response(cell_id: CellId) -> RuntimeResponse {
RuntimeResponse::Result {
error_text: Some(format!("exec cell {cell_id} not found")),
cell_id,
cell_id: cell_id.to_string(),
content_items: Vec::new(),
}
}
fn pending_result_response(cell_id: &str, result: PendingResult) -> RuntimeResponse {
fn pending_result_response(cell_id: &CellId, result: PendingResult) -> RuntimeResponse {
RuntimeResponse::Result {
cell_id: cell_id.to_string(),
content_items: result.content_items,
@@ -505,21 +499,21 @@ fn pending_result_response(cell_id: &str, result: PendingResult) -> RuntimeRespo
}
}
fn send_terminal_response(response_tx: SessionResponseSender, response: RuntimeResponse) {
fn send_terminal_response(response_tx: CellResponseSender, response: RuntimeResponse) {
match response_tx {
SessionResponseSender::Runtime(response_tx) => {
CellResponseSender::Runtime(response_tx) => {
let _ = response_tx.send(response);
}
SessionResponseSender::ExecuteToPending(response_tx) => {
CellResponseSender::ExecuteToPending(response_tx) => {
let _ = response_tx.send(ExecuteToPendingOutcome::Completed(response));
}
}
}
fn send_or_buffer_result(
cell_id: &str,
cell_id: &CellId,
result: PendingResult,
response_tx: &mut Option<SessionResponseSender>,
response_tx: &mut Option<CellResponseSender>,
pending_result: &mut Option<PendingResult>,
) -> bool {
if let Some(response_tx) = response_tx.take() {
@@ -533,37 +527,35 @@ fn send_or_buffer_result(
}
fn send_yield_response(
cell_id: &str,
cell_id: &CellId,
content_items: &mut Vec<FunctionCallOutputContentItem>,
response_tx: &mut Option<SessionResponseSender>,
response_tx: &mut Option<CellResponseSender>,
) {
let Some(current_response_tx) = response_tx.take() else {
return;
};
match current_response_tx {
SessionResponseSender::Runtime(response_tx) => {
CellResponseSender::Runtime(response_tx) => {
let _ = response_tx.send(RuntimeResponse::Yielded {
cell_id: cell_id.to_string(),
content_items: std::mem::take(content_items),
});
}
SessionResponseSender::ExecuteToPending(execute_to_pending_tx) => {
*response_tx = Some(SessionResponseSender::ExecuteToPending(
execute_to_pending_tx,
));
CellResponseSender::ExecuteToPending(execute_to_pending_tx) => {
*response_tx = Some(CellResponseSender::ExecuteToPending(execute_to_pending_tx));
}
}
}
async fn run_session_control(
async fn run_cell_control(
inner: Arc<Inner>,
context: SessionControlContext,
context: CellControlContext,
mut event_rx: mpsc::UnboundedReceiver<RuntimeEvent>,
mut control_rx: mpsc::UnboundedReceiver<SessionControlCommand>,
initial_response_tx: SessionResponseSender,
mut control_rx: mpsc::UnboundedReceiver<CellControlCommand>,
initial_response_tx: CellResponseSender,
initial_yield_time_ms: Option<u64>,
) {
let SessionControlContext {
let CellControlContext {
cell_id,
runtime_tx,
runtime_control_tx,
@@ -593,7 +585,7 @@ async fn run_session_control(
if termination_requested {
if let Some(response_tx) = response_tx.take() {
let response = RuntimeResponse::Terminated {
cell_id: cell_id.clone(),
cell_id: cell_id.to_string(),
content_items: std::mem::take(&mut content_items),
};
send_terminal_response(response_tx, response);
@@ -625,13 +617,13 @@ async fn run_session_control(
RuntimeEvent::Pending => {
if let Some(current_response_tx) = response_tx.take() {
match current_response_tx {
SessionResponseSender::Runtime(runtime_response_tx) => {
CellResponseSender::Runtime(runtime_response_tx) => {
response_tx =
Some(SessionResponseSender::Runtime(runtime_response_tx));
Some(CellResponseSender::Runtime(runtime_response_tx));
}
SessionResponseSender::ExecuteToPending(response_tx) => {
CellResponseSender::ExecuteToPending(response_tx) => {
let _ = response_tx.send(ExecuteToPendingOutcome::Pending {
cell_id: cell_id.clone(),
cell_id: cell_id.to_string(),
content_items: std::mem::take(&mut content_items),
pending_tool_call_ids: std::mem::take(
&mut pending_tool_call_ids,
@@ -650,7 +642,7 @@ async fn run_session_control(
}
RuntimeEvent::Notify { call_id, text } => {
let delegate = Arc::clone(&inner.delegate);
let cell_id = CellId::from(cell_id.clone());
let cell_id = cell_id.clone();
tokio::spawn(async move {
if let Err(err) = delegate.notify(call_id, cell_id.clone(), text).await {
warn!(
@@ -669,7 +661,7 @@ async fn run_session_control(
pending_tool_call_ids.push(id.clone());
}
let tool_call = CodeModeNestedToolCall {
cell_id: cell_id.clone(),
cell_id: cell_id.to_string(),
runtime_tool_call_id: id.clone(),
tool_name: name,
tool_kind: kind,
@@ -698,7 +690,7 @@ async fn run_session_control(
if termination_requested {
if let Some(response_tx) = response_tx.take() {
let response = RuntimeResponse::Terminated {
cell_id: cell_id.clone(),
cell_id: cell_id.to_string(),
content_items: std::mem::take(&mut content_items),
};
send_terminal_response(response_tx, response);
@@ -730,7 +722,7 @@ async fn run_session_control(
break;
};
match command {
SessionControlCommand::Poll {
CellControlCommand::Poll {
yield_time_ms,
response_tx: next_response_tx,
} => {
@@ -738,11 +730,11 @@ async fn run_session_control(
let _ = next_response_tx.send(pending_result_response(&cell_id, result));
break;
}
response_tx = Some(SessionResponseSender::Runtime(next_response_tx));
response_tx = Some(CellResponseSender::Runtime(next_response_tx));
yield_timer = Some(Box::pin(tokio::time::sleep(Duration::from_millis(yield_time_ms))));
resume_paused_runtime(&runtime_control_tx, pending_mode);
}
SessionControlCommand::PollToPending {
CellControlCommand::PollToPending {
response_tx: next_response_tx,
} => {
if let Some(result) = pending_result.take() {
@@ -752,17 +744,17 @@ async fn run_session_control(
break;
}
response_tx =
Some(SessionResponseSender::ExecuteToPending(next_response_tx));
Some(CellResponseSender::ExecuteToPending(next_response_tx));
yield_timer = None;
resume_paused_runtime(&runtime_control_tx, pending_mode);
}
SessionControlCommand::Terminate { response_tx: next_response_tx } => {
CellControlCommand::Terminate { response_tx: next_response_tx } => {
if let Some(result) = pending_result.take() {
let _ = next_response_tx.send(pending_result_response(&cell_id, result));
break;
}
response_tx = Some(SessionResponseSender::Runtime(next_response_tx));
response_tx = Some(CellResponseSender::Runtime(next_response_tx));
termination_requested = true;
cancellation_token.cancel();
yield_timer = None;
@@ -772,7 +764,7 @@ async fn run_session_control(
if runtime_closed {
if let Some(response_tx) = response_tx.take() {
let response = RuntimeResponse::Terminated {
cell_id: cell_id.clone(),
cell_id: cell_id.to_string(),
content_items: std::mem::take(&mut content_items),
};
send_terminal_response(response_tx, response);
@@ -800,7 +792,7 @@ async fn run_session_control(
let _ = runtime_tx.send(RuntimeCommand::Terminate);
cancellation_token.cancel();
terminate_paused_runtime(&runtime_control_tx, pending_mode);
inner.sessions.lock().await.remove(&cell_id);
inner.cells.lock().await.remove(&cell_id);
}
fn resume_paused_runtime(
@@ -834,20 +826,21 @@ mod tests {
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use super::CellControlCommand;
use super::CellControlContext;
use super::CellId;
use super::CellResponseSender;
use super::CodeModeService;
use super::Inner;
use super::NoopCodeModeSessionDelegate;
use super::PendingRuntimeMode;
use super::RuntimeCommand;
use super::RuntimeResponse;
use super::SessionControlCommand;
use super::SessionControlContext;
use super::SessionResponseSender;
use super::WaitOutcome;
use super::WaitRequest;
use super::WaitToPendingOutcome;
use super::WaitToPendingRequest;
use super::run_session_control;
use super::run_cell_control;
use crate::CodeModeToolKind;
use crate::FunctionCallOutputContentItem;
use crate::ToolDefinition;
@@ -879,7 +872,7 @@ mod tests {
fn test_inner() -> Arc<Inner> {
Arc::new(Inner {
stored_values: Mutex::new(HashMap::new()),
sessions: Mutex::new(HashMap::new()),
cells: Mutex::new(HashMap::new()),
delegate: Arc::new(NoopCodeModeSessionDelegate),
shutting_down: std::sync::atomic::AtomicBool::new(false),
next_cell_id: AtomicU64::new(1),
@@ -912,6 +905,70 @@ mod tests {
);
}
#[tokio::test]
async fn stored_values_are_shared_between_cells_but_not_sessions() {
let first_session = CodeModeService::new();
let second_session = CodeModeService::new();
let write_response = execute(
&first_session,
ExecuteRequest {
source: r#"store("key", "visible");"#.to_string(),
yield_time_ms: None,
..execute_request("")
},
)
.await;
let same_session = execute(
&first_session,
ExecuteRequest {
source: r#"text(String(load("key")));"#.to_string(),
yield_time_ms: None,
..execute_request("")
},
)
.await;
let other_session = execute(
&second_session,
ExecuteRequest {
source: r#"text(String(load("key")));"#.to_string(),
yield_time_ms: None,
..execute_request("")
},
)
.await;
assert_eq!(
write_response,
RuntimeResponse::Result {
cell_id: "1".to_string(),
content_items: Vec::new(),
error_text: None,
}
);
assert_eq!(
same_session,
RuntimeResponse::Result {
cell_id: "2".to_string(),
content_items: vec![FunctionCallOutputContentItem::InputText {
text: "visible".to_string(),
}],
error_text: None,
}
);
assert_eq!(
other_session,
RuntimeResponse::Result {
cell_id: "1".to_string(),
content_items: vec![FunctionCallOutputContentItem::InputText {
text: "undefined".to_string(),
}],
error_text: None,
}
);
}
#[tokio::test]
async fn shutdown_interrupts_cpu_bound_cells() {
let service = CodeModeService::new();
@@ -1088,10 +1145,10 @@ await Promise.all([
let runtime_tx = service
.inner
.sessions
.cells
.lock()
.await
.get("1")
.get(&CellId::from("1".to_string()))
.unwrap()
.runtime_tx
.clone();
@@ -1158,10 +1215,10 @@ await new Promise(() => {});
let runtime_tx = service
.inner
.sessions
.cells
.lock()
.await
.get("1")
.get(&CellId::from("1".to_string()))
.unwrap()
.runtime_tx
.clone();
@@ -1229,10 +1286,10 @@ text("done");
let runtime_tx = service
.inner
.sessions
.cells
.lock()
.await
.get("1")
.get(&CellId::from("1".to_string()))
.unwrap()
.runtime_tx
.clone();
@@ -1631,10 +1688,10 @@ image({
)
.unwrap();
tokio::spawn(run_session_control(
tokio::spawn(run_cell_control(
inner,
SessionControlContext {
cell_id: "cell-1".to_string(),
CellControlContext {
cell_id: CellId::from("cell-1".to_string()),
runtime_tx: runtime_tx.clone(),
runtime_control_tx,
pending_mode: PendingRuntimeMode::Continue,
@@ -1643,7 +1700,7 @@ image({
},
event_rx,
control_rx,
SessionResponseSender::Runtime(initial_response_tx),
CellResponseSender::Runtime(initial_response_tx),
Some(/*initial_yield_time_ms*/ 60_000),
));
@@ -1659,7 +1716,7 @@ image({
let (terminate_response_tx, terminate_response_rx) = oneshot::channel();
control_tx
.send(SessionControlCommand::Terminate {
.send(CellControlCommand::Terminate {
response_tx: terminate_response_tx,
})
.unwrap();