Files
codex/prs/bolinfest/PR-1387.md
2025-09-02 15:17:45 -07:00

22 KiB
Raw Blame History

PR #1387: [Rust] Allow resuming a session that was killed with ctrl + c

Description

Previously, if you ctrl+c'd a conversation, all subsequent turns would 400 because the Responses API never got a response for one of its call ids. This ensures that if we aren't sending a call id by hand, we generate a synthetic aborted call.

Fixes #1244

https://github.com/user-attachments/assets/5126354f-b970-45f5-8c65-f811bca8294a

Full Diff

diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs
index 12c5b7afca..dfe06d1fec 100644
--- a/codex-rs/core/src/chat_completions.rs
+++ b/codex-rs/core/src/chat_completions.rs
@@ -425,7 +425,12 @@ where
                         response_id,
                         token_usage,
                     })));
-                } // No other `Ok` variants exist at the moment, continue polling.
+                }
+                Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
+                    // These events are exclusive to the Responses API and
+                    // will never appear in a Chat Completions stream.
+                    continue;
+                }
             }
         }
     }
@@ -439,7 +444,7 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
     ///
     /// ```ignore
     ///     OutputItemDone(<full message>)
-    ///     Completed { .. }
+    ///     Completed
     /// ```
     ///
     /// No other `OutputItemDone` events will be seen by the caller.
diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs
index 4770796dbb..6daa3a8969 100644
--- a/codex-rs/core/src/client.rs
+++ b/codex-rs/core/src/client.rs
@@ -168,7 +168,7 @@ impl ModelClient {
                     // negligible.
                     if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) {
                         // Surface the error body to callers. Use `unwrap_or_default` per Clippy.
-                        let body = (res.text().await).unwrap_or_default();
+                        let body = res.text().await.unwrap_or_default();
                         return Err(CodexErr::UnexpectedStatus(status, body));
                     }
 
@@ -208,6 +208,9 @@ struct SseEvent {
     item: Option<Value>,
 }
 
+#[derive(Debug, Deserialize)]
+struct ResponseCreated {}
+
 #[derive(Debug, Deserialize)]
 struct ResponseCompleted {
     id: String,
@@ -335,6 +338,11 @@ where
                     return;
                 }
             }
+            "response.created" => {
+                if event.response.is_some() {
+                    let _ = tx_event.send(Ok(ResponseEvent::Created {})).await;
+                }
+            }
             // Final response completed  includes array of output items & id
             "response.completed" => {
                 if let Some(resp_val) = event.response {
@@ -350,7 +358,6 @@ where
                 };
             }
             "response.content_part.done"
-            | "response.created"
             | "response.function_call_arguments.delta"
             | "response.in_progress"
             | "response.output_item.added"
diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs
index e17cf22c59..b08880a0df 100644
--- a/codex-rs/core/src/client_common.rs
+++ b/codex-rs/core/src/client_common.rs
@@ -51,6 +51,7 @@ impl Prompt {
 
 #[derive(Debug)]
 pub enum ResponseEvent {
+    Created,
     OutputItemDone(ResponseItem),
     Completed {
         response_id: String,
diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs
index a43f75a731..ec6e0bd185 100644
--- a/codex-rs/core/src/codex.rs
+++ b/codex-rs/core/src/codex.rs
@@ -1,6 +1,7 @@
 // Poisoned mutex should fail the program
 #![allow(clippy::unwrap_used)]
 
+use std::borrow::Cow;
 use std::collections::HashMap;
 use std::collections::HashSet;
 use std::path::Path;
@@ -188,7 +189,7 @@ pub(crate) struct Session {
 
     /// Optional rollout recorder for persisting the conversation transcript so
     /// sessions can be replayed or inspected later.
-    rollout: Mutex<Option<crate::rollout::RolloutRecorder>>,
+    rollout: Mutex<Option<RolloutRecorder>>,
     state: Mutex<State>,
     codex_linux_sandbox_exe: Option<PathBuf>,
 }
@@ -206,6 +207,9 @@ impl Session {
 struct State {
     approved_commands: HashSet<Vec<String>>,
     current_task: Option<AgentTask>,
+    /// Call IDs that have been sent from the Responses API but have not been sent back yet.
+    /// You CANNOT send a Responses API follow-up message unless you have sent back the output for all pending calls or else it will 400.
+    pending_call_ids: HashSet<String>,
     previous_response_id: Option<String>,
     pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
     pending_input: Vec<ResponseInputItem>,
@@ -312,7 +316,7 @@ impl Session {
     /// Append the given items to the session's rollout transcript (if enabled)
     /// and persist them to disk.
     async fn record_rollout_items(&self, items: &[ResponseItem]) {
-        // Clone the recorder outside of the mutex so we dont hold the lock
+        // Clone the recorder outside of the mutex so we don't hold the lock
         // across an await point (MutexGuard is not Send).
         let recorder = {
             let guard = self.rollout.lock().unwrap();
@@ -411,6 +415,8 @@ impl Session {
     pub fn abort(&self) {
         info!("Aborting existing session");
         let mut state = self.state.lock().unwrap();
+        // Don't clear pending_call_ids because we need to keep track of them to ensure we don't 400 on the next turn.
+        // We will generate a synthetic aborted response for each pending call id.
         state.pending_approvals.clear();
         state.pending_input.clear();
         if let Some(task) = state.current_task.take() {
@@ -431,7 +437,7 @@ impl Session {
         }
 
         let Ok(json) = serde_json::to_string(&notification) else {
-            tracing::error!("failed to serialise notification payload");
+            error!("failed to serialise notification payload");
             return;
         };
 
@@ -443,7 +449,7 @@ impl Session {
 
         // Fire-and-forget  we do not wait for completion.
         if let Err(e) = command.spawn() {
-            tracing::warn!("failed to spawn notifier '{}': {e}", notify_command[0]);
+            warn!("failed to spawn notifier '{}': {e}", notify_command[0]);
         }
     }
 }
@@ -647,7 +653,7 @@ async fn submission_loop(
                     match RolloutRecorder::new(&config, session_id, instructions.clone()).await {
                         Ok(r) => Some(r),
                         Err(e) => {
-                            tracing::warn!("failed to initialise rollout recorder: {e}");
+                            warn!("failed to initialise rollout recorder: {e}");
                             None
                         }
                     };
@@ -742,7 +748,7 @@ async fn submission_loop(
                 tokio::spawn(async move {
                     if let Err(e) = crate::message_history::append_entry(&text, &id, &config).await
                     {
-                        tracing::warn!("failed to append to message history: {e}");
+                        warn!("failed to append to message history: {e}");
                     }
                 });
             }
@@ -772,7 +778,7 @@ async fn submission_loop(
                     };
 
                     if let Err(e) = tx_event.send(event).await {
-                        tracing::warn!("failed to send GetHistoryEntryResponse event: {e}");
+                        warn!("failed to send GetHistoryEntryResponse event: {e}");
                     }
                 });
             }
@@ -1052,6 +1058,7 @@ async fn run_turn(
 /// events map to a `ResponseItem`. A `ResponseItem` may need to be
 /// "handled" such that it produces a `ResponseInputItem` that needs to be
 /// sent back to the model on the next turn.
+#[derive(Debug)]
 struct ProcessedResponseItem {
     item: ResponseItem,
     response: Option<ResponseInputItem>,
@@ -1062,7 +1069,57 @@ async fn try_run_turn(
     sub_id: &str,
     prompt: &Prompt,
 ) -> CodexResult<Vec<ProcessedResponseItem>> {
-    let mut stream = sess.client.clone().stream(prompt).await?;
+    // call_ids that are part of this response.
+    let completed_call_ids = prompt
+        .input
+        .iter()
+        .filter_map(|ri| match ri {
+            ResponseItem::FunctionCallOutput { call_id, .. } => Some(call_id),
+            ResponseItem::LocalShellCall {
+                call_id: Some(call_id),
+                ..
+            } => Some(call_id),
+            _ => None,
+        })
+        .collect::<Vec<_>>();
+
+    // call_ids that were pending but are not part of this response.
+    // This usually happens because the user interrupted the model before we responded to one of its tool calls
+    // and then the user sent a follow-up message.
+    let missing_calls = {
+        sess.state
+            .lock()
+            .unwrap()
+            .pending_call_ids
+            .iter()
+            .filter_map(|call_id| {
+                if completed_call_ids.contains(&call_id) {
+                    None
+                } else {
+                    Some(call_id.clone())
+                }
+            })
+            .map(|call_id| ResponseItem::FunctionCallOutput {
+                call_id: call_id.clone(),
+                output: FunctionCallOutputPayload {
+                    content: "aborted".to_string(),
+                    success: Some(false),
+                },
+            })
+            .collect::<Vec<_>>()
+    };
+    let prompt: Cow<Prompt> = if missing_calls.is_empty() {
+        Cow::Borrowed(prompt)
+    } else {
+        // Add the synthetic aborted missing calls to the beginning of the input to ensure all call ids have responses.
+        let input = [missing_calls, prompt.input.clone()].concat();
+        Cow::Owned(Prompt {
+            input,
+            ..prompt.clone()
+        })
+    };
+
+    let mut stream = sess.client.clone().stream(&prompt).await?;
 
     // Buffer all the incoming messages from the stream first, then execute them.
     // If we execute a function call in the middle of handling the stream, it can time out.
@@ -1074,8 +1131,27 @@ async fn try_run_turn(
     let mut output = Vec::new();
     for event in input {
         match event {
+            ResponseEvent::Created => {
+                let mut state = sess.state.lock().unwrap();
+                // We successfully created a new response and ensured that all pending calls were included so we can clear the pending call ids.
+                state.pending_call_ids.clear();
+            }
             ResponseEvent::OutputItemDone(item) => {
+                let call_id = match &item {
+                    ResponseItem::LocalShellCall {
+                        call_id: Some(call_id),
+                        ..
+                    } => Some(call_id),
+                    ResponseItem::FunctionCall { call_id, .. } => Some(call_id),
+                    _ => None,
+                };
+                if let Some(call_id) = call_id {
+                    // We just got a new call id so we need to make sure to respond to it in the next turn.
+                    let mut state = sess.state.lock().unwrap();
+                    state.pending_call_ids.insert(call_id.clone());
+                }
                 let response = handle_response_item(sess, sub_id, item.clone()).await?;
+
                 output.push(ProcessedResponseItem { item, response });
             }
             ResponseEvent::Completed {
@@ -1138,7 +1214,7 @@ async fn handle_response_item(
             arguments,
             call_id,
         } => {
-            tracing::info!("FunctionCall: {arguments}");
+            info!("FunctionCall: {arguments}");
             Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await)
         }
         ResponseItem::LocalShellCall {
@@ -1220,7 +1296,7 @@ async fn handle_function_call(
                     // Unknown function: reply with structured failure so the model can adapt.
                     ResponseInputItem::FunctionCallOutput {
                         call_id,
-                        output: crate::models::FunctionCallOutputPayload {
+                        output: FunctionCallOutputPayload {
                             content: format!("unsupported call: {}", name),
                             success: None,
                         },
@@ -1252,7 +1328,7 @@ fn parse_container_exec_arguments(
             // allow model to re-sample
             let output = ResponseInputItem::FunctionCallOutput {
                 call_id: call_id.to_string(),
-                output: crate::models::FunctionCallOutputPayload {
+                output: FunctionCallOutputPayload {
                     content: format!("failed to parse function arguments: {e}"),
                     success: None,
                 },
@@ -1320,7 +1396,7 @@ async fn handle_container_exec_with_params(
                 ReviewDecision::Denied | ReviewDecision::Abort => {
                     return ResponseInputItem::FunctionCallOutput {
                         call_id,
-                        output: crate::models::FunctionCallOutputPayload {
+                        output: FunctionCallOutputPayload {
                             content: "exec command rejected by user".to_string(),
                             success: None,
                         },
@@ -1336,7 +1412,7 @@ async fn handle_container_exec_with_params(
         SafetyCheck::Reject { reason } => {
             return ResponseInputItem::FunctionCallOutput {
                 call_id,
-                output: crate::models::FunctionCallOutputPayload {
+                output: FunctionCallOutputPayload {
                     content: format!("exec command rejected: {reason}"),
                     success: None,
                 },
@@ -1870,7 +1946,7 @@ fn apply_changes_from_apply_patch(action: &ApplyPatchAction) -> anyhow::Result<A
     })
 }
 
-fn get_writable_roots(cwd: &Path) -> Vec<std::path::PathBuf> {
+fn get_writable_roots(cwd: &Path) -> Vec<PathBuf> {
     let mut writable_roots = Vec::new();
     if cfg!(target_os = "macos") {
         // On macOS, $TMPDIR is private to the user.
@@ -1898,7 +1974,7 @@ fn get_writable_roots(cwd: &Path) -> Vec<std::path::PathBuf> {
 }
 
 /// Exec output is a pre-serialized JSON payload
-fn format_exec_output(output: &str, exit_code: i32, duration: std::time::Duration) -> String {
+fn format_exec_output(output: &str, exit_code: i32, duration: Duration) -> String {
     #[derive(Serialize)]
     struct ExecMetadata {
         exit_code: i32,

Review Comments

codex-rs/core/src/client.rs

@@ -296,6 +299,11 @@ where
                     return;
                 }
             }
+            "response.created" => {
+                if let Some(_) = event.response {

I predict Clippy wants if event.response.is_some()?

I have this in my settings.json in VS Code:

  "rust-analyzer.check.command": "clippy",
  "rust-analyzer.check.extraArgs": ["--", "-D", "warnings"],

codex-rs/core/src/client_common.rs

@@ -50,6 +50,7 @@ impl Prompt {
 
 #[derive(Debug)]
 pub enum ResponseEvent {
+    Created {},
    Created,

codex-rs/core/src/codex.rs

@@ -1062,7 +1068,59 @@ async fn try_run_turn(
     sub_id: &str,
     prompt: &Prompt,
 ) -> CodexResult<Vec<ProcessedResponseItem>> {
-    let mut stream = sess.client.clone().stream(prompt).await?;
+    // call_ids that are part of this response.
+    let completed_call_ids = prompt
+        .input
+        .iter()
+        .filter_map(|ri| match ri {
+            ResponseItem::FunctionCallOutput { call_id, .. } => Some(call_id),
+            ResponseItem::LocalShellCall {
+                call_id: Some(call_id),
+                ..
+            } => Some(call_id),
+            _ => None,
+        })
+        .collect::<Vec<_>>();
+
+    // call_ids that were pending but are not part of this response.
+    // This usually happens because the user interrupted the model before we responded to one of its tool calls
+    // and then the user sent a follow-up message.
+    let missing_calls = {
+        sess.state
+            .lock()
+            .unwrap()
+            .pending_call_ids
+            .iter()
+            .filter_map(|call_id| {
+                if completed_call_ids.contains(&call_id) {
+                    None
+                } else {
+                    Some(call_id.clone())
+                }
+            })
+            .map(|call_id| ResponseItem::FunctionCallOutput {
+                call_id: call_id.clone(),
+                output: FunctionCallOutputPayload {
+                    content: "aborted".to_string(),
+                    success: Some(false),
+                },
+            })
+            .collect::<Vec<_>>()
+    };
+    let prompt = if missing_calls.is_empty() {
+        prompt.clone()

I think you want let prompt: Cow<'a, Prompt> if you can to avoid the clone()? So in the consequent, it's Cow::Borrowed and in the alternative, it's Cow::Owned?

@@ -1062,7 +1068,59 @@ async fn try_run_turn(
     sub_id: &str,
     prompt: &Prompt,
 ) -> CodexResult<Vec<ProcessedResponseItem>> {
-    let mut stream = sess.client.clone().stream(prompt).await?;
+    // call_ids that are part of this response.
+    let completed_call_ids = prompt
+        .input
+        .iter()
+        .filter_map(|ri| match ri {
+            ResponseItem::FunctionCallOutput { call_id, .. } => Some(call_id),
+            ResponseItem::LocalShellCall {
+                call_id: Some(call_id),
+                ..
+            } => Some(call_id),
+            _ => None,
+        })
+        .collect::<Vec<_>>();
+
+    // call_ids that were pending but are not part of this response.
+    // This usually happens because the user interrupted the model before we responded to one of its tool calls
+    // and then the user sent a follow-up message.
+    let missing_calls = {
+        sess.state
+            .lock()
+            .unwrap()
+            .pending_call_ids
+            .iter()
+            .filter_map(|call_id| {
+                if completed_call_ids.contains(&call_id) {
+                    None
+                } else {
+                    Some(call_id.clone())
+                }
+            })
+            .map(|call_id| ResponseItem::FunctionCallOutput {
+                call_id: call_id.clone(),
+                output: FunctionCallOutputPayload {
+                    content: "aborted".to_string(),
+                    success: Some(false),
+                },
+            })
+            .collect::<Vec<_>>()
+    };
+    let prompt = if missing_calls.is_empty() {

Can you add a comment explaining why we redefine the prompt when missing_calls is non-empty?

@@ -1062,7 +1068,59 @@ async fn try_run_turn(
     sub_id: &str,
     prompt: &Prompt,
 ) -> CodexResult<Vec<ProcessedResponseItem>> {
-    let mut stream = sess.client.clone().stream(prompt).await?;
+    // call_ids that are part of this response.
+    let completed_call_ids = prompt
+        .input
+        .iter()
+        .filter_map(|ri| match ri {
+            ResponseItem::FunctionCallOutput { call_id, .. } => Some(call_id),
+            ResponseItem::LocalShellCall {
+                call_id: Some(call_id),
+                ..
+            } => Some(call_id),
+            _ => None,
+        })
+        .collect::<Vec<_>>();
+
+    // call_ids that were pending but are not part of this response.
+    // This usually happens because the user interrupted the model before we responded to one of its tool calls
+    // and then the user sent a follow-up message.
+    let missing_calls = {
+        sess.state
+            .lock()
+            .unwrap()
+            .pending_call_ids
+            .iter()
+            .filter_map(|call_id| {
+                if completed_call_ids.contains(&call_id) {
+                    None
+                } else {
+                    Some(call_id.clone())
+                }
+            })
+            .map(|call_id| ResponseItem::FunctionCallOutput {
+                call_id: call_id.clone(),
+                output: FunctionCallOutputPayload {
+                    content: "aborted".to_string(),
+                    success: Some(false),
+                },
+            })
+            .collect::<Vec<_>>()
+    };
+    let prompt = if missing_calls.is_empty() {
+        prompt.clone()
+    } else {
+        let input = [prompt.input.clone(), missing_calls].concat();
+        Prompt {
+            input,
+            prev_id: prompt.prev_id.clone(),
+            user_instructions: prompt.user_instructions.clone(),
+            store: prompt.store,
+            extra_tools: prompt.extra_tools.clone(),
+        }

Does this work?

        Prompt {
            input: [prompt.input.clone(), missing_calls].concat(),
            ..prompt.clone(),
        }