diff --git a/codex-rs/core/src/tools/handlers/agent_jobs/report_agent_job_result.rs b/codex-rs/core/src/tools/handlers/agent_jobs/report_agent_job_result.rs index a7a36a49d5..6458ae4b77 100644 --- a/codex-rs/core/src/tools/handlers/agent_jobs/report_agent_job_result.rs +++ b/codex-rs/core/src/tools/handlers/agent_jobs/report_agent_job_result.rs @@ -61,27 +61,31 @@ pub async fn handle( } let db = required_state_db(&session)?; let reporting_thread_id = session.conversation_id.to_string(); - let accepted = db - .report_agent_job_item_result( + let accepted = if args.stop.unwrap_or(false) { + db.report_agent_job_item_result_and_cancel_job( + args.job_id.as_str(), + args.item_id.as_str(), + reporting_thread_id.as_str(), + &args.result, + "cancelled by worker request", + ) + .await + } else { + db.report_agent_job_item_result( args.job_id.as_str(), args.item_id.as_str(), reporting_thread_id.as_str(), &args.result, ) .await - .map_err(|err| { - let job_id = args.job_id.as_str(); - let item_id = args.item_id.as_str(); - FunctionCallError::RespondToModel(format!( - "failed to record agent job result for {job_id} / {item_id}: {err}" - )) - })?; - if accepted && args.stop.unwrap_or(false) { - let message = "cancelled by worker request"; - let _ = db - .mark_agent_job_cancelled(args.job_id.as_str(), message) - .await; } + .map_err(|err| { + let job_id = args.job_id.as_str(); + let item_id = args.item_id.as_str(); + FunctionCallError::RespondToModel(format!( + "failed to record agent job result for {job_id} / {item_id}: {err}" + )) + })?; let content = serde_json::to_string(&ReportAgentJobResultToolResult { accepted }).map_err(|err| { FunctionCallError::Fatal(format!( diff --git a/codex-rs/state/src/runtime/agent_jobs.rs b/codex-rs/state/src/runtime/agent_jobs.rs index 3f5526c58d..fc0e75640e 100644 --- a/codex-rs/state/src/runtime/agent_jobs.rs +++ b/codex-rs/state/src/runtime/agent_jobs.rs @@ -227,22 +227,23 @@ WHERE id = ? Ok(()) } - pub async fn mark_agent_job_completed(&self, job_id: &str) -> anyhow::Result<()> { + pub async fn mark_agent_job_completed(&self, job_id: &str) -> anyhow::Result { let now = Utc::now().timestamp(); - sqlx::query( + let result = sqlx::query( r#" UPDATE agent_jobs SET status = ?, updated_at = ?, completed_at = ?, last_error = NULL -WHERE id = ? +WHERE id = ? AND status = ? "#, ) .bind(AgentJobStatus::Completed.as_str()) .bind(now) .bind(now) .bind(job_id) + .bind(AgentJobStatus::Running.as_str()) .execute(self.pool.as_ref()) .await?; - Ok(()) + Ok(result.rows_affected() > 0) } pub async fn mark_agent_job_failed( @@ -428,9 +429,46 @@ WHERE job_id = ? AND item_id = ? AND status = ? item_id: &str, reporting_thread_id: &str, result_json: &Value, + ) -> anyhow::Result { + self.report_agent_job_item_result_inner( + job_id, + item_id, + reporting_thread_id, + result_json, + /*cancel_job_reason*/ None, + ) + .await + } + + pub async fn report_agent_job_item_result_and_cancel_job( + &self, + job_id: &str, + item_id: &str, + reporting_thread_id: &str, + result_json: &Value, + cancel_job_reason: &str, + ) -> anyhow::Result { + self.report_agent_job_item_result_inner( + job_id, + item_id, + reporting_thread_id, + result_json, + Some(cancel_job_reason), + ) + .await + } + + async fn report_agent_job_item_result_inner( + &self, + job_id: &str, + item_id: &str, + reporting_thread_id: &str, + result_json: &Value, + cancel_job_reason: Option<&str>, ) -> anyhow::Result { let now = Utc::now().timestamp(); let serialized = serde_json::to_string(result_json)?; + let mut tx = self.pool.begin().await?; let result = sqlx::query( r#" UPDATE agent_job_items @@ -458,9 +496,29 @@ WHERE .bind(item_id) .bind(AgentJobItemStatus::Running.as_str()) .bind(reporting_thread_id) - .execute(self.pool.as_ref()) + .execute(&mut *tx) .await?; - Ok(result.rows_affected() > 0) + let accepted = result.rows_affected() > 0; + if accepted && let Some(reason) = cancel_job_reason { + sqlx::query( + r#" +UPDATE agent_jobs +SET status = ?, updated_at = ?, completed_at = ?, last_error = ? +WHERE id = ? AND status IN (?, ?) + "#, + ) + .bind(AgentJobStatus::Cancelled.as_str()) + .bind(now) + .bind(now) + .bind(reason) + .bind(job_id) + .bind(AgentJobStatus::Pending.as_str()) + .bind(AgentJobStatus::Running.as_str()) + .execute(&mut *tx) + .await?; + } + tx.commit().await?; + Ok(accepted) } pub async fn mark_agent_job_item_completed( @@ -652,6 +710,52 @@ mod tests { Ok(()) } + #[tokio::test] + async fn report_agent_job_item_result_can_cancel_job_atomically() -> anyhow::Result<()> { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home, "test-provider".to_string()).await?; + let (job_id, item_id, thread_id) = create_running_single_item_job(runtime.as_ref()).await?; + + let accepted = runtime + .report_agent_job_item_result_and_cancel_job( + job_id.as_str(), + item_id.as_str(), + thread_id.as_str(), + &json!({"ok": true}), + "cancelled by worker request", + ) + .await?; + assert!(accepted); + + let job = runtime + .get_agent_job(job_id.as_str()) + .await? + .expect("job should exist"); + assert_eq!(job.status, AgentJobStatus::Cancelled); + assert_eq!( + job.last_error, + Some("cancelled by worker request".to_string()) + ); + + let item = runtime + .get_agent_job_item(job_id.as_str(), item_id.as_str()) + .await? + .expect("job item should exist"); + assert_eq!(item.status, AgentJobItemStatus::Completed); + assert_eq!(item.result_json, Some(json!({"ok": true}))); + assert_eq!(item.assigned_thread_id, None); + + let completed = runtime.mark_agent_job_completed(job_id.as_str()).await?; + assert!(!completed); + let job = runtime + .get_agent_job(job_id.as_str()) + .await? + .expect("job should exist"); + assert_eq!(job.status, AgentJobStatus::Cancelled); + + Ok(()) + } + #[tokio::test] async fn report_agent_job_item_result_rejects_late_reports() -> anyhow::Result<()> { let codex_home = unique_temp_dir();