mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
Refactor code mode worker dispatch
This commit is contained in:
@@ -5464,6 +5464,11 @@ pub(crate) async fn run_turn(
|
||||
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
|
||||
// many turns, from the perspective of the user, it is a single turn.
|
||||
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
let _code_mode_worker = sess
|
||||
.services
|
||||
.code_mode_service
|
||||
.start_turn_worker(&sess, &turn_context, &turn_diff_tracker)
|
||||
.await;
|
||||
let mut server_model_warning_emitted_for_turn = false;
|
||||
|
||||
// `ModelClientSession` is turn-scoped and caches WebSocket + sticky routing state, so we reuse
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -32,6 +31,8 @@ use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::warn;
|
||||
|
||||
@@ -50,79 +51,109 @@ struct ExecContext {
|
||||
|
||||
pub(crate) struct CodeModeProcess {
|
||||
child: tokio::process::Child,
|
||||
stdin: tokio::process::ChildStdin,
|
||||
stdout_lines: tokio::io::Lines<BufReader<tokio::process::ChildStdout>>,
|
||||
stderr_task: Option<JoinHandle<()>>,
|
||||
pending_messages: HashMap<i32, VecDeque<NodeToHostMessage>>,
|
||||
stdin: Arc<Mutex<tokio::process::ChildStdin>>,
|
||||
stdout_task: JoinHandle<()>,
|
||||
response_waiters: Arc<Mutex<HashMap<String, oneshot::Sender<NodeToHostMessage>>>>,
|
||||
tool_call_rx: Arc<Mutex<mpsc::UnboundedReceiver<CodeModeToolCall>>>,
|
||||
}
|
||||
|
||||
pub(crate) struct CodeModeWorker {
|
||||
shutdown_tx: Option<oneshot::Sender<()>>,
|
||||
task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
struct CodeModeToolCall {
|
||||
request_id: String,
|
||||
id: String,
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
input: Option<JsonValue>,
|
||||
}
|
||||
|
||||
impl Drop for CodeModeWorker {
|
||||
fn drop(&mut self) {
|
||||
if let Some(shutdown_tx) = self.shutdown_tx.take() {
|
||||
let _ = shutdown_tx.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CodeModeProcess {
|
||||
async fn write(&mut self, message: &HostToNodeMessage) -> Result<(), std::io::Error> {
|
||||
let line = serde_json::to_string(message).map_err(std::io::Error::other)?;
|
||||
self.stdin.write_all(line.as_bytes()).await?;
|
||||
self.stdin.write_all(b"\n").await?;
|
||||
self.stdin.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read(&mut self, session_id: i32) -> Result<NodeToHostMessage, std::io::Error> {
|
||||
if let Some(message) = self
|
||||
.pending_messages
|
||||
.get_mut(&session_id)
|
||||
.and_then(VecDeque::pop_front)
|
||||
{
|
||||
return Ok(message);
|
||||
}
|
||||
|
||||
loop {
|
||||
let Some(line) = self.stdout_lines.next_line().await? else {
|
||||
match self.wait_for_exit().await {
|
||||
Ok(status) => {
|
||||
self.join_stderr_task().await;
|
||||
return Err(std::io::Error::other(format!(
|
||||
"{PUBLIC_TOOL_NAME} runner exited without returning a result (status {status})"
|
||||
)));
|
||||
}
|
||||
Err(err) => return Err(std::io::Error::other(err)),
|
||||
fn worker(&self, exec: ExecContext) -> CodeModeWorker {
|
||||
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
|
||||
let stdin = Arc::clone(&self.stdin);
|
||||
let tool_call_rx = Arc::clone(&self.tool_call_rx);
|
||||
let task = tokio::spawn(async move {
|
||||
loop {
|
||||
let tool_call = tokio::select! {
|
||||
_ = &mut shutdown_rx => break,
|
||||
tool_call = async {
|
||||
let mut tool_call_rx = tool_call_rx.lock().await;
|
||||
tool_call_rx.recv().await
|
||||
} => tool_call,
|
||||
};
|
||||
let Some(tool_call) = tool_call else {
|
||||
break;
|
||||
};
|
||||
let response = HostToNodeMessage::Response {
|
||||
request_id: tool_call.request_id,
|
||||
id: tool_call.id,
|
||||
code_mode_result: call_nested_tool(
|
||||
exec.clone(),
|
||||
tool_call.name,
|
||||
tool_call.input,
|
||||
)
|
||||
.await,
|
||||
};
|
||||
if let Err(err) = write_message(&stdin, &response).await {
|
||||
warn!("failed to write {PUBLIC_TOOL_NAME} tool response: {err}");
|
||||
break;
|
||||
}
|
||||
};
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let message: NodeToHostMessage =
|
||||
serde_json::from_str(&line).map_err(std::io::Error::other)?;
|
||||
let message_session_id = message_session_id(&message);
|
||||
if message_session_id == session_id {
|
||||
return Ok(message);
|
||||
}
|
||||
self.pending_messages
|
||||
.entry(message_session_id)
|
||||
.or_default()
|
||||
.push_back(message);
|
||||
});
|
||||
|
||||
CodeModeWorker {
|
||||
shutdown_tx: Some(shutdown_tx),
|
||||
task,
|
||||
}
|
||||
}
|
||||
|
||||
fn has_exited(&mut self) -> Result<bool, String> {
|
||||
async fn send(
|
||||
&mut self,
|
||||
request_id: &str,
|
||||
message: &HostToNodeMessage,
|
||||
) -> Result<NodeToHostMessage, std::io::Error> {
|
||||
if self.stdout_task.is_finished() {
|
||||
return Err(std::io::Error::other(format!(
|
||||
"{PUBLIC_TOOL_NAME} runner is not available"
|
||||
)));
|
||||
}
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.response_waiters
|
||||
.lock()
|
||||
.await
|
||||
.insert(request_id.to_string(), tx);
|
||||
if let Err(err) = write_message(&self.stdin, message).await {
|
||||
self.response_waiters.lock().await.remove(request_id);
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
match rx.await {
|
||||
Ok(message) => Ok(message),
|
||||
Err(_) => Err(std::io::Error::other(format!(
|
||||
"{PUBLIC_TOOL_NAME} runner is not available"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn has_exited(&mut self) -> Result<bool, std::io::Error> {
|
||||
self.child
|
||||
.try_wait()
|
||||
.map(|status| status.is_some())
|
||||
.map_err(|err| format!("failed to inspect {PUBLIC_TOOL_NAME} runner: {err}"))
|
||||
}
|
||||
|
||||
async fn wait_for_exit(&mut self) -> Result<std::process::ExitStatus, String> {
|
||||
self.child
|
||||
.wait()
|
||||
.await
|
||||
.map_err(|err| format!("failed to wait for {PUBLIC_TOOL_NAME} runner: {err}"))
|
||||
}
|
||||
|
||||
async fn join_stderr_task(&mut self) {
|
||||
let Some(stderr_task) = self.stderr_task.take() else {
|
||||
return;
|
||||
};
|
||||
if let Err(err) = stderr_task.await {
|
||||
warn!("failed to join {PUBLIC_TOOL_NAME} stderr task: {err}");
|
||||
}
|
||||
.map_err(std::io::Error::other)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,26 +184,62 @@ impl CodeModeService {
|
||||
|
||||
async fn ensure_started(
|
||||
&self,
|
||||
) -> Result<tokio::sync::OwnedMutexGuard<Option<CodeModeProcess>>, String> {
|
||||
) -> Result<tokio::sync::OwnedMutexGuard<Option<CodeModeProcess>>, std::io::Error> {
|
||||
let mut process_slot = self.process.lock().await;
|
||||
let needs_spawn = match process_slot.as_mut() {
|
||||
Some(process) => !matches!(process.has_exited(), Ok(false)),
|
||||
None => true,
|
||||
};
|
||||
if needs_spawn {
|
||||
let node_path = resolve_compatible_node(self.js_repl_node_path.as_deref()).await?;
|
||||
let node_path = resolve_compatible_node(self.js_repl_node_path.as_deref())
|
||||
.await
|
||||
.map_err(std::io::Error::other)?;
|
||||
*process_slot = Some(spawn_code_mode_process(&node_path).await?);
|
||||
}
|
||||
drop(process_slot);
|
||||
Ok(self.process.clone().lock_owned().await)
|
||||
}
|
||||
|
||||
pub(crate) async fn start_turn_worker(
|
||||
&self,
|
||||
session: &Arc<Session>,
|
||||
turn: &Arc<TurnContext>,
|
||||
tracker: &SharedTurnDiffTracker,
|
||||
) -> Option<CodeModeWorker> {
|
||||
if !turn.features.enabled(Feature::CodeMode) {
|
||||
return None;
|
||||
}
|
||||
let exec = ExecContext {
|
||||
session: Arc::clone(session),
|
||||
turn: Arc::clone(turn),
|
||||
tracker: Arc::clone(tracker),
|
||||
};
|
||||
let mut process_slot = match self.ensure_started().await {
|
||||
Ok(process_slot) => process_slot,
|
||||
Err(err) => {
|
||||
warn!("failed to start {PUBLIC_TOOL_NAME} worker for turn: {err}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
let Some(process) = process_slot.as_mut() else {
|
||||
warn!(
|
||||
"failed to start {PUBLIC_TOOL_NAME} worker for turn: {PUBLIC_TOOL_NAME} runner failed to start"
|
||||
);
|
||||
return None;
|
||||
};
|
||||
Some(process.worker(exec))
|
||||
}
|
||||
|
||||
pub(crate) async fn allocate_session_id(&self) -> i32 {
|
||||
let mut next_session_id = self.next_session_id.lock().await;
|
||||
let session_id = *next_session_id;
|
||||
*next_session_id = next_session_id.saturating_add(1);
|
||||
session_id
|
||||
}
|
||||
|
||||
pub(crate) async fn allocate_request_id(&self) -> String {
|
||||
uuid::Uuid::new_v4().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
|
||||
@@ -197,20 +264,23 @@ struct EnabledTool {
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum HostToNodeMessage {
|
||||
Start {
|
||||
request_id: String,
|
||||
session_id: i32,
|
||||
enabled_tools: Vec<EnabledTool>,
|
||||
stored_values: HashMap<String, JsonValue>,
|
||||
source: String,
|
||||
},
|
||||
Poll {
|
||||
request_id: String,
|
||||
session_id: i32,
|
||||
yield_time_ms: u64,
|
||||
},
|
||||
Terminate {
|
||||
request_id: String,
|
||||
session_id: i32,
|
||||
},
|
||||
Response {
|
||||
session_id: i32,
|
||||
request_id: String,
|
||||
id: String,
|
||||
code_mode_result: JsonValue,
|
||||
},
|
||||
@@ -220,22 +290,19 @@ enum HostToNodeMessage {
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum NodeToHostMessage {
|
||||
ToolCall {
|
||||
session_id: i32,
|
||||
id: String,
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
input: Option<JsonValue>,
|
||||
#[serde(flatten)]
|
||||
tool_call: CodeModeToolCall,
|
||||
},
|
||||
Yielded {
|
||||
session_id: i32,
|
||||
request_id: String,
|
||||
content_items: Vec<JsonValue>,
|
||||
},
|
||||
Terminated {
|
||||
session_id: i32,
|
||||
request_id: String,
|
||||
content_items: Vec<JsonValue>,
|
||||
},
|
||||
Result {
|
||||
session_id: i32,
|
||||
request_id: String,
|
||||
content_items: Vec<JsonValue>,
|
||||
stored_values: HashMap<String, JsonValue>,
|
||||
#[serde(default)]
|
||||
@@ -307,10 +374,19 @@ pub(crate) async fn execute(
|
||||
let stored_values = service.stored_values().await;
|
||||
let source = build_source(&code, &enabled_tools).map_err(FunctionCallError::RespondToModel)?;
|
||||
let session_id = service.allocate_session_id().await;
|
||||
let request_id = service.allocate_request_id().await;
|
||||
let process_slot = service
|
||||
.ensure_started()
|
||||
.await
|
||||
.map_err(FunctionCallError::RespondToModel)?;
|
||||
.map_err(|err| FunctionCallError::RespondToModel(err.to_string()))?;
|
||||
let started_at = std::time::Instant::now();
|
||||
let message = HostToNodeMessage::Start {
|
||||
request_id: request_id.clone(),
|
||||
session_id,
|
||||
enabled_tools,
|
||||
stored_values,
|
||||
source,
|
||||
};
|
||||
let result = {
|
||||
let mut process_slot = process_slot;
|
||||
let Some(process) = process_slot.as_mut() else {
|
||||
@@ -318,19 +394,15 @@ pub(crate) async fn execute(
|
||||
"{PUBLIC_TOOL_NAME} runner failed to start"
|
||||
)));
|
||||
};
|
||||
drive_code_mode_session(
|
||||
&exec,
|
||||
process,
|
||||
HostToNodeMessage::Start {
|
||||
session_id,
|
||||
enabled_tools,
|
||||
stored_values,
|
||||
source,
|
||||
},
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
let message = process
|
||||
.send(&request_id, &message)
|
||||
.await
|
||||
.map_err(|err| err.to_string());
|
||||
let message = match message {
|
||||
Ok(message) => message,
|
||||
Err(error) => return Err(FunctionCallError::RespondToModel(error)),
|
||||
};
|
||||
handle_node_message(&exec, session_id, message, None, started_at).await
|
||||
};
|
||||
match result {
|
||||
Ok(CodeModeSessionProgress::Finished(output))
|
||||
@@ -353,13 +425,32 @@ pub(crate) async fn wait(
|
||||
turn,
|
||||
tracker,
|
||||
};
|
||||
let request_id = exec
|
||||
.session
|
||||
.services
|
||||
.code_mode_service
|
||||
.allocate_request_id()
|
||||
.await;
|
||||
let started_at = std::time::Instant::now();
|
||||
let message = if terminate {
|
||||
HostToNodeMessage::Terminate {
|
||||
request_id: request_id.clone(),
|
||||
session_id,
|
||||
}
|
||||
} else {
|
||||
HostToNodeMessage::Poll {
|
||||
request_id: request_id.clone(),
|
||||
session_id,
|
||||
yield_time_ms,
|
||||
}
|
||||
};
|
||||
let process_slot = exec
|
||||
.session
|
||||
.services
|
||||
.code_mode_service
|
||||
.ensure_started()
|
||||
.await
|
||||
.map_err(FunctionCallError::RespondToModel)?;
|
||||
.map_err(|err| FunctionCallError::RespondToModel(err.to_string()))?;
|
||||
let result = {
|
||||
let mut process_slot = process_slot;
|
||||
let Some(process) = process_slot.as_mut() else {
|
||||
@@ -372,19 +463,20 @@ pub(crate) async fn wait(
|
||||
"{PUBLIC_TOOL_NAME} runner failed to start"
|
||||
)));
|
||||
}
|
||||
drive_code_mode_session(
|
||||
let message = process
|
||||
.send(&request_id, &message)
|
||||
.await
|
||||
.map_err(|err| err.to_string());
|
||||
let message = match message {
|
||||
Ok(message) => message,
|
||||
Err(error) => return Err(FunctionCallError::RespondToModel(error)),
|
||||
};
|
||||
handle_node_message(
|
||||
&exec,
|
||||
process,
|
||||
if terminate {
|
||||
HostToNodeMessage::Terminate { session_id }
|
||||
} else {
|
||||
HostToNodeMessage::Poll {
|
||||
session_id,
|
||||
yield_time_ms,
|
||||
}
|
||||
},
|
||||
session_id,
|
||||
message,
|
||||
Some(max_output_tokens),
|
||||
terminate,
|
||||
started_at,
|
||||
)
|
||||
.await
|
||||
};
|
||||
@@ -395,131 +487,18 @@ pub(crate) async fn wait(
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_code_mode_process(node_path: &std::path::Path) -> Result<CodeModeProcess, String> {
|
||||
let mut cmd = tokio::process::Command::new(node_path);
|
||||
cmd.arg("--experimental-vm-modules");
|
||||
cmd.arg("--eval");
|
||||
cmd.arg(CODE_MODE_RUNNER_SOURCE);
|
||||
cmd.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.kill_on_drop(true);
|
||||
|
||||
let mut child = cmd
|
||||
.spawn()
|
||||
.map_err(|err| format!("failed to start {PUBLIC_TOOL_NAME} Node runtime: {err}"))?;
|
||||
let stdout = child
|
||||
.stdout
|
||||
.take()
|
||||
.ok_or_else(|| format!("{PUBLIC_TOOL_NAME} runner missing stdout"))?;
|
||||
let stderr = child
|
||||
.stderr
|
||||
.take()
|
||||
.ok_or_else(|| format!("{PUBLIC_TOOL_NAME} runner missing stderr"))?;
|
||||
let stdin = child
|
||||
.stdin
|
||||
.take()
|
||||
.ok_or_else(|| format!("{PUBLIC_TOOL_NAME} runner missing stdin"))?;
|
||||
|
||||
let stderr_task = tokio::spawn(async move {
|
||||
let mut reader = BufReader::new(stderr);
|
||||
let mut buf = Vec::new();
|
||||
match reader.read_to_end(&mut buf).await {
|
||||
Ok(_) => {
|
||||
let stderr = String::from_utf8_lossy(&buf).trim().to_string();
|
||||
if !stderr.is_empty() {
|
||||
warn!("{PUBLIC_TOOL_NAME} runner stderr: {stderr}");
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("failed to read {PUBLIC_TOOL_NAME} stderr: {err}");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(CodeModeProcess {
|
||||
child,
|
||||
stdin,
|
||||
stdout_lines: BufReader::new(stdout).lines(),
|
||||
stderr_task: Some(stderr_task),
|
||||
pending_messages: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn drive_code_mode_session(
|
||||
exec: &ExecContext,
|
||||
process: &mut CodeModeProcess,
|
||||
message: HostToNodeMessage,
|
||||
poll_max_output_tokens: Option<Option<usize>>,
|
||||
is_terminate: bool,
|
||||
) -> Result<CodeModeSessionProgress, String> {
|
||||
let started_at = std::time::Instant::now();
|
||||
let session_id = match &message {
|
||||
HostToNodeMessage::Start { session_id, .. }
|
||||
| HostToNodeMessage::Poll { session_id, .. }
|
||||
| HostToNodeMessage::Terminate { session_id }
|
||||
| HostToNodeMessage::Response { session_id, .. } => *session_id,
|
||||
};
|
||||
process
|
||||
.write(&message)
|
||||
.await
|
||||
.map_err(|err| err.to_string())?;
|
||||
|
||||
loop {
|
||||
let message = process
|
||||
.read(session_id)
|
||||
.await
|
||||
.map_err(|err| err.to_string())?;
|
||||
if let Some(progress) = handle_node_message(
|
||||
exec,
|
||||
process,
|
||||
session_id,
|
||||
message,
|
||||
poll_max_output_tokens,
|
||||
started_at,
|
||||
is_terminate,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
return Ok(progress);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_node_message(
|
||||
exec: &ExecContext,
|
||||
process: &mut CodeModeProcess,
|
||||
session_id: i32,
|
||||
message: NodeToHostMessage,
|
||||
poll_max_output_tokens: Option<Option<usize>>,
|
||||
started_at: std::time::Instant,
|
||||
is_terminate: bool,
|
||||
) -> Result<Option<CodeModeSessionProgress>, String> {
|
||||
) -> Result<CodeModeSessionProgress, String> {
|
||||
match message {
|
||||
NodeToHostMessage::ToolCall {
|
||||
session_id: message_session_id,
|
||||
id,
|
||||
name,
|
||||
input,
|
||||
} => {
|
||||
if is_terminate {
|
||||
return Ok(None);
|
||||
}
|
||||
let response = HostToNodeMessage::Response {
|
||||
session_id: message_session_id,
|
||||
id,
|
||||
code_mode_result: call_nested_tool(exec.clone(), name, input).await,
|
||||
};
|
||||
process
|
||||
.write(&response)
|
||||
.await
|
||||
.map_err(|err| err.to_string())?;
|
||||
Ok(None)
|
||||
}
|
||||
NodeToHostMessage::ToolCall { .. } => Err(format!(
|
||||
"{PUBLIC_TOOL_NAME} received an unexpected tool call response"
|
||||
)),
|
||||
NodeToHostMessage::Yielded { content_items, .. } => {
|
||||
if is_terminate {
|
||||
return Ok(None);
|
||||
}
|
||||
let mut delta_items = output_content_items_from_json_values(content_items)?;
|
||||
delta_items = truncate_code_mode_result(delta_items, poll_max_output_tokens.flatten());
|
||||
prepend_script_status(
|
||||
@@ -527,9 +506,9 @@ async fn handle_node_message(
|
||||
CodeModeExecutionStatus::Running(session_id),
|
||||
started_at.elapsed(),
|
||||
);
|
||||
Ok(Some(CodeModeSessionProgress::Yielded {
|
||||
Ok(CodeModeSessionProgress::Yielded {
|
||||
output: FunctionToolOutput::from_content(delta_items, Some(true)),
|
||||
}))
|
||||
})
|
||||
}
|
||||
NodeToHostMessage::Terminated { content_items, .. } => {
|
||||
let mut delta_items = output_content_items_from_json_values(content_items)?;
|
||||
@@ -539,9 +518,9 @@ async fn handle_node_message(
|
||||
CodeModeExecutionStatus::Terminated,
|
||||
started_at.elapsed(),
|
||||
);
|
||||
Ok(Some(CodeModeSessionProgress::Finished(
|
||||
Ok(CodeModeSessionProgress::Finished(
|
||||
FunctionToolOutput::from_content(delta_items, Some(true)),
|
||||
)))
|
||||
))
|
||||
}
|
||||
NodeToHostMessage::Result {
|
||||
content_items,
|
||||
@@ -576,19 +555,127 @@ async fn handle_node_message(
|
||||
},
|
||||
started_at.elapsed(),
|
||||
);
|
||||
Ok(Some(CodeModeSessionProgress::Finished(
|
||||
Ok(CodeModeSessionProgress::Finished(
|
||||
FunctionToolOutput::from_content(delta_items, Some(success)),
|
||||
)))
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn message_session_id(message: &NodeToHostMessage) -> i32 {
|
||||
async fn spawn_code_mode_process(
|
||||
node_path: &std::path::Path,
|
||||
) -> Result<CodeModeProcess, std::io::Error> {
|
||||
let mut cmd = tokio::process::Command::new(node_path);
|
||||
cmd.arg("--experimental-vm-modules");
|
||||
cmd.arg("--eval");
|
||||
cmd.arg(CODE_MODE_RUNNER_SOURCE);
|
||||
cmd.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.kill_on_drop(true);
|
||||
|
||||
let mut child = cmd.spawn().map_err(std::io::Error::other)?;
|
||||
let stdout = child.stdout.take().ok_or_else(|| {
|
||||
std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stdout"))
|
||||
})?;
|
||||
let stderr = child.stderr.take().ok_or_else(|| {
|
||||
std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stderr"))
|
||||
})?;
|
||||
let stdin = child
|
||||
.stdin
|
||||
.take()
|
||||
.ok_or_else(|| std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stdin")))?;
|
||||
let stdin = Arc::new(Mutex::new(stdin));
|
||||
let response_waiters = Arc::new(Mutex::new(HashMap::<
|
||||
String,
|
||||
oneshot::Sender<NodeToHostMessage>,
|
||||
>::new()));
|
||||
let (tool_call_tx, tool_call_rx) = mpsc::unbounded_channel();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut reader = BufReader::new(stderr);
|
||||
let mut buf = Vec::new();
|
||||
match reader.read_to_end(&mut buf).await {
|
||||
Ok(_) => {
|
||||
let stderr = String::from_utf8_lossy(&buf).trim().to_string();
|
||||
if !stderr.is_empty() {
|
||||
warn!("{PUBLIC_TOOL_NAME} runner stderr: {stderr}");
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("failed to read {PUBLIC_TOOL_NAME} stderr: {err}");
|
||||
}
|
||||
}
|
||||
});
|
||||
let stdout_task = tokio::spawn({
|
||||
let response_waiters = Arc::clone(&response_waiters);
|
||||
let tool_call_tx = tool_call_tx.clone();
|
||||
async move {
|
||||
let mut stdout_lines = BufReader::new(stdout).lines();
|
||||
loop {
|
||||
let line = match stdout_lines.next_line().await {
|
||||
Ok(line) => line,
|
||||
Err(err) => {
|
||||
warn!("failed to read {PUBLIC_TOOL_NAME} stdout: {err}");
|
||||
break;
|
||||
}
|
||||
};
|
||||
let Some(line) = line else {
|
||||
break;
|
||||
};
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let message: NodeToHostMessage = match serde_json::from_str(&line) {
|
||||
Ok(message) => message,
|
||||
Err(err) => {
|
||||
warn!("failed to parse {PUBLIC_TOOL_NAME} stdout message: {err}");
|
||||
break;
|
||||
}
|
||||
};
|
||||
match message {
|
||||
NodeToHostMessage::ToolCall { tool_call } => {
|
||||
let _ = tool_call_tx.send(tool_call);
|
||||
}
|
||||
message => {
|
||||
let request_id = message_request_id(&message).to_string();
|
||||
if let Some(waiter) = response_waiters.lock().await.remove(&request_id) {
|
||||
let _ = waiter.send(message);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response_waiters.lock().await.clear();
|
||||
}
|
||||
});
|
||||
|
||||
Ok(CodeModeProcess {
|
||||
child,
|
||||
stdin,
|
||||
stdout_task,
|
||||
response_waiters,
|
||||
tool_call_rx: Arc::new(Mutex::new(tool_call_rx)),
|
||||
})
|
||||
}
|
||||
|
||||
async fn write_message(
|
||||
stdin: &Arc<Mutex<tokio::process::ChildStdin>>,
|
||||
message: &HostToNodeMessage,
|
||||
) -> Result<(), std::io::Error> {
|
||||
let line = serde_json::to_string(message).map_err(std::io::Error::other)?;
|
||||
let mut stdin = stdin.lock().await;
|
||||
stdin.write_all(line.as_bytes()).await?;
|
||||
stdin.write_all(b"\n").await?;
|
||||
stdin.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn message_request_id(message: &NodeToHostMessage) -> &str {
|
||||
match message {
|
||||
NodeToHostMessage::ToolCall { session_id, .. }
|
||||
| NodeToHostMessage::Yielded { session_id, .. }
|
||||
| NodeToHostMessage::Terminated { session_id, .. }
|
||||
| NodeToHostMessage::Result { session_id, .. } => *session_id,
|
||||
NodeToHostMessage::ToolCall { tool_call } => &tool_call.request_id,
|
||||
NodeToHostMessage::Yielded { request_id, .. }
|
||||
| NodeToHostMessage::Terminated { request_id, .. }
|
||||
| NodeToHostMessage::Result { request_id, .. } => request_id,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -466,11 +466,18 @@ function createProtocol() {
|
||||
if (message.type === 'poll') {
|
||||
const session = sessions.get(message.session_id);
|
||||
if (session) {
|
||||
schedulePollYield(protocol, session, normalizeYieldTime(message.yield_time_ms ?? 0));
|
||||
session.request_id = String(message.request_id);
|
||||
if (session.pending_result) {
|
||||
void completeSession(protocol, sessions, session, session.pending_result);
|
||||
} else if (session.pending_tool_call) {
|
||||
void forwardToolCall(protocol, session, session.pending_tool_call);
|
||||
} else {
|
||||
schedulePollYield(protocol, session, normalizeYieldTime(message.yield_time_ms ?? 0));
|
||||
}
|
||||
} else {
|
||||
void protocol.send({
|
||||
type: 'result',
|
||||
session_id: message.session_id,
|
||||
request_id: message.request_id,
|
||||
content_items: [],
|
||||
stored_values: {},
|
||||
error_text: `exec session ${message.session_id} not found`,
|
||||
@@ -483,11 +490,12 @@ function createProtocol() {
|
||||
if (message.type === 'terminate') {
|
||||
const session = sessions.get(message.session_id);
|
||||
if (session) {
|
||||
session.request_id = String(message.request_id);
|
||||
void terminateSession(protocol, sessions, session);
|
||||
} else {
|
||||
void protocol.send({
|
||||
type: 'result',
|
||||
session_id: message.session_id,
|
||||
request_id: message.request_id,
|
||||
content_items: [],
|
||||
stored_values: {},
|
||||
error_text: `exec session ${message.session_id} not found`,
|
||||
@@ -498,11 +506,11 @@ function createProtocol() {
|
||||
}
|
||||
|
||||
if (message.type === 'response') {
|
||||
const entry = pending.get(message.session_id + ':' + message.id);
|
||||
const entry = pending.get(message.request_id + ':' + message.id);
|
||||
if (!entry) {
|
||||
return;
|
||||
}
|
||||
pending.delete(message.session_id + ':' + message.id);
|
||||
pending.delete(message.request_id + ':' + message.id);
|
||||
entry.resolve(message.code_mode_result ?? '');
|
||||
return;
|
||||
}
|
||||
@@ -537,12 +545,12 @@ function createProtocol() {
|
||||
});
|
||||
}
|
||||
|
||||
function request(sessionId, type, payload) {
|
||||
function request(requestId, type, payload) {
|
||||
const id = 'msg-' + ++nextId;
|
||||
const pendingKey = sessionId + ':' + id;
|
||||
const pendingKey = requestId + ':' + id;
|
||||
return new Promise((resolve, reject) => {
|
||||
pending.set(pendingKey, { resolve, reject });
|
||||
void send({ type, session_id: sessionId, id, ...payload }).catch((error) => {
|
||||
void send({ type, request_id: requestId, id, ...payload }).catch((error) => {
|
||||
pending.delete(pendingKey);
|
||||
reject(error);
|
||||
});
|
||||
@@ -565,7 +573,10 @@ function startSession(protocol, sessions, start) {
|
||||
initial_yield_timer: null,
|
||||
initial_yield_triggered: false,
|
||||
max_output_tokens_per_exec_call: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL,
|
||||
pending_result: null,
|
||||
pending_tool_call: null,
|
||||
poll_yield_timer: null,
|
||||
request_id: String(start.request_id),
|
||||
worker: new Worker(sessionWorkerSource(), {
|
||||
eval: true,
|
||||
workerData: start,
|
||||
@@ -621,17 +632,28 @@ async function handleWorkerMessage(protocol, sessions, session, message) {
|
||||
}
|
||||
|
||||
if (message.type === 'tool_call') {
|
||||
if (session.request_id === null) {
|
||||
session.pending_tool_call = message;
|
||||
return;
|
||||
}
|
||||
void forwardToolCall(protocol, session, message);
|
||||
return;
|
||||
}
|
||||
|
||||
if (message.type === 'result') {
|
||||
await completeSession(protocol, sessions, session, {
|
||||
const result = {
|
||||
type: 'result',
|
||||
stored_values: cloneJsonValue(message.stored_values ?? {}),
|
||||
error_text:
|
||||
typeof message.error_text === 'string' ? message.error_text : undefined,
|
||||
});
|
||||
};
|
||||
if (session.request_id === null) {
|
||||
session.pending_result = result;
|
||||
session.initial_yield_timer = clearTimer(session.initial_yield_timer);
|
||||
session.poll_yield_timer = clearTimer(session.poll_yield_timer);
|
||||
return;
|
||||
}
|
||||
await completeSession(protocol, sessions, session, result);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -640,10 +662,11 @@ async function handleWorkerMessage(protocol, sessions, session, message) {
|
||||
|
||||
async function forwardToolCall(protocol, session, message) {
|
||||
try {
|
||||
const result = await protocol.request(session.id, 'tool_call', {
|
||||
const result = await protocol.request(session.request_id, 'tool_call', {
|
||||
name: String(message.name),
|
||||
input: message.input,
|
||||
});
|
||||
session.pending_tool_call = null;
|
||||
if (session.completed) {
|
||||
return;
|
||||
}
|
||||
@@ -655,6 +678,7 @@ async function forwardToolCall(protocol, session, message) {
|
||||
});
|
||||
} catch {}
|
||||
} catch (error) {
|
||||
session.pending_tool_call = null;
|
||||
if (session.completed) {
|
||||
return;
|
||||
}
|
||||
@@ -673,14 +697,16 @@ async function sendYielded(protocol, session) {
|
||||
return;
|
||||
}
|
||||
const contentItems = takeContentItems(session);
|
||||
const requestId = session.request_id;
|
||||
try {
|
||||
session.worker.postMessage({ type: 'clear_content' });
|
||||
} catch {}
|
||||
await protocol.send({
|
||||
type: 'yielded',
|
||||
session_id: session.id,
|
||||
request_id: requestId,
|
||||
content_items: contentItems,
|
||||
});
|
||||
session.request_id = null;
|
||||
}
|
||||
|
||||
function scheduleInitialYield(protocol, session, yieldTime) {
|
||||
@@ -711,17 +737,26 @@ async function completeSession(protocol, sessions, session, message) {
|
||||
if (session.completed) {
|
||||
return;
|
||||
}
|
||||
if (session.request_id === null) {
|
||||
session.pending_result = message;
|
||||
session.initial_yield_timer = clearTimer(session.initial_yield_timer);
|
||||
session.poll_yield_timer = clearTimer(session.poll_yield_timer);
|
||||
return;
|
||||
}
|
||||
const requestId = session.request_id;
|
||||
session.completed = true;
|
||||
session.initial_yield_timer = clearTimer(session.initial_yield_timer);
|
||||
session.poll_yield_timer = clearTimer(session.poll_yield_timer);
|
||||
sessions.delete(session.id);
|
||||
const contentItems = takeContentItems(session);
|
||||
session.pending_result = null;
|
||||
session.pending_tool_call = null;
|
||||
try {
|
||||
session.worker.postMessage({ type: 'clear_content' });
|
||||
} catch {}
|
||||
await protocol.send({
|
||||
...message,
|
||||
session_id: session.id,
|
||||
request_id: requestId,
|
||||
content_items: contentItems,
|
||||
max_output_tokens_per_exec_call: session.max_output_tokens_per_exec_call,
|
||||
});
|
||||
@@ -741,7 +776,7 @@ async function terminateSession(protocol, sessions, session) {
|
||||
} catch {}
|
||||
await protocol.send({
|
||||
type: 'terminated',
|
||||
session_id: session.id,
|
||||
request_id: session.request_id,
|
||||
content_items: contentItems,
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user