Refactor code mode worker dispatch

This commit is contained in:
pakrym-oai
2026-03-11 23:14:52 -07:00
parent 2ddde61431
commit a0072cf521
3 changed files with 369 additions and 242 deletions

View File

@@ -5464,6 +5464,11 @@ pub(crate) async fn run_turn(
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
// many turns, from the perspective of the user, it is a single turn.
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let _code_mode_worker = sess
.services
.code_mode_service
.start_turn_worker(&sess, &turn_context, &turn_diff_tracker)
.await;
let mut server_model_warning_emitted_for_turn = false;
// `ModelClientSession` is turn-scoped and caches WebSocket + sticky routing state, so we reuse

View File

@@ -1,5 +1,4 @@
use std::collections::HashMap;
use std::collections::VecDeque;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
@@ -32,6 +31,8 @@ use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tracing::warn;
@@ -50,79 +51,109 @@ struct ExecContext {
pub(crate) struct CodeModeProcess {
child: tokio::process::Child,
stdin: tokio::process::ChildStdin,
stdout_lines: tokio::io::Lines<BufReader<tokio::process::ChildStdout>>,
stderr_task: Option<JoinHandle<()>>,
pending_messages: HashMap<i32, VecDeque<NodeToHostMessage>>,
stdin: Arc<Mutex<tokio::process::ChildStdin>>,
stdout_task: JoinHandle<()>,
response_waiters: Arc<Mutex<HashMap<String, oneshot::Sender<NodeToHostMessage>>>>,
tool_call_rx: Arc<Mutex<mpsc::UnboundedReceiver<CodeModeToolCall>>>,
}
pub(crate) struct CodeModeWorker {
shutdown_tx: Option<oneshot::Sender<()>>,
task: JoinHandle<()>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
struct CodeModeToolCall {
request_id: String,
id: String,
name: String,
#[serde(default)]
input: Option<JsonValue>,
}
impl Drop for CodeModeWorker {
fn drop(&mut self) {
if let Some(shutdown_tx) = self.shutdown_tx.take() {
let _ = shutdown_tx.send(());
}
}
}
impl CodeModeProcess {
async fn write(&mut self, message: &HostToNodeMessage) -> Result<(), std::io::Error> {
let line = serde_json::to_string(message).map_err(std::io::Error::other)?;
self.stdin.write_all(line.as_bytes()).await?;
self.stdin.write_all(b"\n").await?;
self.stdin.flush().await?;
Ok(())
}
async fn read(&mut self, session_id: i32) -> Result<NodeToHostMessage, std::io::Error> {
if let Some(message) = self
.pending_messages
.get_mut(&session_id)
.and_then(VecDeque::pop_front)
{
return Ok(message);
}
loop {
let Some(line) = self.stdout_lines.next_line().await? else {
match self.wait_for_exit().await {
Ok(status) => {
self.join_stderr_task().await;
return Err(std::io::Error::other(format!(
"{PUBLIC_TOOL_NAME} runner exited without returning a result (status {status})"
)));
}
Err(err) => return Err(std::io::Error::other(err)),
fn worker(&self, exec: ExecContext) -> CodeModeWorker {
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
let stdin = Arc::clone(&self.stdin);
let tool_call_rx = Arc::clone(&self.tool_call_rx);
let task = tokio::spawn(async move {
loop {
let tool_call = tokio::select! {
_ = &mut shutdown_rx => break,
tool_call = async {
let mut tool_call_rx = tool_call_rx.lock().await;
tool_call_rx.recv().await
} => tool_call,
};
let Some(tool_call) = tool_call else {
break;
};
let response = HostToNodeMessage::Response {
request_id: tool_call.request_id,
id: tool_call.id,
code_mode_result: call_nested_tool(
exec.clone(),
tool_call.name,
tool_call.input,
)
.await,
};
if let Err(err) = write_message(&stdin, &response).await {
warn!("failed to write {PUBLIC_TOOL_NAME} tool response: {err}");
break;
}
};
if line.trim().is_empty() {
continue;
}
let message: NodeToHostMessage =
serde_json::from_str(&line).map_err(std::io::Error::other)?;
let message_session_id = message_session_id(&message);
if message_session_id == session_id {
return Ok(message);
}
self.pending_messages
.entry(message_session_id)
.or_default()
.push_back(message);
});
CodeModeWorker {
shutdown_tx: Some(shutdown_tx),
task,
}
}
fn has_exited(&mut self) -> Result<bool, String> {
async fn send(
&mut self,
request_id: &str,
message: &HostToNodeMessage,
) -> Result<NodeToHostMessage, std::io::Error> {
if self.stdout_task.is_finished() {
return Err(std::io::Error::other(format!(
"{PUBLIC_TOOL_NAME} runner is not available"
)));
}
let (tx, rx) = oneshot::channel();
self.response_waiters
.lock()
.await
.insert(request_id.to_string(), tx);
if let Err(err) = write_message(&self.stdin, message).await {
self.response_waiters.lock().await.remove(request_id);
return Err(err);
}
match rx.await {
Ok(message) => Ok(message),
Err(_) => Err(std::io::Error::other(format!(
"{PUBLIC_TOOL_NAME} runner is not available"
))),
}
}
fn has_exited(&mut self) -> Result<bool, std::io::Error> {
self.child
.try_wait()
.map(|status| status.is_some())
.map_err(|err| format!("failed to inspect {PUBLIC_TOOL_NAME} runner: {err}"))
}
async fn wait_for_exit(&mut self) -> Result<std::process::ExitStatus, String> {
self.child
.wait()
.await
.map_err(|err| format!("failed to wait for {PUBLIC_TOOL_NAME} runner: {err}"))
}
async fn join_stderr_task(&mut self) {
let Some(stderr_task) = self.stderr_task.take() else {
return;
};
if let Err(err) = stderr_task.await {
warn!("failed to join {PUBLIC_TOOL_NAME} stderr task: {err}");
}
.map_err(std::io::Error::other)
}
}
@@ -153,26 +184,62 @@ impl CodeModeService {
async fn ensure_started(
&self,
) -> Result<tokio::sync::OwnedMutexGuard<Option<CodeModeProcess>>, String> {
) -> Result<tokio::sync::OwnedMutexGuard<Option<CodeModeProcess>>, std::io::Error> {
let mut process_slot = self.process.lock().await;
let needs_spawn = match process_slot.as_mut() {
Some(process) => !matches!(process.has_exited(), Ok(false)),
None => true,
};
if needs_spawn {
let node_path = resolve_compatible_node(self.js_repl_node_path.as_deref()).await?;
let node_path = resolve_compatible_node(self.js_repl_node_path.as_deref())
.await
.map_err(std::io::Error::other)?;
*process_slot = Some(spawn_code_mode_process(&node_path).await?);
}
drop(process_slot);
Ok(self.process.clone().lock_owned().await)
}
pub(crate) async fn start_turn_worker(
&self,
session: &Arc<Session>,
turn: &Arc<TurnContext>,
tracker: &SharedTurnDiffTracker,
) -> Option<CodeModeWorker> {
if !turn.features.enabled(Feature::CodeMode) {
return None;
}
let exec = ExecContext {
session: Arc::clone(session),
turn: Arc::clone(turn),
tracker: Arc::clone(tracker),
};
let mut process_slot = match self.ensure_started().await {
Ok(process_slot) => process_slot,
Err(err) => {
warn!("failed to start {PUBLIC_TOOL_NAME} worker for turn: {err}");
return None;
}
};
let Some(process) = process_slot.as_mut() else {
warn!(
"failed to start {PUBLIC_TOOL_NAME} worker for turn: {PUBLIC_TOOL_NAME} runner failed to start"
);
return None;
};
Some(process.worker(exec))
}
pub(crate) async fn allocate_session_id(&self) -> i32 {
let mut next_session_id = self.next_session_id.lock().await;
let session_id = *next_session_id;
*next_session_id = next_session_id.saturating_add(1);
session_id
}
pub(crate) async fn allocate_request_id(&self) -> String {
uuid::Uuid::new_v4().to_string()
}
}
#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
@@ -197,20 +264,23 @@ struct EnabledTool {
#[serde(tag = "type", rename_all = "snake_case")]
enum HostToNodeMessage {
Start {
request_id: String,
session_id: i32,
enabled_tools: Vec<EnabledTool>,
stored_values: HashMap<String, JsonValue>,
source: String,
},
Poll {
request_id: String,
session_id: i32,
yield_time_ms: u64,
},
Terminate {
request_id: String,
session_id: i32,
},
Response {
session_id: i32,
request_id: String,
id: String,
code_mode_result: JsonValue,
},
@@ -220,22 +290,19 @@ enum HostToNodeMessage {
#[serde(tag = "type", rename_all = "snake_case")]
enum NodeToHostMessage {
ToolCall {
session_id: i32,
id: String,
name: String,
#[serde(default)]
input: Option<JsonValue>,
#[serde(flatten)]
tool_call: CodeModeToolCall,
},
Yielded {
session_id: i32,
request_id: String,
content_items: Vec<JsonValue>,
},
Terminated {
session_id: i32,
request_id: String,
content_items: Vec<JsonValue>,
},
Result {
session_id: i32,
request_id: String,
content_items: Vec<JsonValue>,
stored_values: HashMap<String, JsonValue>,
#[serde(default)]
@@ -307,10 +374,19 @@ pub(crate) async fn execute(
let stored_values = service.stored_values().await;
let source = build_source(&code, &enabled_tools).map_err(FunctionCallError::RespondToModel)?;
let session_id = service.allocate_session_id().await;
let request_id = service.allocate_request_id().await;
let process_slot = service
.ensure_started()
.await
.map_err(FunctionCallError::RespondToModel)?;
.map_err(|err| FunctionCallError::RespondToModel(err.to_string()))?;
let started_at = std::time::Instant::now();
let message = HostToNodeMessage::Start {
request_id: request_id.clone(),
session_id,
enabled_tools,
stored_values,
source,
};
let result = {
let mut process_slot = process_slot;
let Some(process) = process_slot.as_mut() else {
@@ -318,19 +394,15 @@ pub(crate) async fn execute(
"{PUBLIC_TOOL_NAME} runner failed to start"
)));
};
drive_code_mode_session(
&exec,
process,
HostToNodeMessage::Start {
session_id,
enabled_tools,
stored_values,
source,
},
None,
false,
)
.await
let message = process
.send(&request_id, &message)
.await
.map_err(|err| err.to_string());
let message = match message {
Ok(message) => message,
Err(error) => return Err(FunctionCallError::RespondToModel(error)),
};
handle_node_message(&exec, session_id, message, None, started_at).await
};
match result {
Ok(CodeModeSessionProgress::Finished(output))
@@ -353,13 +425,32 @@ pub(crate) async fn wait(
turn,
tracker,
};
let request_id = exec
.session
.services
.code_mode_service
.allocate_request_id()
.await;
let started_at = std::time::Instant::now();
let message = if terminate {
HostToNodeMessage::Terminate {
request_id: request_id.clone(),
session_id,
}
} else {
HostToNodeMessage::Poll {
request_id: request_id.clone(),
session_id,
yield_time_ms,
}
};
let process_slot = exec
.session
.services
.code_mode_service
.ensure_started()
.await
.map_err(FunctionCallError::RespondToModel)?;
.map_err(|err| FunctionCallError::RespondToModel(err.to_string()))?;
let result = {
let mut process_slot = process_slot;
let Some(process) = process_slot.as_mut() else {
@@ -372,19 +463,20 @@ pub(crate) async fn wait(
"{PUBLIC_TOOL_NAME} runner failed to start"
)));
}
drive_code_mode_session(
let message = process
.send(&request_id, &message)
.await
.map_err(|err| err.to_string());
let message = match message {
Ok(message) => message,
Err(error) => return Err(FunctionCallError::RespondToModel(error)),
};
handle_node_message(
&exec,
process,
if terminate {
HostToNodeMessage::Terminate { session_id }
} else {
HostToNodeMessage::Poll {
session_id,
yield_time_ms,
}
},
session_id,
message,
Some(max_output_tokens),
terminate,
started_at,
)
.await
};
@@ -395,131 +487,18 @@ pub(crate) async fn wait(
}
}
async fn spawn_code_mode_process(node_path: &std::path::Path) -> Result<CodeModeProcess, String> {
let mut cmd = tokio::process::Command::new(node_path);
cmd.arg("--experimental-vm-modules");
cmd.arg("--eval");
cmd.arg(CODE_MODE_RUNNER_SOURCE);
cmd.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
let mut child = cmd
.spawn()
.map_err(|err| format!("failed to start {PUBLIC_TOOL_NAME} Node runtime: {err}"))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| format!("{PUBLIC_TOOL_NAME} runner missing stdout"))?;
let stderr = child
.stderr
.take()
.ok_or_else(|| format!("{PUBLIC_TOOL_NAME} runner missing stderr"))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| format!("{PUBLIC_TOOL_NAME} runner missing stdin"))?;
let stderr_task = tokio::spawn(async move {
let mut reader = BufReader::new(stderr);
let mut buf = Vec::new();
match reader.read_to_end(&mut buf).await {
Ok(_) => {
let stderr = String::from_utf8_lossy(&buf).trim().to_string();
if !stderr.is_empty() {
warn!("{PUBLIC_TOOL_NAME} runner stderr: {stderr}");
}
}
Err(err) => {
warn!("failed to read {PUBLIC_TOOL_NAME} stderr: {err}");
}
}
});
Ok(CodeModeProcess {
child,
stdin,
stdout_lines: BufReader::new(stdout).lines(),
stderr_task: Some(stderr_task),
pending_messages: HashMap::new(),
})
}
async fn drive_code_mode_session(
exec: &ExecContext,
process: &mut CodeModeProcess,
message: HostToNodeMessage,
poll_max_output_tokens: Option<Option<usize>>,
is_terminate: bool,
) -> Result<CodeModeSessionProgress, String> {
let started_at = std::time::Instant::now();
let session_id = match &message {
HostToNodeMessage::Start { session_id, .. }
| HostToNodeMessage::Poll { session_id, .. }
| HostToNodeMessage::Terminate { session_id }
| HostToNodeMessage::Response { session_id, .. } => *session_id,
};
process
.write(&message)
.await
.map_err(|err| err.to_string())?;
loop {
let message = process
.read(session_id)
.await
.map_err(|err| err.to_string())?;
if let Some(progress) = handle_node_message(
exec,
process,
session_id,
message,
poll_max_output_tokens,
started_at,
is_terminate,
)
.await?
{
return Ok(progress);
}
}
}
async fn handle_node_message(
exec: &ExecContext,
process: &mut CodeModeProcess,
session_id: i32,
message: NodeToHostMessage,
poll_max_output_tokens: Option<Option<usize>>,
started_at: std::time::Instant,
is_terminate: bool,
) -> Result<Option<CodeModeSessionProgress>, String> {
) -> Result<CodeModeSessionProgress, String> {
match message {
NodeToHostMessage::ToolCall {
session_id: message_session_id,
id,
name,
input,
} => {
if is_terminate {
return Ok(None);
}
let response = HostToNodeMessage::Response {
session_id: message_session_id,
id,
code_mode_result: call_nested_tool(exec.clone(), name, input).await,
};
process
.write(&response)
.await
.map_err(|err| err.to_string())?;
Ok(None)
}
NodeToHostMessage::ToolCall { .. } => Err(format!(
"{PUBLIC_TOOL_NAME} received an unexpected tool call response"
)),
NodeToHostMessage::Yielded { content_items, .. } => {
if is_terminate {
return Ok(None);
}
let mut delta_items = output_content_items_from_json_values(content_items)?;
delta_items = truncate_code_mode_result(delta_items, poll_max_output_tokens.flatten());
prepend_script_status(
@@ -527,9 +506,9 @@ async fn handle_node_message(
CodeModeExecutionStatus::Running(session_id),
started_at.elapsed(),
);
Ok(Some(CodeModeSessionProgress::Yielded {
Ok(CodeModeSessionProgress::Yielded {
output: FunctionToolOutput::from_content(delta_items, Some(true)),
}))
})
}
NodeToHostMessage::Terminated { content_items, .. } => {
let mut delta_items = output_content_items_from_json_values(content_items)?;
@@ -539,9 +518,9 @@ async fn handle_node_message(
CodeModeExecutionStatus::Terminated,
started_at.elapsed(),
);
Ok(Some(CodeModeSessionProgress::Finished(
Ok(CodeModeSessionProgress::Finished(
FunctionToolOutput::from_content(delta_items, Some(true)),
)))
))
}
NodeToHostMessage::Result {
content_items,
@@ -576,19 +555,127 @@ async fn handle_node_message(
},
started_at.elapsed(),
);
Ok(Some(CodeModeSessionProgress::Finished(
Ok(CodeModeSessionProgress::Finished(
FunctionToolOutput::from_content(delta_items, Some(success)),
)))
))
}
}
}
fn message_session_id(message: &NodeToHostMessage) -> i32 {
async fn spawn_code_mode_process(
node_path: &std::path::Path,
) -> Result<CodeModeProcess, std::io::Error> {
let mut cmd = tokio::process::Command::new(node_path);
cmd.arg("--experimental-vm-modules");
cmd.arg("--eval");
cmd.arg(CODE_MODE_RUNNER_SOURCE);
cmd.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
let mut child = cmd.spawn().map_err(std::io::Error::other)?;
let stdout = child.stdout.take().ok_or_else(|| {
std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stdout"))
})?;
let stderr = child.stderr.take().ok_or_else(|| {
std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stderr"))
})?;
let stdin = child
.stdin
.take()
.ok_or_else(|| std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stdin")))?;
let stdin = Arc::new(Mutex::new(stdin));
let response_waiters = Arc::new(Mutex::new(HashMap::<
String,
oneshot::Sender<NodeToHostMessage>,
>::new()));
let (tool_call_tx, tool_call_rx) = mpsc::unbounded_channel();
tokio::spawn(async move {
let mut reader = BufReader::new(stderr);
let mut buf = Vec::new();
match reader.read_to_end(&mut buf).await {
Ok(_) => {
let stderr = String::from_utf8_lossy(&buf).trim().to_string();
if !stderr.is_empty() {
warn!("{PUBLIC_TOOL_NAME} runner stderr: {stderr}");
}
}
Err(err) => {
warn!("failed to read {PUBLIC_TOOL_NAME} stderr: {err}");
}
}
});
let stdout_task = tokio::spawn({
let response_waiters = Arc::clone(&response_waiters);
let tool_call_tx = tool_call_tx.clone();
async move {
let mut stdout_lines = BufReader::new(stdout).lines();
loop {
let line = match stdout_lines.next_line().await {
Ok(line) => line,
Err(err) => {
warn!("failed to read {PUBLIC_TOOL_NAME} stdout: {err}");
break;
}
};
let Some(line) = line else {
break;
};
if line.trim().is_empty() {
continue;
}
let message: NodeToHostMessage = match serde_json::from_str(&line) {
Ok(message) => message,
Err(err) => {
warn!("failed to parse {PUBLIC_TOOL_NAME} stdout message: {err}");
break;
}
};
match message {
NodeToHostMessage::ToolCall { tool_call } => {
let _ = tool_call_tx.send(tool_call);
}
message => {
let request_id = message_request_id(&message).to_string();
if let Some(waiter) = response_waiters.lock().await.remove(&request_id) {
let _ = waiter.send(message);
}
}
}
}
response_waiters.lock().await.clear();
}
});
Ok(CodeModeProcess {
child,
stdin,
stdout_task,
response_waiters,
tool_call_rx: Arc::new(Mutex::new(tool_call_rx)),
})
}
async fn write_message(
stdin: &Arc<Mutex<tokio::process::ChildStdin>>,
message: &HostToNodeMessage,
) -> Result<(), std::io::Error> {
let line = serde_json::to_string(message).map_err(std::io::Error::other)?;
let mut stdin = stdin.lock().await;
stdin.write_all(line.as_bytes()).await?;
stdin.write_all(b"\n").await?;
stdin.flush().await?;
Ok(())
}
fn message_request_id(message: &NodeToHostMessage) -> &str {
match message {
NodeToHostMessage::ToolCall { session_id, .. }
| NodeToHostMessage::Yielded { session_id, .. }
| NodeToHostMessage::Terminated { session_id, .. }
| NodeToHostMessage::Result { session_id, .. } => *session_id,
NodeToHostMessage::ToolCall { tool_call } => &tool_call.request_id,
NodeToHostMessage::Yielded { request_id, .. }
| NodeToHostMessage::Terminated { request_id, .. }
| NodeToHostMessage::Result { request_id, .. } => request_id,
}
}

View File

@@ -466,11 +466,18 @@ function createProtocol() {
if (message.type === 'poll') {
const session = sessions.get(message.session_id);
if (session) {
schedulePollYield(protocol, session, normalizeYieldTime(message.yield_time_ms ?? 0));
session.request_id = String(message.request_id);
if (session.pending_result) {
void completeSession(protocol, sessions, session, session.pending_result);
} else if (session.pending_tool_call) {
void forwardToolCall(protocol, session, session.pending_tool_call);
} else {
schedulePollYield(protocol, session, normalizeYieldTime(message.yield_time_ms ?? 0));
}
} else {
void protocol.send({
type: 'result',
session_id: message.session_id,
request_id: message.request_id,
content_items: [],
stored_values: {},
error_text: `exec session ${message.session_id} not found`,
@@ -483,11 +490,12 @@ function createProtocol() {
if (message.type === 'terminate') {
const session = sessions.get(message.session_id);
if (session) {
session.request_id = String(message.request_id);
void terminateSession(protocol, sessions, session);
} else {
void protocol.send({
type: 'result',
session_id: message.session_id,
request_id: message.request_id,
content_items: [],
stored_values: {},
error_text: `exec session ${message.session_id} not found`,
@@ -498,11 +506,11 @@ function createProtocol() {
}
if (message.type === 'response') {
const entry = pending.get(message.session_id + ':' + message.id);
const entry = pending.get(message.request_id + ':' + message.id);
if (!entry) {
return;
}
pending.delete(message.session_id + ':' + message.id);
pending.delete(message.request_id + ':' + message.id);
entry.resolve(message.code_mode_result ?? '');
return;
}
@@ -537,12 +545,12 @@ function createProtocol() {
});
}
function request(sessionId, type, payload) {
function request(requestId, type, payload) {
const id = 'msg-' + ++nextId;
const pendingKey = sessionId + ':' + id;
const pendingKey = requestId + ':' + id;
return new Promise((resolve, reject) => {
pending.set(pendingKey, { resolve, reject });
void send({ type, session_id: sessionId, id, ...payload }).catch((error) => {
void send({ type, request_id: requestId, id, ...payload }).catch((error) => {
pending.delete(pendingKey);
reject(error);
});
@@ -565,7 +573,10 @@ function startSession(protocol, sessions, start) {
initial_yield_timer: null,
initial_yield_triggered: false,
max_output_tokens_per_exec_call: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL,
pending_result: null,
pending_tool_call: null,
poll_yield_timer: null,
request_id: String(start.request_id),
worker: new Worker(sessionWorkerSource(), {
eval: true,
workerData: start,
@@ -621,17 +632,28 @@ async function handleWorkerMessage(protocol, sessions, session, message) {
}
if (message.type === 'tool_call') {
if (session.request_id === null) {
session.pending_tool_call = message;
return;
}
void forwardToolCall(protocol, session, message);
return;
}
if (message.type === 'result') {
await completeSession(protocol, sessions, session, {
const result = {
type: 'result',
stored_values: cloneJsonValue(message.stored_values ?? {}),
error_text:
typeof message.error_text === 'string' ? message.error_text : undefined,
});
};
if (session.request_id === null) {
session.pending_result = result;
session.initial_yield_timer = clearTimer(session.initial_yield_timer);
session.poll_yield_timer = clearTimer(session.poll_yield_timer);
return;
}
await completeSession(protocol, sessions, session, result);
return;
}
@@ -640,10 +662,11 @@ async function handleWorkerMessage(protocol, sessions, session, message) {
async function forwardToolCall(protocol, session, message) {
try {
const result = await protocol.request(session.id, 'tool_call', {
const result = await protocol.request(session.request_id, 'tool_call', {
name: String(message.name),
input: message.input,
});
session.pending_tool_call = null;
if (session.completed) {
return;
}
@@ -655,6 +678,7 @@ async function forwardToolCall(protocol, session, message) {
});
} catch {}
} catch (error) {
session.pending_tool_call = null;
if (session.completed) {
return;
}
@@ -673,14 +697,16 @@ async function sendYielded(protocol, session) {
return;
}
const contentItems = takeContentItems(session);
const requestId = session.request_id;
try {
session.worker.postMessage({ type: 'clear_content' });
} catch {}
await protocol.send({
type: 'yielded',
session_id: session.id,
request_id: requestId,
content_items: contentItems,
});
session.request_id = null;
}
function scheduleInitialYield(protocol, session, yieldTime) {
@@ -711,17 +737,26 @@ async function completeSession(protocol, sessions, session, message) {
if (session.completed) {
return;
}
if (session.request_id === null) {
session.pending_result = message;
session.initial_yield_timer = clearTimer(session.initial_yield_timer);
session.poll_yield_timer = clearTimer(session.poll_yield_timer);
return;
}
const requestId = session.request_id;
session.completed = true;
session.initial_yield_timer = clearTimer(session.initial_yield_timer);
session.poll_yield_timer = clearTimer(session.poll_yield_timer);
sessions.delete(session.id);
const contentItems = takeContentItems(session);
session.pending_result = null;
session.pending_tool_call = null;
try {
session.worker.postMessage({ type: 'clear_content' });
} catch {}
await protocol.send({
...message,
session_id: session.id,
request_id: requestId,
content_items: contentItems,
max_output_tokens_per_exec_call: session.max_output_tokens_per_exec_call,
});
@@ -741,7 +776,7 @@ async function terminateSession(protocol, sessions, session) {
} catch {}
await protocol.send({
type: 'terminated',
session_id: session.id,
request_id: session.request_id,
content_items: contentItems,
});
}