mirror of
https://github.com/openai/codex.git
synced 2026-05-23 12:34:25 +00:00
code-mode: introduce durable session interface
This commit is contained in:
@@ -29,9 +29,18 @@ pub use runtime::WaitOutcome;
|
||||
pub use runtime::WaitRequest;
|
||||
pub use runtime::WaitToPendingOutcome;
|
||||
pub use runtime::WaitToPendingRequest;
|
||||
pub use service::CellId;
|
||||
pub use service::CodeModeService;
|
||||
pub use service::CodeModeTurnHost;
|
||||
pub use service::CodeModeTurnWorker;
|
||||
pub use service::CodeModeSession;
|
||||
pub use service::CodeModeSessionDelegate;
|
||||
pub use service::CodeModeSessionProvider;
|
||||
pub use service::CodeModeSessionProviderFuture;
|
||||
pub use service::CodeModeSessionResultFuture;
|
||||
pub use service::InProcessCodeModeSessionProvider;
|
||||
pub use service::NoopCodeModeSessionDelegate;
|
||||
pub use service::NotificationFuture;
|
||||
pub use service::StartedCell;
|
||||
pub use service::ToolInvocationFuture;
|
||||
|
||||
pub const PUBLIC_TOOL_NAME: &str = "exec";
|
||||
pub const WAIT_TOOL_NAME: &str = "wait";
|
||||
|
||||
@@ -27,11 +27,6 @@ const EXIT_SENTINEL: &str = "__codex_code_mode_exit__";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ExecuteRequest {
|
||||
/// Runtime cell id for this execution.
|
||||
///
|
||||
/// Callers allocate this before execution so tracing, waits, and nested tool
|
||||
/// calls can refer to the cell as soon as JavaScript starts.
|
||||
pub cell_id: String,
|
||||
pub tool_call_id: String,
|
||||
pub enabled_tools: Vec<ToolDefinition>,
|
||||
pub source: String,
|
||||
@@ -43,7 +38,6 @@ pub struct ExecuteRequest {
|
||||
pub struct WaitRequest {
|
||||
pub cell_id: String,
|
||||
pub yield_time_ms: u64,
|
||||
pub terminate: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -133,16 +127,6 @@ pub struct CodeModeNestedToolCall {
|
||||
pub input: Option<JsonValue>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum TurnMessage {
|
||||
ToolCall(CodeModeNestedToolCall),
|
||||
Notify {
|
||||
cell_id: String,
|
||||
call_id: String,
|
||||
text: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum RuntimeCommand {
|
||||
ToolResponse { id: String, result: JsonValue },
|
||||
@@ -460,7 +444,6 @@ mod tests {
|
||||
|
||||
fn execute_request(source: &str) -> ExecuteRequest {
|
||||
ExecuteRequest {
|
||||
cell_id: "1".to_string(),
|
||||
tool_call_id: "call_1".to_string(),
|
||||
enabled_tools: Vec::new(),
|
||||
source: source.to_string(),
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value as JsonValue;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -22,35 +27,168 @@ use crate::runtime::RuntimeCommand;
|
||||
use crate::runtime::RuntimeControlCommand;
|
||||
use crate::runtime::RuntimeEvent;
|
||||
use crate::runtime::RuntimeResponse;
|
||||
use crate::runtime::TurnMessage;
|
||||
use crate::runtime::WaitOutcome;
|
||||
use crate::runtime::WaitRequest;
|
||||
use crate::runtime::WaitToPendingOutcome;
|
||||
use crate::runtime::WaitToPendingRequest;
|
||||
use crate::runtime::spawn_runtime;
|
||||
|
||||
#[async_trait]
|
||||
pub trait CodeModeTurnHost: Send + Sync {
|
||||
async fn invoke_tool(
|
||||
&self,
|
||||
pub type CodeModeSessionResultFuture<'a, T> =
|
||||
Pin<Box<dyn Future<Output = Result<T, String>> + Send + 'a>>;
|
||||
pub type CodeModeSessionProviderFuture<'a> =
|
||||
CodeModeSessionResultFuture<'a, Arc<dyn CodeModeSession>>;
|
||||
pub type ToolInvocationFuture<'a> =
|
||||
Pin<Box<dyn Future<Output = Result<JsonValue, String>> + Send + 'a>>;
|
||||
pub type NotificationFuture<'a> = Pin<Box<dyn Future<Output = Result<(), String>> + Send + 'a>>;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
|
||||
pub struct CellId(String);
|
||||
|
||||
impl CellId {
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<str> for CellId {
|
||||
fn as_ref(&self) -> &str {
|
||||
self.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for CellId {
|
||||
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
formatter.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for CellId {
|
||||
fn from(value: String) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CellId> for String {
|
||||
fn from(value: CellId) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StartedCell {
|
||||
pub cell_id: CellId,
|
||||
initial_response_rx: oneshot::Receiver<RuntimeResponse>,
|
||||
}
|
||||
|
||||
impl StartedCell {
|
||||
pub async fn initial_response(self) -> Result<RuntimeResponse, String> {
|
||||
self.initial_response_rx
|
||||
.await
|
||||
.map_err(|_| "exec runtime ended unexpectedly".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Host callbacks used by a code-mode session while cells are executing.
|
||||
pub trait CodeModeSessionDelegate: Send + Sync {
|
||||
fn invoke_tool<'a>(
|
||||
&'a self,
|
||||
invocation: CodeModeNestedToolCall,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Result<JsonValue, String>;
|
||||
) -> ToolInvocationFuture<'a>;
|
||||
|
||||
async fn notify(&self, call_id: String, cell_id: String, text: String) -> Result<(), String>;
|
||||
fn notify<'a>(
|
||||
&'a self,
|
||||
call_id: String,
|
||||
cell_id: CellId,
|
||||
text: String,
|
||||
) -> NotificationFuture<'a>;
|
||||
}
|
||||
|
||||
pub struct NoopCodeModeSessionDelegate;
|
||||
|
||||
impl CodeModeSessionDelegate for NoopCodeModeSessionDelegate {
|
||||
fn invoke_tool<'a>(
|
||||
&'a self,
|
||||
_invocation: CodeModeNestedToolCall,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> ToolInvocationFuture<'a> {
|
||||
Box::pin(async move {
|
||||
cancellation_token.cancelled().await;
|
||||
Err("code mode nested tools are unavailable".to_string())
|
||||
})
|
||||
}
|
||||
|
||||
fn notify<'a>(
|
||||
&'a self,
|
||||
_call_id: String,
|
||||
_cell_id: CellId,
|
||||
_text: String,
|
||||
) -> NotificationFuture<'a> {
|
||||
Box::pin(async { Ok(()) })
|
||||
}
|
||||
}
|
||||
|
||||
/// A durable code-mode session. Implementations may execute cells in-process or remotely.
|
||||
pub trait CodeModeSession: Send + Sync {
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
request: ExecuteRequest,
|
||||
) -> CodeModeSessionResultFuture<'a, StartedCell>;
|
||||
|
||||
fn wait<'a>(&'a self, request: WaitRequest) -> CodeModeSessionResultFuture<'a, WaitOutcome>;
|
||||
|
||||
fn terminate<'a>(&'a self, cell_id: String) -> CodeModeSessionResultFuture<'a, WaitOutcome>;
|
||||
|
||||
fn stored_values<'a>(
|
||||
&'a self,
|
||||
) -> Pin<Box<dyn Future<Output = HashMap<String, JsonValue>> + Send + 'a>>;
|
||||
|
||||
fn replace_stored_values<'a>(
|
||||
&'a self,
|
||||
values: HashMap<String, JsonValue>,
|
||||
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
|
||||
|
||||
fn shutdown<'a>(&'a self) -> CodeModeSessionResultFuture<'a, ()>;
|
||||
}
|
||||
|
||||
/// Creates code-mode sessions for one host.
|
||||
///
|
||||
/// Providers choose where a session executes and receive the host delegate that
|
||||
/// the session should use for nested tool calls and notifications.
|
||||
pub trait CodeModeSessionProvider: Send + Sync {
|
||||
fn create_session<'a>(
|
||||
&'a self,
|
||||
delegate: Arc<dyn CodeModeSessionDelegate>,
|
||||
) -> CodeModeSessionProviderFuture<'a>;
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct InProcessCodeModeSessionProvider;
|
||||
|
||||
impl CodeModeSessionProvider for InProcessCodeModeSessionProvider {
|
||||
fn create_session<'a>(
|
||||
&'a self,
|
||||
delegate: Arc<dyn CodeModeSessionDelegate>,
|
||||
) -> CodeModeSessionProviderFuture<'a> {
|
||||
Box::pin(async move {
|
||||
let session: Arc<dyn CodeModeSession> =
|
||||
Arc::new(CodeModeService::with_delegate(delegate));
|
||||
Ok(session)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SessionHandle {
|
||||
control_tx: mpsc::UnboundedSender<SessionControlCommand>,
|
||||
runtime_tx: std::sync::mpsc::Sender<RuntimeCommand>,
|
||||
cancellation_token: CancellationToken,
|
||||
}
|
||||
|
||||
struct Inner {
|
||||
stored_values: Mutex<HashMap<String, JsonValue>>,
|
||||
sessions: Mutex<HashMap<String, SessionHandle>>,
|
||||
turn_message_tx: async_channel::Sender<TurnMessage>,
|
||||
turn_message_rx: async_channel::Receiver<TurnMessage>,
|
||||
delegate: Arc<dyn CodeModeSessionDelegate>,
|
||||
shutting_down: AtomicBool,
|
||||
next_cell_id: AtomicU64,
|
||||
}
|
||||
|
||||
@@ -60,14 +198,16 @@ pub struct CodeModeService {
|
||||
|
||||
impl CodeModeService {
|
||||
pub fn new() -> Self {
|
||||
let (turn_message_tx, turn_message_rx) = async_channel::unbounded();
|
||||
Self::with_delegate(Arc::new(NoopCodeModeSessionDelegate))
|
||||
}
|
||||
|
||||
pub fn with_delegate(delegate: Arc<dyn CodeModeSessionDelegate>) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(Inner {
|
||||
stored_values: Mutex::new(HashMap::new()),
|
||||
sessions: Mutex::new(HashMap::new()),
|
||||
turn_message_tx,
|
||||
turn_message_rx,
|
||||
delegate,
|
||||
shutting_down: AtomicBool::new(false),
|
||||
next_cell_id: AtomicU64::new(1),
|
||||
}),
|
||||
}
|
||||
@@ -77,23 +217,28 @@ impl CodeModeService {
|
||||
self.inner.stored_values.lock().await.clone()
|
||||
}
|
||||
|
||||
/// Reserves the runtime cell id for a future `execute` request.
|
||||
///
|
||||
/// The runtime can issue nested tool calls before the first `execute`
|
||||
/// response is returned. Hosts that need a parent trace object for those
|
||||
/// nested calls should allocate the cell id up front and pass it back on the
|
||||
/// `ExecuteRequest`.
|
||||
pub fn allocate_cell_id(&self) -> String {
|
||||
self.inner
|
||||
.next_cell_id
|
||||
.fetch_add(1, Ordering::Relaxed)
|
||||
.to_string()
|
||||
pub async fn replace_stored_values(&self, values: HashMap<String, JsonValue>) {
|
||||
*self.inner.stored_values.lock().await = values;
|
||||
}
|
||||
|
||||
pub async fn execute(&self, request: ExecuteRequest) -> Result<RuntimeResponse, String> {
|
||||
fn allocate_cell_id(&self) -> CellId {
|
||||
CellId(
|
||||
self.inner
|
||||
.next_cell_id
|
||||
.fetch_add(1, Ordering::Relaxed)
|
||||
.to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn execute(&self, request: ExecuteRequest) -> Result<StartedCell, String> {
|
||||
if self.inner.shutting_down.load(Ordering::Acquire) {
|
||||
return Err("code mode session is shutting down".to_string());
|
||||
}
|
||||
let initial_yield_time_ms = request.yield_time_ms.unwrap_or(DEFAULT_EXEC_YIELD_TIME_MS);
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
let cell_id = self.allocate_cell_id();
|
||||
self.start_session(
|
||||
cell_id.clone(),
|
||||
request,
|
||||
SessionResponseSender::Runtime(response_tx),
|
||||
Some(initial_yield_time_ms),
|
||||
@@ -101,9 +246,10 @@ impl CodeModeService {
|
||||
)
|
||||
.await?;
|
||||
|
||||
response_rx
|
||||
.await
|
||||
.map_err(|_| "exec runtime ended unexpectedly".to_string())
|
||||
Ok(StartedCell {
|
||||
cell_id,
|
||||
initial_response_rx: response_rx,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn execute_to_pending(
|
||||
@@ -111,7 +257,9 @@ impl CodeModeService {
|
||||
request: ExecuteRequest,
|
||||
) -> Result<ExecuteToPendingOutcome, String> {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
let cell_id = self.allocate_cell_id();
|
||||
self.start_session(
|
||||
cell_id,
|
||||
request,
|
||||
SessionResponseSender::ExecuteToPending(response_tx),
|
||||
/*initial_yield_time_ms*/ None,
|
||||
@@ -126,32 +274,32 @@ impl CodeModeService {
|
||||
|
||||
async fn start_session(
|
||||
&self,
|
||||
cell_id: CellId,
|
||||
request: ExecuteRequest,
|
||||
initial_response_tx: SessionResponseSender,
|
||||
initial_yield_time_ms: Option<u64>,
|
||||
pending_mode: PendingRuntimeMode,
|
||||
) -> Result<(), String> {
|
||||
let cell_id = request.cell_id.clone();
|
||||
let cell_id_text = cell_id.to_string();
|
||||
let (event_tx, event_rx) = mpsc::unbounded_channel();
|
||||
let (control_tx, control_rx) = mpsc::unbounded_channel();
|
||||
let stored_values = self.stored_values().await;
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let (runtime_tx, runtime_control_tx, runtime_terminate_handle) = {
|
||||
let mut sessions = self.inner.sessions.lock().await;
|
||||
if sessions.contains_key(&cell_id) {
|
||||
if sessions.contains_key(&cell_id_text) {
|
||||
return Err(format!("exec cell {cell_id} already exists"));
|
||||
}
|
||||
|
||||
let (runtime_tx, runtime_control_tx, runtime_terminate_handle) =
|
||||
spawn_runtime(stored_values, request, event_tx, pending_mode)?;
|
||||
|
||||
// Keep the session registry locked through insertion so a
|
||||
// caller-owned cell id cannot race with another execute and replace
|
||||
// a live runtime.
|
||||
sessions.insert(
|
||||
cell_id.clone(),
|
||||
cell_id_text.clone(),
|
||||
SessionHandle {
|
||||
control_tx,
|
||||
runtime_tx: runtime_tx.clone(),
|
||||
cancellation_token: cancellation_token.clone(),
|
||||
},
|
||||
);
|
||||
(runtime_tx, runtime_control_tx, runtime_terminate_handle)
|
||||
@@ -160,11 +308,12 @@ impl CodeModeService {
|
||||
tokio::spawn(run_session_control(
|
||||
Arc::clone(&self.inner),
|
||||
SessionControlContext {
|
||||
cell_id: cell_id.clone(),
|
||||
cell_id: cell_id_text,
|
||||
runtime_tx,
|
||||
runtime_control_tx,
|
||||
pending_mode,
|
||||
runtime_terminate_handle,
|
||||
cancellation_token,
|
||||
},
|
||||
event_rx,
|
||||
control_rx,
|
||||
@@ -188,13 +337,9 @@ impl CodeModeService {
|
||||
return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id)));
|
||||
};
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
let control_message = if request.terminate {
|
||||
SessionControlCommand::Terminate { response_tx }
|
||||
} else {
|
||||
SessionControlCommand::Poll {
|
||||
yield_time_ms: request.yield_time_ms,
|
||||
response_tx,
|
||||
}
|
||||
let control_message = SessionControlCommand::Poll {
|
||||
yield_time_ms: request.yield_time_ms,
|
||||
response_tx,
|
||||
};
|
||||
if handle.control_tx.send(control_message).is_err() {
|
||||
return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id)));
|
||||
@@ -207,6 +352,25 @@ impl CodeModeService {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn terminate(&self, cell_id: String) -> Result<WaitOutcome, String> {
|
||||
let handle = self.inner.sessions.lock().await.get(&cell_id).cloned();
|
||||
let Some(handle) = handle else {
|
||||
return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id)));
|
||||
};
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
if handle
|
||||
.control_tx
|
||||
.send(SessionControlCommand::Terminate { response_tx })
|
||||
.is_err()
|
||||
{
|
||||
return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id)));
|
||||
}
|
||||
match response_rx.await {
|
||||
Ok(response) => Ok(WaitOutcome::LiveCell(response)),
|
||||
Err(_) => Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn wait_to_pending(
|
||||
&self,
|
||||
request: WaitToPendingRequest,
|
||||
@@ -242,69 +406,28 @@ impl CodeModeService {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start_turn_worker(&self, host: Arc<dyn CodeModeTurnHost>) -> CodeModeTurnWorker {
|
||||
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
|
||||
let inner = Arc::clone(&self.inner);
|
||||
let turn_message_rx = self.inner.turn_message_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let next_message = tokio::select! {
|
||||
_ = &mut shutdown_rx => break,
|
||||
message = turn_message_rx.recv() => message.ok(),
|
||||
};
|
||||
let Some(next_message) = next_message else {
|
||||
break;
|
||||
};
|
||||
match next_message {
|
||||
TurnMessage::Notify {
|
||||
cell_id,
|
||||
call_id,
|
||||
text,
|
||||
} => {
|
||||
if let Err(err) = host.notify(call_id, cell_id.clone(), text).await {
|
||||
warn!(
|
||||
"failed to deliver code mode notification for cell {cell_id}: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
TurnMessage::ToolCall(invocation) => {
|
||||
let host = Arc::clone(&host);
|
||||
let inner = Arc::clone(&inner);
|
||||
tokio::spawn(async move {
|
||||
let cell_id = invocation.cell_id.clone();
|
||||
let runtime_tool_call_id = invocation.runtime_tool_call_id.clone();
|
||||
let response =
|
||||
host.invoke_tool(invocation, CancellationToken::new()).await;
|
||||
let runtime_tx = inner
|
||||
.sessions
|
||||
.lock()
|
||||
.await
|
||||
.get(&cell_id)
|
||||
.map(|handle| handle.runtime_tx.clone());
|
||||
let Some(runtime_tx) = runtime_tx else {
|
||||
return;
|
||||
};
|
||||
let command = match response {
|
||||
Ok(result) => RuntimeCommand::ToolResponse {
|
||||
id: runtime_tool_call_id,
|
||||
result,
|
||||
},
|
||||
Err(error_text) => RuntimeCommand::ToolError {
|
||||
id: runtime_tool_call_id,
|
||||
error_text,
|
||||
},
|
||||
};
|
||||
let _ = runtime_tx.send(command);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
CodeModeTurnWorker {
|
||||
shutdown_tx: Some(shutdown_tx),
|
||||
pub async fn shutdown(&self) -> Result<(), String> {
|
||||
self.inner.shutting_down.store(true, Ordering::Release);
|
||||
let handles = self
|
||||
.inner
|
||||
.sessions
|
||||
.lock()
|
||||
.await
|
||||
.values()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
for handle in handles {
|
||||
handle.cancellation_token.cancel();
|
||||
let (response_tx, _response_rx) = oneshot::channel();
|
||||
let _ = handle
|
||||
.control_tx
|
||||
.send(SessionControlCommand::Terminate { response_tx });
|
||||
let _ = handle.runtime_tx.send(RuntimeCommand::Terminate);
|
||||
}
|
||||
while !self.inner.sessions.lock().await.is_empty() {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -314,15 +437,53 @@ impl Default for CodeModeService {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CodeModeTurnWorker {
|
||||
shutdown_tx: Option<oneshot::Sender<()>>,
|
||||
impl Drop for CodeModeService {
|
||||
fn drop(&mut self) {
|
||||
self.inner.shutting_down.store(true, Ordering::Release);
|
||||
if let Ok(sessions) = self.inner.sessions.try_lock() {
|
||||
for handle in sessions.values() {
|
||||
handle.cancellation_token.cancel();
|
||||
let (response_tx, _response_rx) = oneshot::channel();
|
||||
let _ = handle
|
||||
.control_tx
|
||||
.send(SessionControlCommand::Terminate { response_tx });
|
||||
let _ = handle.runtime_tx.send(RuntimeCommand::Terminate);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CodeModeTurnWorker {
|
||||
fn drop(&mut self) {
|
||||
if let Some(shutdown_tx) = self.shutdown_tx.take() {
|
||||
let _ = shutdown_tx.send(());
|
||||
}
|
||||
impl CodeModeSession for CodeModeService {
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
request: ExecuteRequest,
|
||||
) -> CodeModeSessionResultFuture<'a, StartedCell> {
|
||||
Box::pin(CodeModeService::execute(self, request))
|
||||
}
|
||||
|
||||
fn wait<'a>(&'a self, request: WaitRequest) -> CodeModeSessionResultFuture<'a, WaitOutcome> {
|
||||
Box::pin(CodeModeService::wait(self, request))
|
||||
}
|
||||
|
||||
fn terminate<'a>(&'a self, cell_id: String) -> CodeModeSessionResultFuture<'a, WaitOutcome> {
|
||||
Box::pin(CodeModeService::terminate(self, cell_id))
|
||||
}
|
||||
|
||||
fn stored_values<'a>(
|
||||
&'a self,
|
||||
) -> Pin<Box<dyn Future<Output = HashMap<String, JsonValue>> + Send + 'a>> {
|
||||
Box::pin(CodeModeService::stored_values(self))
|
||||
}
|
||||
|
||||
fn replace_stored_values<'a>(
|
||||
&'a self,
|
||||
values: HashMap<String, JsonValue>,
|
||||
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
Box::pin(CodeModeService::replace_stored_values(self, values))
|
||||
}
|
||||
|
||||
fn shutdown<'a>(&'a self) -> CodeModeSessionResultFuture<'a, ()> {
|
||||
Box::pin(CodeModeService::shutdown(self))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -355,6 +516,7 @@ struct SessionControlContext {
|
||||
runtime_control_tx: std::sync::mpsc::Sender<RuntimeControlCommand>,
|
||||
pending_mode: PendingRuntimeMode,
|
||||
runtime_terminate_handle: v8::IsolateHandle,
|
||||
cancellation_token: CancellationToken,
|
||||
}
|
||||
|
||||
fn missing_cell_response(cell_id: String) -> RuntimeResponse {
|
||||
@@ -437,6 +599,7 @@ async fn run_session_control(
|
||||
runtime_control_tx,
|
||||
pending_mode,
|
||||
runtime_terminate_handle,
|
||||
cancellation_token,
|
||||
} = context;
|
||||
let mut content_items = Vec::new();
|
||||
let mut pending_tool_call_ids = Vec::new();
|
||||
@@ -516,11 +679,15 @@ async fn run_session_control(
|
||||
send_yield_response(&cell_id, &mut content_items, &mut response_tx);
|
||||
}
|
||||
RuntimeEvent::Notify { call_id, text } => {
|
||||
let _ = inner.turn_message_tx.send(TurnMessage::Notify {
|
||||
cell_id: cell_id.clone(),
|
||||
call_id,
|
||||
text,
|
||||
}).await;
|
||||
let delegate = Arc::clone(&inner.delegate);
|
||||
let cell_id = CellId::from(cell_id.clone());
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = delegate.notify(call_id, cell_id.clone(), text).await {
|
||||
warn!(
|
||||
"failed to deliver code mode notification for cell {cell_id}: {err}"
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
RuntimeEvent::ToolCall {
|
||||
id,
|
||||
@@ -533,15 +700,25 @@ async fn run_session_control(
|
||||
}
|
||||
let tool_call = CodeModeNestedToolCall {
|
||||
cell_id: cell_id.clone(),
|
||||
runtime_tool_call_id: id,
|
||||
runtime_tool_call_id: id.clone(),
|
||||
tool_name: name,
|
||||
tool_kind: kind,
|
||||
input,
|
||||
};
|
||||
let _ = inner
|
||||
.turn_message_tx
|
||||
.send(TurnMessage::ToolCall(tool_call))
|
||||
.await;
|
||||
let delegate = Arc::clone(&inner.delegate);
|
||||
let runtime_tx = runtime_tx.clone();
|
||||
let cancellation_token = cancellation_token.child_token();
|
||||
tokio::spawn(async move {
|
||||
let response = tokio::select! {
|
||||
response = delegate.invoke_tool(tool_call, cancellation_token.clone()) => response,
|
||||
_ = cancellation_token.cancelled() => return,
|
||||
};
|
||||
let command = match response {
|
||||
Ok(result) => RuntimeCommand::ToolResponse { id, result },
|
||||
Err(error_text) => RuntimeCommand::ToolError { id, error_text },
|
||||
};
|
||||
let _ = runtime_tx.send(command);
|
||||
});
|
||||
}
|
||||
RuntimeEvent::Result {
|
||||
stored_value_writes,
|
||||
@@ -562,7 +739,7 @@ async fn run_session_control(
|
||||
.stored_values
|
||||
.lock()
|
||||
.await
|
||||
.extend(stored_value_writes);
|
||||
.extend(stored_value_writes.clone());
|
||||
let result = PendingResult {
|
||||
content_items: std::mem::take(&mut content_items),
|
||||
error_text,
|
||||
@@ -617,6 +794,7 @@ async fn run_session_control(
|
||||
|
||||
response_tx = Some(SessionResponseSender::Runtime(next_response_tx));
|
||||
termination_requested = true;
|
||||
cancellation_token.cancel();
|
||||
yield_timer = None;
|
||||
let _ = runtime_tx.send(RuntimeCommand::Terminate);
|
||||
terminate_paused_runtime(&runtime_control_tx, pending_mode);
|
||||
@@ -650,6 +828,7 @@ async fn run_session_control(
|
||||
}
|
||||
|
||||
let _ = runtime_tx.send(RuntimeCommand::Terminate);
|
||||
cancellation_token.cancel();
|
||||
terminate_paused_runtime(&runtime_control_tx, pending_mode);
|
||||
inner.sessions.lock().await.remove(&cell_id);
|
||||
}
|
||||
@@ -687,6 +866,7 @@ mod tests {
|
||||
|
||||
use super::CodeModeService;
|
||||
use super::Inner;
|
||||
use super::NoopCodeModeSessionDelegate;
|
||||
use super::PendingRuntimeMode;
|
||||
use super::RuntimeCommand;
|
||||
use super::RuntimeResponse;
|
||||
@@ -708,7 +888,6 @@ mod tests {
|
||||
|
||||
fn execute_request(source: &str) -> ExecuteRequest {
|
||||
ExecuteRequest {
|
||||
cell_id: "1".to_string(),
|
||||
tool_call_id: "call_1".to_string(),
|
||||
enabled_tools: Vec::new(),
|
||||
source: source.to_string(),
|
||||
@@ -717,13 +896,22 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute(service: &CodeModeService, request: ExecuteRequest) -> RuntimeResponse {
|
||||
service
|
||||
.execute(request)
|
||||
.await
|
||||
.unwrap()
|
||||
.initial_response()
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn test_inner() -> Arc<Inner> {
|
||||
let (turn_message_tx, turn_message_rx) = async_channel::unbounded();
|
||||
Arc::new(Inner {
|
||||
stored_values: Mutex::new(HashMap::new()),
|
||||
sessions: Mutex::new(HashMap::new()),
|
||||
turn_message_tx,
|
||||
turn_message_rx,
|
||||
delegate: Arc::new(NoopCodeModeSessionDelegate),
|
||||
shutting_down: std::sync::atomic::AtomicBool::new(false),
|
||||
next_cell_id: AtomicU64::new(1),
|
||||
})
|
||||
}
|
||||
@@ -732,14 +920,15 @@ mod tests {
|
||||
async fn synchronous_exit_returns_successfully() {
|
||||
let service = CodeModeService::new();
|
||||
|
||||
let response = service
|
||||
.execute(ExecuteRequest {
|
||||
let response = execute(
|
||||
&service,
|
||||
ExecuteRequest {
|
||||
source: r#"text("before"); exit(); text("after");"#.to_string(),
|
||||
yield_time_ms: None,
|
||||
..execute_request("")
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
@@ -753,6 +942,31 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_interrupts_cpu_bound_cells() {
|
||||
let service = CodeModeService::new();
|
||||
|
||||
let cell = service
|
||||
.execute(ExecuteRequest {
|
||||
source: "while (true) {}".to_string(),
|
||||
..execute_request("")
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
cell.initial_response().await.unwrap(),
|
||||
RuntimeResponse::Yielded {
|
||||
cell_id: "1".to_string(),
|
||||
content_items: Vec::new(),
|
||||
}
|
||||
);
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(1), service.shutdown())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_to_pending_returns_completed_for_synchronous_results() {
|
||||
let service = CodeModeService::new();
|
||||
@@ -805,14 +1019,7 @@ mod tests {
|
||||
}
|
||||
);
|
||||
|
||||
let termination = service
|
||||
.wait(WaitRequest {
|
||||
cell_id: "1".to_string(),
|
||||
yield_time_ms: 1,
|
||||
terminate: true,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let termination = service.terminate("1".to_string()).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
termination,
|
||||
@@ -859,14 +1066,7 @@ await Promise.all([
|
||||
}
|
||||
);
|
||||
|
||||
let termination = service
|
||||
.wait(WaitRequest {
|
||||
cell_id: "1".to_string(),
|
||||
yield_time_ms: 1,
|
||||
terminate: true,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let termination = service.terminate("1".to_string()).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
termination,
|
||||
@@ -948,14 +1148,7 @@ await Promise.all([
|
||||
})
|
||||
);
|
||||
|
||||
let termination = service
|
||||
.wait(WaitRequest {
|
||||
cell_id: "1".to_string(),
|
||||
yield_time_ms: 1,
|
||||
terminate: true,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let termination = service.terminate("1".to_string()).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
termination,
|
||||
@@ -1027,14 +1220,7 @@ await new Promise(() => {});
|
||||
})
|
||||
);
|
||||
|
||||
let termination = service
|
||||
.wait(WaitRequest {
|
||||
cell_id: "1".to_string(),
|
||||
yield_time_ms: 1,
|
||||
terminate: true,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let termination = service.terminate("1".to_string()).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
termination,
|
||||
@@ -1112,14 +1298,15 @@ text("done");
|
||||
async fn v8_console_is_not_exposed_on_global_this() {
|
||||
let service = CodeModeService::new();
|
||||
|
||||
let response = service
|
||||
.execute(ExecuteRequest {
|
||||
let response = execute(
|
||||
&service,
|
||||
ExecuteRequest {
|
||||
source: r#"text(String(Object.hasOwn(globalThis, "console")));"#.to_string(),
|
||||
yield_time_ms: None,
|
||||
..execute_request("")
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
@@ -1137,8 +1324,9 @@ text("done");
|
||||
async fn date_locale_string_formats_with_icu_data() {
|
||||
let service = CodeModeService::new();
|
||||
|
||||
let response = service
|
||||
.execute(ExecuteRequest {
|
||||
let response = execute(
|
||||
&service,
|
||||
ExecuteRequest {
|
||||
source: r#"
|
||||
const value = new Date("2025-01-02T03:04:05Z")
|
||||
.toLocaleString("fr-FR", {
|
||||
@@ -1156,9 +1344,9 @@ text(value);
|
||||
.to_string(),
|
||||
yield_time_ms: None,
|
||||
..execute_request("")
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
@@ -1176,8 +1364,9 @@ text(value);
|
||||
async fn intl_date_time_format_formats_with_icu_data() {
|
||||
let service = CodeModeService::new();
|
||||
|
||||
let response = service
|
||||
.execute(ExecuteRequest {
|
||||
let response = execute(
|
||||
&service,
|
||||
ExecuteRequest {
|
||||
source: r#"
|
||||
const formatter = new Intl.DateTimeFormat("fr-FR", {
|
||||
weekday: "long",
|
||||
@@ -1194,9 +1383,9 @@ text(formatter.format(new Date("2025-01-02T03:04:05Z")));
|
||||
.to_string(),
|
||||
yield_time_ms: None,
|
||||
..execute_request("")
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
@@ -1214,8 +1403,9 @@ text(formatter.format(new Date("2025-01-02T03:04:05Z")));
|
||||
async fn output_helpers_return_undefined() {
|
||||
let service = CodeModeService::new();
|
||||
|
||||
let response = service
|
||||
.execute(ExecuteRequest {
|
||||
let response = execute(
|
||||
&service,
|
||||
ExecuteRequest {
|
||||
source: r#"
|
||||
const returnsUndefined = [
|
||||
text("first"),
|
||||
@@ -1227,9 +1417,9 @@ text(JSON.stringify(returnsUndefined));
|
||||
.to_string(),
|
||||
yield_time_ms: None,
|
||||
..execute_request("")
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
@@ -1256,8 +1446,9 @@ text(JSON.stringify(returnsUndefined));
|
||||
async fn image_helper_accepts_raw_mcp_image_block_with_original_detail() {
|
||||
let service = CodeModeService::new();
|
||||
|
||||
let response = service
|
||||
.execute(ExecuteRequest {
|
||||
let response = execute(
|
||||
&service,
|
||||
ExecuteRequest {
|
||||
source: r#"
|
||||
image({
|
||||
type: "image",
|
||||
@@ -1269,9 +1460,9 @@ image({
|
||||
.to_string(),
|
||||
yield_time_ms: None,
|
||||
..execute_request("")
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
@@ -1290,8 +1481,9 @@ image({
|
||||
async fn image_helper_second_arg_overrides_explicit_object_detail() {
|
||||
let service = CodeModeService::new();
|
||||
|
||||
let response = service
|
||||
.execute(ExecuteRequest {
|
||||
let response = execute(
|
||||
&service,
|
||||
ExecuteRequest {
|
||||
source: r#"
|
||||
image(
|
||||
{
|
||||
@@ -1304,9 +1496,9 @@ image(
|
||||
.to_string(),
|
||||
yield_time_ms: None,
|
||||
..execute_request("")
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
@@ -1325,8 +1517,9 @@ image(
|
||||
async fn image_helper_second_arg_overrides_raw_mcp_image_detail() {
|
||||
let service = CodeModeService::new();
|
||||
|
||||
let response = service
|
||||
.execute(ExecuteRequest {
|
||||
let response = execute(
|
||||
&service,
|
||||
ExecuteRequest {
|
||||
source: r#"
|
||||
image(
|
||||
{
|
||||
@@ -1341,9 +1534,9 @@ image(
|
||||
.to_string(),
|
||||
yield_time_ms: None,
|
||||
..execute_request("")
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
@@ -1362,8 +1555,9 @@ image(
|
||||
async fn image_helper_rejects_unsupported_detail() {
|
||||
let service = CodeModeService::new();
|
||||
|
||||
let response = service
|
||||
.execute(ExecuteRequest {
|
||||
let response = execute(
|
||||
&service,
|
||||
ExecuteRequest {
|
||||
source: r#"
|
||||
image({
|
||||
image_url: "https://example.com/image.jpg",
|
||||
@@ -1373,9 +1567,9 @@ image({
|
||||
.to_string(),
|
||||
yield_time_ms: None,
|
||||
..execute_request("")
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
@@ -1391,8 +1585,9 @@ image({
|
||||
async fn image_helper_rejects_raw_mcp_result_container() {
|
||||
let service = CodeModeService::new();
|
||||
|
||||
let response = service
|
||||
.execute(ExecuteRequest {
|
||||
let response = execute(
|
||||
&service,
|
||||
ExecuteRequest {
|
||||
source: r#"
|
||||
image({
|
||||
content: [
|
||||
@@ -1409,9 +1604,9 @@ image({
|
||||
.to_string(),
|
||||
yield_time_ms: None,
|
||||
..execute_request("")
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
@@ -1433,7 +1628,6 @@ image({
|
||||
.wait(WaitRequest {
|
||||
cell_id: "missing".to_string(),
|
||||
yield_time_ms: 1,
|
||||
terminate: false,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1475,6 +1669,7 @@ image({
|
||||
runtime_control_tx,
|
||||
pending_mode: PendingRuntimeMode::Continue,
|
||||
runtime_terminate_handle,
|
||||
cancellation_token: tokio_util::sync::CancellationToken::new(),
|
||||
},
|
||||
event_rx,
|
||||
control_rx,
|
||||
|
||||
@@ -603,6 +603,9 @@ async fn shutdown_session_runtime(sess: &Arc<Session>) {
|
||||
.unified_exec_manager
|
||||
.terminate_all_processes()
|
||||
.await;
|
||||
if let Err(err) = sess.services.code_mode_service.shutdown().await {
|
||||
warn!("failed to shutdown code mode session: {err}");
|
||||
}
|
||||
let mcp_shutdown = {
|
||||
let mut manager = sess.services.mcp_connection_manager.write().await;
|
||||
manager.begin_shutdown()
|
||||
|
||||
@@ -909,16 +909,12 @@ async fn run_sampling_request(
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
);
|
||||
let _code_mode_worker = sess
|
||||
.services
|
||||
.code_mode_service
|
||||
.start_turn_worker(
|
||||
&sess,
|
||||
&turn_context,
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
)
|
||||
.await;
|
||||
let _code_mode_worker = sess.services.code_mode_service.start_turn_worker(
|
||||
&sess,
|
||||
&turn_context,
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
);
|
||||
let mut retries = 0;
|
||||
let mut initial_input = Some(input);
|
||||
loop {
|
||||
|
||||
234
codex-rs/core/src/tools/code_mode/delegate.rs
Normal file
234
codex-rs/core/src/tools/code_mode/delegate.rs
Normal file
@@ -0,0 +1,234 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use codex_code_mode::CellId;
|
||||
use codex_code_mode::CodeModeNestedToolCall;
|
||||
use codex_code_mode::CodeModeSessionDelegate;
|
||||
use codex_code_mode::NotificationFuture;
|
||||
use codex_code_mode::ToolInvocationFuture;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use serde_json::Value as JsonValue;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::watch;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::warn;
|
||||
|
||||
use super::ExecContext;
|
||||
use super::PUBLIC_TOOL_NAME;
|
||||
use super::call_nested_tool;
|
||||
use crate::tools::ToolRouter;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
|
||||
pub(super) struct CodeModeDispatchBroker {
|
||||
dispatch_tx: async_channel::Sender<DispatchMessage>,
|
||||
dispatch_rx: async_channel::Receiver<DispatchMessage>,
|
||||
dispatch_gates: Arc<Mutex<HashMap<String, watch::Sender<bool>>>>,
|
||||
}
|
||||
|
||||
impl CodeModeDispatchBroker {
|
||||
pub(super) fn new() -> Self {
|
||||
let (dispatch_tx, dispatch_rx) = async_channel::unbounded();
|
||||
Self {
|
||||
dispatch_tx,
|
||||
dispatch_rx,
|
||||
dispatch_gates: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn mark_cell_ready_for_dispatch(&self, cell_id: &CellId) {
|
||||
dispatch_gate(&self.dispatch_gates, cell_id.as_str()).send_replace(true);
|
||||
}
|
||||
|
||||
pub(super) fn start_turn_worker(
|
||||
&self,
|
||||
exec: ExecContext,
|
||||
router: Arc<ToolRouter>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
) -> CodeModeDispatchWorker {
|
||||
let tool_runtime = ToolCallRuntime::new(
|
||||
router,
|
||||
Arc::clone(&exec.session),
|
||||
Arc::clone(&exec.turn),
|
||||
tracker,
|
||||
);
|
||||
let host = Arc::new(CoreTurnHost { exec, tool_runtime });
|
||||
let dispatch_rx = self.dispatch_rx.clone();
|
||||
let dispatch_gates = Arc::clone(&self.dispatch_gates);
|
||||
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let message = tokio::select! {
|
||||
_ = &mut shutdown_rx => break,
|
||||
message = dispatch_rx.recv() => message.ok(),
|
||||
};
|
||||
let Some(message) = message else {
|
||||
break;
|
||||
};
|
||||
match message {
|
||||
DispatchMessage::Notify {
|
||||
call_id,
|
||||
cell_id,
|
||||
text,
|
||||
} => {
|
||||
wait_until_cell_ready_for_dispatch(&dispatch_gates, &cell_id).await;
|
||||
if let Err(err) = host.notify(call_id, cell_id.clone(), text).await {
|
||||
warn!(
|
||||
"failed to deliver code mode notification for cell {cell_id}: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
DispatchMessage::InvokeTool {
|
||||
invocation,
|
||||
cancellation_token,
|
||||
response_tx,
|
||||
} => {
|
||||
wait_until_cell_ready_for_dispatch(&dispatch_gates, &invocation.cell_id)
|
||||
.await;
|
||||
let host = Arc::clone(&host);
|
||||
tokio::spawn(async move {
|
||||
let response = host.invoke_tool(invocation, cancellation_token).await;
|
||||
let _ = response_tx.send(response);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
CodeModeDispatchWorker {
|
||||
shutdown_tx: Some(shutdown_tx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch_gate(
|
||||
dispatch_gates: &Mutex<HashMap<String, watch::Sender<bool>>>,
|
||||
cell_id: &str,
|
||||
) -> watch::Sender<bool> {
|
||||
let mut dispatch_gates = match dispatch_gates.lock() {
|
||||
Ok(dispatch_gates) => dispatch_gates,
|
||||
Err(poisoned) => poisoned.into_inner(),
|
||||
};
|
||||
dispatch_gates
|
||||
.entry(cell_id.to_string())
|
||||
.or_insert_with(|| watch::channel(false).0)
|
||||
.clone()
|
||||
}
|
||||
|
||||
async fn wait_until_cell_ready_for_dispatch(
|
||||
dispatch_gates: &Mutex<HashMap<String, watch::Sender<bool>>>,
|
||||
cell_id: &str,
|
||||
) {
|
||||
let mut ready_rx = dispatch_gate(dispatch_gates, cell_id).subscribe();
|
||||
while !*ready_rx.borrow_and_update() && ready_rx.changed().await.is_ok() {}
|
||||
}
|
||||
|
||||
impl CodeModeSessionDelegate for CodeModeDispatchBroker {
|
||||
fn invoke_tool<'a>(
|
||||
&'a self,
|
||||
invocation: CodeModeNestedToolCall,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> ToolInvocationFuture<'a> {
|
||||
Box::pin(async move {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
self.dispatch_tx
|
||||
.send(DispatchMessage::InvokeTool {
|
||||
invocation,
|
||||
cancellation_token: cancellation_token.clone(),
|
||||
response_tx,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| "code mode nested tool dispatcher is unavailable".to_string())?;
|
||||
tokio::select! {
|
||||
response = response_rx => response
|
||||
.map_err(|_| "code mode nested tool dispatcher stopped".to_string())?,
|
||||
_ = cancellation_token.cancelled() => {
|
||||
Err("code mode nested tool call cancelled".to_string())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn notify<'a>(
|
||||
&'a self,
|
||||
call_id: String,
|
||||
cell_id: CellId,
|
||||
text: String,
|
||||
) -> NotificationFuture<'a> {
|
||||
Box::pin(async move {
|
||||
self.dispatch_tx
|
||||
.send(DispatchMessage::Notify {
|
||||
call_id,
|
||||
cell_id: cell_id.to_string(),
|
||||
text,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| "code mode notification dispatcher is unavailable".to_string())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
enum DispatchMessage {
|
||||
InvokeTool {
|
||||
invocation: CodeModeNestedToolCall,
|
||||
cancellation_token: CancellationToken,
|
||||
response_tx: oneshot::Sender<Result<JsonValue, String>>,
|
||||
},
|
||||
Notify {
|
||||
call_id: String,
|
||||
cell_id: String,
|
||||
text: String,
|
||||
},
|
||||
}
|
||||
|
||||
pub(crate) struct CodeModeDispatchWorker {
|
||||
shutdown_tx: Option<oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
impl Drop for CodeModeDispatchWorker {
|
||||
fn drop(&mut self) {
|
||||
if let Some(shutdown_tx) = self.shutdown_tx.take() {
|
||||
let _ = shutdown_tx.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct CoreTurnHost {
|
||||
exec: ExecContext,
|
||||
tool_runtime: ToolCallRuntime,
|
||||
}
|
||||
|
||||
impl CoreTurnHost {
|
||||
async fn invoke_tool(
|
||||
&self,
|
||||
invocation: CodeModeNestedToolCall,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Result<JsonValue, String> {
|
||||
call_nested_tool(
|
||||
self.exec.clone(),
|
||||
self.tool_runtime.clone(),
|
||||
invocation,
|
||||
cancellation_token,
|
||||
)
|
||||
.await
|
||||
.map_err(|error| error.to_string())
|
||||
}
|
||||
|
||||
async fn notify(&self, call_id: String, cell_id: String, text: String) -> Result<(), String> {
|
||||
if text.trim().is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.exec
|
||||
.session
|
||||
.inject_response_items(vec![ResponseInputItem::CustomToolCallOutput {
|
||||
call_id,
|
||||
name: Some(PUBLIC_TOOL_NAME.to_string()),
|
||||
output: FunctionCallOutputPayload::from_text(text),
|
||||
}])
|
||||
.await
|
||||
.map_err(|_| {
|
||||
format!("failed to inject exec notify message for cell {cell_id}: no active turn")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -38,9 +38,21 @@ impl CodeModeExecuteHandler {
|
||||
let exec = ExecContext { session, turn };
|
||||
let enabled_tools =
|
||||
codex_tools::collect_code_mode_tool_definitions(&self.nested_tool_specs);
|
||||
// Allocate before starting V8 so the trace can create the parent
|
||||
// CodeCell before model-authored JavaScript issues nested tool calls.
|
||||
let runtime_cell_id = exec.session.services.code_mode_service.allocate_cell_id();
|
||||
let started_at = std::time::Instant::now();
|
||||
let started_cell = exec
|
||||
.session
|
||||
.services
|
||||
.code_mode_service
|
||||
.execute(codex_code_mode::ExecuteRequest {
|
||||
tool_call_id: call_id.clone(),
|
||||
enabled_tools,
|
||||
source: args.code.clone(),
|
||||
yield_time_ms: args.yield_time_ms,
|
||||
max_output_tokens: args.max_output_tokens,
|
||||
})
|
||||
.await
|
||||
.map_err(FunctionCallError::RespondToModel)?;
|
||||
let runtime_cell_id = started_cell.cell_id.to_string();
|
||||
let code_cell_trace = exec
|
||||
.session
|
||||
.services
|
||||
@@ -51,19 +63,12 @@ impl CodeModeExecuteHandler {
|
||||
call_id.as_str(),
|
||||
args.code.as_str(),
|
||||
);
|
||||
let started_at = std::time::Instant::now();
|
||||
let response = exec
|
||||
.session
|
||||
exec.session
|
||||
.services
|
||||
.code_mode_service
|
||||
.execute(codex_code_mode::ExecuteRequest {
|
||||
cell_id: runtime_cell_id,
|
||||
tool_call_id: call_id,
|
||||
enabled_tools,
|
||||
source: args.code,
|
||||
yield_time_ms: args.yield_time_ms,
|
||||
max_output_tokens: args.max_output_tokens,
|
||||
})
|
||||
.mark_cell_ready_for_dispatch(&started_cell.cell_id);
|
||||
let response = started_cell
|
||||
.initial_response()
|
||||
.await
|
||||
.map_err(FunctionCallError::RespondToModel)?;
|
||||
// Record the raw runtime boundary. The model-visible custom-tool output
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
mod delegate;
|
||||
mod execute_handler;
|
||||
pub(crate) mod execute_spec;
|
||||
mod response_adapter;
|
||||
@@ -8,12 +9,10 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_code_mode::CodeModeNestedToolCall;
|
||||
use codex_code_mode::CodeModeSession;
|
||||
use codex_code_mode::CodeModeToolKind;
|
||||
use codex_code_mode::CodeModeTurnHost;
|
||||
use codex_code_mode::RuntimeResponse;
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use serde_json::Value as JsonValue;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
@@ -36,6 +35,8 @@ use codex_utils_output_truncation::TruncationPolicy;
|
||||
use codex_utils_output_truncation::formatted_truncate_text_content_items_with_policy;
|
||||
use codex_utils_output_truncation::truncate_function_output_items_with_policy;
|
||||
|
||||
use delegate::CodeModeDispatchBroker;
|
||||
use delegate::CodeModeDispatchWorker;
|
||||
pub(crate) use execute_handler::CodeModeExecuteHandler;
|
||||
use response_adapter::into_function_call_output_content_items;
|
||||
pub(crate) use wait_handler::CodeModeWaitHandler;
|
||||
@@ -56,42 +57,61 @@ pub(crate) struct ExecContext {
|
||||
}
|
||||
|
||||
pub(crate) struct CodeModeService {
|
||||
inner: codex_code_mode::CodeModeService,
|
||||
session: Option<Arc<dyn CodeModeSession>>,
|
||||
dispatch_broker: Arc<CodeModeDispatchBroker>,
|
||||
}
|
||||
|
||||
impl CodeModeService {
|
||||
pub(crate) fn new() -> Self {
|
||||
let dispatch_broker = Arc::new(CodeModeDispatchBroker::new());
|
||||
Self {
|
||||
inner: codex_code_mode::CodeModeService::new(),
|
||||
session: Some(Arc::new(codex_code_mode::CodeModeService::with_delegate(
|
||||
dispatch_broker.clone(),
|
||||
))),
|
||||
dispatch_broker,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn allocate_cell_id(&self) -> String {
|
||||
self.inner.allocate_cell_id()
|
||||
}
|
||||
|
||||
pub(crate) async fn execute(
|
||||
&self,
|
||||
request: codex_code_mode::ExecuteRequest,
|
||||
) -> Result<RuntimeResponse, String> {
|
||||
self.inner.execute(request).await
|
||||
) -> Result<codex_code_mode::StartedCell, String> {
|
||||
self.session()?.execute(request).await
|
||||
}
|
||||
|
||||
pub(crate) async fn wait(
|
||||
&self,
|
||||
request: codex_code_mode::WaitRequest,
|
||||
) -> Result<codex_code_mode::WaitOutcome, String> {
|
||||
self.inner.wait(request).await
|
||||
self.session()?.wait(request).await
|
||||
}
|
||||
|
||||
pub(crate) async fn start_turn_worker(
|
||||
pub(crate) async fn terminate(
|
||||
&self,
|
||||
cell_id: String,
|
||||
) -> Result<codex_code_mode::WaitOutcome, String> {
|
||||
self.session()?.terminate(cell_id).await
|
||||
}
|
||||
|
||||
pub(crate) async fn shutdown(&self) -> Result<(), String> {
|
||||
match &self.session {
|
||||
Some(session) => session.shutdown().await,
|
||||
None => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn mark_cell_ready_for_dispatch(&self, cell_id: &codex_code_mode::CellId) {
|
||||
self.dispatch_broker.mark_cell_ready_for_dispatch(cell_id);
|
||||
}
|
||||
|
||||
pub(crate) fn start_turn_worker(
|
||||
&self,
|
||||
session: &Arc<Session>,
|
||||
turn: &Arc<TurnContext>,
|
||||
router: Arc<ToolRouter>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
) -> Option<codex_code_mode::CodeModeTurnWorker> {
|
||||
if !turn.features.enabled(Feature::CodeMode) {
|
||||
) -> Option<CodeModeDispatchWorker> {
|
||||
if !turn.features.enabled(Feature::CodeMode) || self.session.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
@@ -99,50 +119,16 @@ impl CodeModeService {
|
||||
session: Arc::clone(session),
|
||||
turn: Arc::clone(turn),
|
||||
};
|
||||
let tool_runtime =
|
||||
ToolCallRuntime::new(router, Arc::clone(session), Arc::clone(turn), tracker);
|
||||
let host = Arc::new(CoreTurnHost { exec, tool_runtime });
|
||||
Some(self.inner.start_turn_worker(host))
|
||||
}
|
||||
}
|
||||
|
||||
struct CoreTurnHost {
|
||||
exec: ExecContext,
|
||||
tool_runtime: ToolCallRuntime,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl CodeModeTurnHost for CoreTurnHost {
|
||||
async fn invoke_tool(
|
||||
&self,
|
||||
invocation: CodeModeNestedToolCall,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Result<JsonValue, String> {
|
||||
call_nested_tool(
|
||||
self.exec.clone(),
|
||||
self.tool_runtime.clone(),
|
||||
invocation,
|
||||
cancellation_token,
|
||||
Some(
|
||||
self.dispatch_broker
|
||||
.start_turn_worker(exec, router, tracker),
|
||||
)
|
||||
.await
|
||||
.map_err(|error| error.to_string())
|
||||
}
|
||||
|
||||
async fn notify(&self, call_id: String, cell_id: String, text: String) -> Result<(), String> {
|
||||
if text.trim().is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.exec
|
||||
.session
|
||||
.inject_response_items(vec![ResponseInputItem::CustomToolCallOutput {
|
||||
call_id,
|
||||
name: Some(PUBLIC_TOOL_NAME.to_string()),
|
||||
output: FunctionCallOutputPayload::from_text(text),
|
||||
}])
|
||||
.await
|
||||
.map_err(|_| {
|
||||
format!("failed to inject exec notify message for cell {cell_id}: no active turn")
|
||||
})
|
||||
fn session(&self) -> Result<&Arc<dyn CodeModeSession>, String> {
|
||||
self.session
|
||||
.as_ref()
|
||||
.ok_or_else(|| "code mode is unavailable".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -73,17 +73,23 @@ impl ToolExecutor<ToolInvocation> for CodeModeWaitHandler {
|
||||
let args: ExecWaitArgs = parse_arguments(&arguments)?;
|
||||
let exec = ExecContext { session, turn };
|
||||
let started_at = std::time::Instant::now();
|
||||
let wait_response = exec
|
||||
.session
|
||||
.services
|
||||
.code_mode_service
|
||||
.wait(codex_code_mode::WaitRequest {
|
||||
cell_id: args.cell_id,
|
||||
yield_time_ms: args.yield_time_ms,
|
||||
terminate: args.terminate,
|
||||
})
|
||||
.await
|
||||
.map_err(FunctionCallError::RespondToModel)?;
|
||||
let wait_response = if args.terminate {
|
||||
exec.session
|
||||
.services
|
||||
.code_mode_service
|
||||
.terminate(args.cell_id)
|
||||
.await
|
||||
} else {
|
||||
exec.session
|
||||
.services
|
||||
.code_mode_service
|
||||
.wait(codex_code_mode::WaitRequest {
|
||||
cell_id: args.cell_id,
|
||||
yield_time_ms: args.yield_time_ms,
|
||||
})
|
||||
.await
|
||||
}
|
||||
.map_err(FunctionCallError::RespondToModel)?;
|
||||
if let codex_code_mode::WaitOutcome::LiveCell(response) = &wait_response
|
||||
&& !matches!(response, codex_code_mode::RuntimeResponse::Yielded { .. })
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user