This commit is contained in:
jimmyfraiture
2025-09-24 15:55:47 +01:00
parent e05540ea7f
commit fc179ecdf0
2 changed files with 28 additions and 52 deletions

View File

@@ -1004,9 +1004,7 @@ impl Session {
guard.clone()
};
if let Some(turn_state) = current_turn {
turn_state
.enqueue_user_input(input.clone(), readiness.clone())
.await;
turn_state.enqueue_user_input(input, readiness).await;
Ok(())
} else {
Err((input, readiness))
@@ -1026,20 +1024,17 @@ impl Session {
pub async fn interrupt_task(&self) {
info!("interrupt received: abort current task, if any");
{
let mut state = self.state.lock().await;
state.pending_approvals.clear();
}
let mut state = self.state.lock().await;
state.pending_approvals.clear();
let task = {
let mut current_task = self.current_task.lock().await;
current_task.take()
};
{
let mut current_turn = self.current_turn.lock().await;
current_turn.take();
}
let mut current_turn = self.current_turn.lock().await;
current_turn.take();
if let Some(task) = task {
task.abort(TurnAbortReason::Interrupted);
@@ -1710,7 +1705,7 @@ async fn run_task(sess: Arc<Session>, turn_state: Arc<TurnState>) {
}
loop {
let (pending_input, turn_readiness) = turn_state.drain_mailbox().await.into_parts();
let (pending_input, _turn_readiness) = turn_state.drain_mailbox().await.into_parts();
let turn_input: Vec<ResponseItem> = if is_review_mode {
if !pending_input.is_empty() {
@@ -1736,15 +1731,7 @@ async fn run_task(sess: Arc<Session>, turn_state: Arc<TurnState>) {
})
.collect();
match run_turn(
&sess,
turn_state.as_ref(),
sub_id.clone(),
turn_input,
turn_readiness.clone(),
)
.await
{
match run_turn(&sess, turn_state.as_ref(), sub_id.clone(), turn_input).await {
Ok(turn_output) => {
let TurnRunResult {
processed_items,
@@ -1976,7 +1963,6 @@ async fn run_turn(
turn_state: &TurnState,
sub_id: String,
input: Vec<ResponseItem>,
turn_readiness: Option<Arc<ReadinessFlag>>,
) -> CodexResult<TurnRunResult> {
let turn_context = turn_state.turn_context();
let tools = get_openai_tools(
@@ -1993,16 +1979,7 @@ async fn run_turn(
let mut retries = 0;
loop {
match try_run_turn(
sess,
turn_state,
turn_context.as_ref(),
&sub_id,
&prompt,
turn_readiness.clone(),
)
.await
{
match try_run_turn(sess, turn_state, turn_context.as_ref(), &sub_id, &prompt).await {
Ok(output) => return Ok(output),
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
@@ -2069,7 +2046,6 @@ async fn try_run_turn(
turn_context: &TurnContext,
sub_id: &str,
prompt: &Prompt,
turn_readiness: Option<Arc<ReadinessFlag>>,
) -> CodexResult<TurnRunResult> {
// call_ids that are part of this response.
let completed_call_ids = prompt
@@ -2165,15 +2141,9 @@ async fn try_run_turn(
match event {
ResponseEvent::Created => {}
ResponseEvent::OutputItemDone(item) => {
let response = handle_response_item(
sess,
turn_state,
turn_context,
sub_id,
item.clone(),
turn_readiness.clone(),
)
.await?;
let response =
handle_response_item(sess, turn_state, turn_context, sub_id, item.clone())
.await?;
output.push(ProcessedResponseItem { item, response });
}
ResponseEvent::WebSearchCallBegin { call_id } => {
@@ -2262,7 +2232,6 @@ async fn handle_response_item(
turn_context: &TurnContext,
sub_id: &str,
item: ResponseItem,
turn_readiness: Option<Arc<ReadinessFlag>>,
) -> CodexResult<Option<ResponseInputItem>> {
debug!(?item, "Output item");
let output = match item {
@@ -2272,9 +2241,8 @@ async fn handle_response_item(
call_id,
..
} => {
if let Some(flag) = turn_readiness.as_ref() {
flag.wait_ready().await;
}
// Gate tool invocation on readiness signal for this turn, if set.
turn_state.wait_on_readiness().await;
info!("FunctionCall: {name}({arguments})");
Some(
handle_function_call(
@@ -2296,9 +2264,7 @@ async fn handle_response_item(
action,
} => {
let LocalShellAction::Exec(action) = action;
if let Some(flag) = turn_readiness.as_ref() {
flag.wait_ready().await;
}
turn_state.wait_on_readiness().await;
tracing::info!("LocalShellCall: {action:?}");
let params = ShellToolCallParams {
command: action.command,
@@ -2342,9 +2308,7 @@ async fn handle_response_item(
input,
status: _,
} => {
if let Some(flag) = turn_readiness.as_ref() {
flag.wait_ready().await;
}
turn_state.wait_on_readiness().await;
Some(
handle_custom_tool_call(
sess,

View File

@@ -4,6 +4,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use anyhow::Result;
use codex_utils_readiness::Readiness;
use codex_utils_readiness::ReadinessFlag;
use serde_json::Value;
use tokio::sync::RwLock;
@@ -228,4 +229,15 @@ impl TurnState {
let mut runtime = self.runtime.write().await;
runtime.diff_tracker.get_unified_diff()
}
pub(crate) async fn latest_readiness(&self) -> Option<Arc<ReadinessFlag>> {
let runtime = self.runtime.read().await;
runtime.mailbox.latest_readiness.clone()
}
pub(crate) async fn wait_on_readiness(&self) {
if let Some(flag) = self.latest_readiness().await {
flag.wait_ready().await;
}
}
}