From bfe33e5a7ab05154128beffa7b4465826896c274 Mon Sep 17 00:00:00 2001 From: starr-openai Date: Wed, 6 May 2026 23:56:19 -0700 Subject: [PATCH] Make agent job stop cancellation atomic A worker stop request used to record the item result and job cancellation in separate updates, so the job runner could observe the item completion first and continue spawning pending work. Commit both state updates together and prevent completion from overwriting a final cancellation. Co-authored-by: Codex --- .../agent_jobs/report_agent_job_result.rs | 32 ++--- codex-rs/state/src/runtime/agent_jobs.rs | 116 +++++++++++++++++- 2 files changed, 128 insertions(+), 20 deletions(-) diff --git a/codex-rs/core/src/tools/handlers/agent_jobs/report_agent_job_result.rs b/codex-rs/core/src/tools/handlers/agent_jobs/report_agent_job_result.rs index a7a36a49d5..6458ae4b77 100644 --- a/codex-rs/core/src/tools/handlers/agent_jobs/report_agent_job_result.rs +++ b/codex-rs/core/src/tools/handlers/agent_jobs/report_agent_job_result.rs @@ -61,27 +61,31 @@ pub async fn handle( } let db = required_state_db(&session)?; let reporting_thread_id = session.conversation_id.to_string(); - let accepted = db - .report_agent_job_item_result( + let accepted = if args.stop.unwrap_or(false) { + db.report_agent_job_item_result_and_cancel_job( + args.job_id.as_str(), + args.item_id.as_str(), + reporting_thread_id.as_str(), + &args.result, + "cancelled by worker request", + ) + .await + } else { + db.report_agent_job_item_result( args.job_id.as_str(), args.item_id.as_str(), reporting_thread_id.as_str(), &args.result, ) .await - .map_err(|err| { - let job_id = args.job_id.as_str(); - let item_id = args.item_id.as_str(); - FunctionCallError::RespondToModel(format!( - "failed to record agent job result for {job_id} / {item_id}: {err}" - )) - })?; - if accepted && args.stop.unwrap_or(false) { - let message = "cancelled by worker request"; - let _ = db - .mark_agent_job_cancelled(args.job_id.as_str(), message) - .await; } + .map_err(|err| { + let job_id = args.job_id.as_str(); + let item_id = args.item_id.as_str(); + FunctionCallError::RespondToModel(format!( + "failed to record agent job result for {job_id} / {item_id}: {err}" + )) + })?; let content = serde_json::to_string(&ReportAgentJobResultToolResult { accepted }).map_err(|err| { FunctionCallError::Fatal(format!( diff --git a/codex-rs/state/src/runtime/agent_jobs.rs b/codex-rs/state/src/runtime/agent_jobs.rs index 3f5526c58d..fc0e75640e 100644 --- a/codex-rs/state/src/runtime/agent_jobs.rs +++ b/codex-rs/state/src/runtime/agent_jobs.rs @@ -227,22 +227,23 @@ WHERE id = ? Ok(()) } - pub async fn mark_agent_job_completed(&self, job_id: &str) -> anyhow::Result<()> { + pub async fn mark_agent_job_completed(&self, job_id: &str) -> anyhow::Result { let now = Utc::now().timestamp(); - sqlx::query( + let result = sqlx::query( r#" UPDATE agent_jobs SET status = ?, updated_at = ?, completed_at = ?, last_error = NULL -WHERE id = ? +WHERE id = ? AND status = ? "#, ) .bind(AgentJobStatus::Completed.as_str()) .bind(now) .bind(now) .bind(job_id) + .bind(AgentJobStatus::Running.as_str()) .execute(self.pool.as_ref()) .await?; - Ok(()) + Ok(result.rows_affected() > 0) } pub async fn mark_agent_job_failed( @@ -428,9 +429,46 @@ WHERE job_id = ? AND item_id = ? AND status = ? item_id: &str, reporting_thread_id: &str, result_json: &Value, + ) -> anyhow::Result { + self.report_agent_job_item_result_inner( + job_id, + item_id, + reporting_thread_id, + result_json, + /*cancel_job_reason*/ None, + ) + .await + } + + pub async fn report_agent_job_item_result_and_cancel_job( + &self, + job_id: &str, + item_id: &str, + reporting_thread_id: &str, + result_json: &Value, + cancel_job_reason: &str, + ) -> anyhow::Result { + self.report_agent_job_item_result_inner( + job_id, + item_id, + reporting_thread_id, + result_json, + Some(cancel_job_reason), + ) + .await + } + + async fn report_agent_job_item_result_inner( + &self, + job_id: &str, + item_id: &str, + reporting_thread_id: &str, + result_json: &Value, + cancel_job_reason: Option<&str>, ) -> anyhow::Result { let now = Utc::now().timestamp(); let serialized = serde_json::to_string(result_json)?; + let mut tx = self.pool.begin().await?; let result = sqlx::query( r#" UPDATE agent_job_items @@ -458,9 +496,29 @@ WHERE .bind(item_id) .bind(AgentJobItemStatus::Running.as_str()) .bind(reporting_thread_id) - .execute(self.pool.as_ref()) + .execute(&mut *tx) .await?; - Ok(result.rows_affected() > 0) + let accepted = result.rows_affected() > 0; + if accepted && let Some(reason) = cancel_job_reason { + sqlx::query( + r#" +UPDATE agent_jobs +SET status = ?, updated_at = ?, completed_at = ?, last_error = ? +WHERE id = ? AND status IN (?, ?) + "#, + ) + .bind(AgentJobStatus::Cancelled.as_str()) + .bind(now) + .bind(now) + .bind(reason) + .bind(job_id) + .bind(AgentJobStatus::Pending.as_str()) + .bind(AgentJobStatus::Running.as_str()) + .execute(&mut *tx) + .await?; + } + tx.commit().await?; + Ok(accepted) } pub async fn mark_agent_job_item_completed( @@ -652,6 +710,52 @@ mod tests { Ok(()) } + #[tokio::test] + async fn report_agent_job_item_result_can_cancel_job_atomically() -> anyhow::Result<()> { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home, "test-provider".to_string()).await?; + let (job_id, item_id, thread_id) = create_running_single_item_job(runtime.as_ref()).await?; + + let accepted = runtime + .report_agent_job_item_result_and_cancel_job( + job_id.as_str(), + item_id.as_str(), + thread_id.as_str(), + &json!({"ok": true}), + "cancelled by worker request", + ) + .await?; + assert!(accepted); + + let job = runtime + .get_agent_job(job_id.as_str()) + .await? + .expect("job should exist"); + assert_eq!(job.status, AgentJobStatus::Cancelled); + assert_eq!( + job.last_error, + Some("cancelled by worker request".to_string()) + ); + + let item = runtime + .get_agent_job_item(job_id.as_str(), item_id.as_str()) + .await? + .expect("job item should exist"); + assert_eq!(item.status, AgentJobItemStatus::Completed); + assert_eq!(item.result_json, Some(json!({"ok": true}))); + assert_eq!(item.assigned_thread_id, None); + + let completed = runtime.mark_agent_job_completed(job_id.as_str()).await?; + assert!(!completed); + let job = runtime + .get_agent_job(job_id.as_str()) + .await? + .expect("job should exist"); + assert_eq!(job.status, AgentJobStatus::Cancelled); + + Ok(()) + } + #[tokio::test] async fn report_agent_job_item_result_rejects_late_reports() -> anyhow::Result<()> { let codex_home = unique_temp_dir();