mirror of
https://github.com/openai/codex.git
synced 2026-06-02 19:31:59 +00:00
Compare commits
14 Commits
fcoury/hid
...
codex/fix-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f576ea4dfe | ||
|
|
1197683973 | ||
|
|
d09040cd31 | ||
|
|
5556d563bb | ||
|
|
3acc95b171 | ||
|
|
b9564552d3 | ||
|
|
937fd9c2f6 | ||
|
|
b4e16db47b | ||
|
|
47f9e3ed87 | ||
|
|
d5676aa35c | ||
|
|
8580e0e3e3 | ||
|
|
7918592dfa | ||
|
|
87d6aaeb6b | ||
|
|
da391e4bd2 |
@@ -4060,7 +4060,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -16345,7 +16345,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/v2/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -14239,7 +14239,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -1323,7 +1323,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -2181,7 +2181,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -1674,7 +1674,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -1674,7 +1674,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -1674,7 +1674,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -2181,7 +2181,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -1674,7 +1674,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -2181,7 +2181,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -1674,7 +1674,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -1323,7 +1323,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -1674,7 +1674,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -1323,7 +1323,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -1323,7 +1323,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -1323,7 +1323,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"items": {
|
||||
"description": "Only populated on a `thread/resume` or `thread/fork` response. For all other responses and notifications returning a Turn, the items field will be an empty list.",
|
||||
"description": "Populated on history-bearing responses such as `thread/resume` and `thread/fork`. The `turn/completed` notification may also include a compact terminal turn snapshot when it is still available from the live thread listener. Bulky command/tool outputs may be elided there to avoid replaying large payloads on the terminal notification.",
|
||||
"items": {
|
||||
"$ref": "#/definitions/ThreadItem"
|
||||
},
|
||||
|
||||
@@ -7,9 +7,10 @@ import type { TurnStatus } from "./TurnStatus";
|
||||
|
||||
export type Turn = { id: string,
|
||||
/**
|
||||
* Only populated on a `thread/resume` or `thread/fork` response.
|
||||
* For all other responses and notifications returning a Turn,
|
||||
* the items field will be an empty list.
|
||||
* Populated on history-bearing responses such as `thread/resume` and `thread/fork`.
|
||||
* The `turn/completed` notification may also include a compact terminal turn snapshot when
|
||||
* it is still available from the live thread listener. Bulky command/tool outputs may be
|
||||
* elided there to avoid replaying large payloads on the terminal notification.
|
||||
*/
|
||||
items: Array<ThreadItem>, status: TurnStatus,
|
||||
/**
|
||||
|
||||
@@ -4443,9 +4443,10 @@ impl From<CoreTokenUsage> for TokenUsageBreakdown {
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct Turn {
|
||||
pub id: String,
|
||||
/// Only populated on a `thread/resume` or `thread/fork` response.
|
||||
/// For all other responses and notifications returning a Turn,
|
||||
/// the items field will be an empty list.
|
||||
/// Populated on history-bearing responses such as `thread/resume` and `thread/fork`.
|
||||
/// The `turn/completed` notification may also include a compact terminal turn snapshot when
|
||||
/// it is still available from the live thread listener. Bulky command/tool outputs may be
|
||||
/// elided there to avoid replaying large payloads on the terminal notification.
|
||||
pub items: Vec<ThreadItem>,
|
||||
pub status: TurnStatus,
|
||||
/// Only populated when the Turn's status is failed.
|
||||
|
||||
@@ -1020,12 +1020,12 @@ Because audio is intentionally separate from `ThreadItem`, clients can opt out o
|
||||
The app-server streams JSON-RPC notifications while a turn is running. Each turn emits `turn/started` when it begins running and ends with `turn/completed` (final `turn` status). Token usage events stream separately via `thread/tokenUsage/updated`. Clients subscribe to the events they care about, rendering each item incrementally as updates arrive. The per-item lifecycle is always: `item/started` → zero or more item-specific deltas → `item/completed`.
|
||||
|
||||
- `turn/started` — `{ turn }` with the turn id, empty `items`, and `status: "inProgress"`.
|
||||
- `turn/completed` — `{ turn }` where `turn.status` is `completed`, `interrupted`, or `failed`; failures carry `{ error: { message, codexErrorInfo?, additionalDetails? } }`.
|
||||
- `turn/completed` — `{ turn }` where `turn.status` is `completed`, `interrupted`, or `failed`; failures carry `{ error: { message, codexErrorInfo?, additionalDetails? } }`. When the live listener still has the terminal turn snapshot, `turn.items` includes the completed turn items.
|
||||
- `turn/diff/updated` — `{ threadId, turnId, diff }` represents the up-to-date snapshot of the turn-level unified diff, emitted after every FileChange item. `diff` is the latest aggregated unified diff across every file change in the turn. UIs can render this to show the full "what changed" view without stitching individual `fileChange` items.
|
||||
- `turn/plan/updated` — `{ turnId, explanation?, plan }` whenever the agent shares or changes its plan; each `plan` entry is `{ step, status }` with `status` in `pending`, `inProgress`, or `completed`.
|
||||
- `model/rerouted` — `{ threadId, turnId, fromModel, toModel, reason }` when the backend reroutes a request to a different model (for example, due to high-risk cyber safety checks).
|
||||
|
||||
Today both notifications carry an empty `items` array even when item events were streamed; rely on `item/*` notifications for the canonical item list until this is fixed.
|
||||
`turn/started` still carries an empty `items` array. `turn/completed` includes the terminal turn items when app-server can recover them from the live thread listener, but clients should still treat `item/*` notifications as the canonical incremental stream.
|
||||
|
||||
#### Items
|
||||
|
||||
|
||||
@@ -2021,20 +2021,31 @@ struct TurnCompletionMetadata {
|
||||
async fn emit_turn_completed_with_status(
|
||||
conversation_id: ThreadId,
|
||||
event_turn_id: String,
|
||||
turn_snapshot: Option<Turn>,
|
||||
turn_completion_metadata: TurnCompletionMetadata,
|
||||
analytics_events_client: Option<&AnalyticsEventsClient>,
|
||||
outgoing: &ThreadScopedOutgoingMessageSender,
|
||||
) {
|
||||
let turn = turn_snapshot.unwrap_or_else(|| Turn {
|
||||
id: event_turn_id.clone(),
|
||||
items: vec![],
|
||||
error: None,
|
||||
status: turn_completion_metadata.status.clone(),
|
||||
started_at: None,
|
||||
completed_at: None,
|
||||
duration_ms: None,
|
||||
});
|
||||
let turn = compact_turn_completed_turn(turn);
|
||||
let notification = TurnCompletedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn: Turn {
|
||||
id: event_turn_id,
|
||||
items: vec![],
|
||||
error: turn_completion_metadata.error,
|
||||
items: turn.items,
|
||||
error: turn_completion_metadata.error.or(turn.error),
|
||||
status: turn_completion_metadata.status,
|
||||
started_at: turn_completion_metadata.started_at,
|
||||
completed_at: turn_completion_metadata.completed_at,
|
||||
duration_ms: turn_completion_metadata.duration_ms,
|
||||
started_at: turn_completion_metadata.started_at.or(turn.started_at),
|
||||
completed_at: turn_completion_metadata.completed_at.or(turn.completed_at),
|
||||
duration_ms: turn_completion_metadata.duration_ms.or(turn.duration_ms),
|
||||
},
|
||||
};
|
||||
if let Some(analytics_events_client) = analytics_events_client {
|
||||
@@ -2046,6 +2057,96 @@ async fn emit_turn_completed_with_status(
|
||||
.await;
|
||||
}
|
||||
|
||||
fn compact_turn_completed_turn(turn: Turn) -> Turn {
|
||||
let Turn {
|
||||
id,
|
||||
items,
|
||||
status,
|
||||
error,
|
||||
started_at,
|
||||
completed_at,
|
||||
duration_ms,
|
||||
} = turn;
|
||||
Turn {
|
||||
id,
|
||||
items: items.into_iter().map(compact_turn_completed_item).collect(),
|
||||
status,
|
||||
error,
|
||||
started_at,
|
||||
completed_at,
|
||||
duration_ms,
|
||||
}
|
||||
}
|
||||
|
||||
fn compact_turn_completed_item(item: ThreadItem) -> ThreadItem {
|
||||
match item {
|
||||
ThreadItem::CommandExecution {
|
||||
id,
|
||||
command,
|
||||
cwd,
|
||||
process_id,
|
||||
source,
|
||||
status,
|
||||
command_actions,
|
||||
aggregated_output: _,
|
||||
exit_code,
|
||||
duration_ms,
|
||||
} => ThreadItem::CommandExecution {
|
||||
id,
|
||||
command,
|
||||
cwd,
|
||||
process_id,
|
||||
source,
|
||||
status,
|
||||
command_actions,
|
||||
aggregated_output: None,
|
||||
exit_code,
|
||||
duration_ms,
|
||||
},
|
||||
ThreadItem::McpToolCall {
|
||||
id,
|
||||
server,
|
||||
tool,
|
||||
status,
|
||||
arguments,
|
||||
mcp_app_resource_uri,
|
||||
result: _,
|
||||
error,
|
||||
duration_ms,
|
||||
} => ThreadItem::McpToolCall {
|
||||
id,
|
||||
server,
|
||||
tool,
|
||||
status,
|
||||
arguments,
|
||||
mcp_app_resource_uri,
|
||||
result: None,
|
||||
error,
|
||||
duration_ms,
|
||||
},
|
||||
ThreadItem::DynamicToolCall {
|
||||
id,
|
||||
namespace,
|
||||
tool,
|
||||
arguments,
|
||||
status,
|
||||
content_items: _,
|
||||
success,
|
||||
duration_ms,
|
||||
} => ThreadItem::DynamicToolCall {
|
||||
id,
|
||||
namespace,
|
||||
tool,
|
||||
arguments,
|
||||
status,
|
||||
content_items: None,
|
||||
success,
|
||||
duration_ms,
|
||||
},
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
async fn complete_file_change_item(
|
||||
conversation_id: ThreadId,
|
||||
item_id: String,
|
||||
@@ -2224,6 +2325,7 @@ pub(crate) async fn maybe_emit_hook_prompt_item_completed(
|
||||
.await;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn find_and_remove_turn_summary(
|
||||
_conversation_id: ThreadId,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
@@ -2232,6 +2334,16 @@ async fn find_and_remove_turn_summary(
|
||||
std::mem::take(&mut state.turn_summary)
|
||||
}
|
||||
|
||||
async fn find_turn_completion_state(
|
||||
event_turn_id: &str,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) -> (TurnSummary, Option<Turn>) {
|
||||
let mut state = thread_state.lock().await;
|
||||
let turn_summary = std::mem::take(&mut state.turn_summary);
|
||||
let turn_snapshot = state.completion_turn_snapshot(event_turn_id);
|
||||
(turn_summary, turn_snapshot)
|
||||
}
|
||||
|
||||
async fn handle_turn_complete(
|
||||
conversation_id: ThreadId,
|
||||
event_turn_id: String,
|
||||
@@ -2240,7 +2352,8 @@ async fn handle_turn_complete(
|
||||
outgoing: &ThreadScopedOutgoingMessageSender,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
let turn_summary = find_and_remove_turn_summary(conversation_id, thread_state).await;
|
||||
let (turn_summary, turn_snapshot) =
|
||||
find_turn_completion_state(&event_turn_id, thread_state).await;
|
||||
|
||||
let (status, error) = match turn_summary.last_error {
|
||||
Some(error) => (TurnStatus::Failed, Some(error)),
|
||||
@@ -2250,6 +2363,7 @@ async fn handle_turn_complete(
|
||||
emit_turn_completed_with_status(
|
||||
conversation_id,
|
||||
event_turn_id,
|
||||
turn_snapshot,
|
||||
TurnCompletionMetadata {
|
||||
status,
|
||||
error,
|
||||
@@ -2271,11 +2385,13 @@ async fn handle_turn_interrupted(
|
||||
outgoing: &ThreadScopedOutgoingMessageSender,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
let turn_summary = find_and_remove_turn_summary(conversation_id, thread_state).await;
|
||||
let (turn_summary, turn_snapshot) =
|
||||
find_turn_completion_state(&event_turn_id, thread_state).await;
|
||||
|
||||
emit_turn_completed_with_status(
|
||||
conversation_id,
|
||||
event_turn_id,
|
||||
turn_snapshot,
|
||||
TurnCompletionMetadata {
|
||||
status: TurnStatus::Interrupted,
|
||||
error: None,
|
||||
@@ -3050,6 +3166,7 @@ mod tests {
|
||||
use codex_protocol::items::build_hook_prompt_message;
|
||||
use codex_protocol::mcp::CallToolResult;
|
||||
use codex_protocol::models::FileSystemPermissions as CoreFileSystemPermissions;
|
||||
use codex_protocol::models::MessagePhase;
|
||||
use codex_protocol::models::NetworkPermissions as CoreNetworkPermissions;
|
||||
use codex_protocol::permissions::FileSystemAccessMode;
|
||||
use codex_protocol::permissions::FileSystemPath;
|
||||
@@ -4097,6 +4214,13 @@ mod tests {
|
||||
collaboration_mode_kind: Default::default(),
|
||||
},
|
||||
));
|
||||
state.track_current_turn_event(&EventMsg::AgentMessage(
|
||||
codex_protocol::protocol::AgentMessageEvent {
|
||||
message: "done".to_string(),
|
||||
phase: Some(MessagePhase::FinalAnswer),
|
||||
memory_citation: None,
|
||||
},
|
||||
));
|
||||
state.track_current_turn_event(&EventMsg::TurnComplete(turn_complete_event(
|
||||
&event_turn_id,
|
||||
)));
|
||||
@@ -4115,12 +4239,23 @@ mod tests {
|
||||
let msg = recv_broadcast_message(&mut rx).await?;
|
||||
match msg {
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => {
|
||||
assert_eq!(n.turn.id, event_turn_id);
|
||||
assert_eq!(n.turn.status, TurnStatus::Completed);
|
||||
assert_eq!(n.turn.error, None);
|
||||
assert_eq!(n.turn.started_at, Some(42));
|
||||
assert_eq!(n.turn.completed_at, Some(TEST_TURN_COMPLETED_AT));
|
||||
assert_eq!(n.turn.duration_ms, Some(TEST_TURN_DURATION_MS));
|
||||
assert_eq!(
|
||||
n.turn,
|
||||
Turn {
|
||||
id: event_turn_id,
|
||||
items: vec![ThreadItem::AgentMessage {
|
||||
id: "item-1".to_string(),
|
||||
text: "done".to_string(),
|
||||
phase: Some(MessagePhase::FinalAnswer),
|
||||
memory_citation: None,
|
||||
}],
|
||||
status: TurnStatus::Completed,
|
||||
error: None,
|
||||
started_at: Some(42),
|
||||
completed_at: Some(TEST_TURN_COMPLETED_AT),
|
||||
duration_ms: Some(TEST_TURN_DURATION_MS),
|
||||
}
|
||||
);
|
||||
}
|
||||
other => bail!("unexpected message: {other:?}"),
|
||||
}
|
||||
@@ -4128,6 +4263,116 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compact_turn_completed_turn_elides_bulky_command_and_tool_outputs() {
|
||||
let turn = Turn {
|
||||
id: "turn-1".to_string(),
|
||||
items: vec![
|
||||
ThreadItem::CommandExecution {
|
||||
id: "cmd-1".to_string(),
|
||||
command: "printf hi".to_string(),
|
||||
cwd: test_path_buf("/tmp").abs(),
|
||||
process_id: Some("1".to_string()),
|
||||
source: CommandExecutionSource::Agent,
|
||||
status: CommandExecutionStatus::Completed,
|
||||
command_actions: vec![V2ParsedCommand::Unknown {
|
||||
command: "printf hi".to_string(),
|
||||
}],
|
||||
aggregated_output: Some("large output".to_string()),
|
||||
exit_code: Some(0),
|
||||
duration_ms: Some(1),
|
||||
},
|
||||
ThreadItem::McpToolCall {
|
||||
id: "mcp-1".to_string(),
|
||||
server: "example".to_string(),
|
||||
tool: "search".to_string(),
|
||||
status: McpToolCallStatus::Completed,
|
||||
arguments: json!({"q":"hi"}),
|
||||
mcp_app_resource_uri: None,
|
||||
result: Some(Box::new(McpToolCallResult {
|
||||
content: vec![json!({"text":"large tool output"})],
|
||||
structured_content: Some(json!({"rows":[{"value":"large"}]})),
|
||||
meta: None,
|
||||
})),
|
||||
error: None,
|
||||
duration_ms: Some(2),
|
||||
},
|
||||
ThreadItem::DynamicToolCall {
|
||||
id: "tool-1".to_string(),
|
||||
namespace: Some("custom-tools".to_string()),
|
||||
tool: "custom".to_string(),
|
||||
arguments: json!({"q":"hi"}),
|
||||
status: DynamicToolCallStatus::Completed,
|
||||
content_items: Some(vec![DynamicToolCallOutputContentItem::InputText {
|
||||
text: "large tool output".to_string(),
|
||||
}]),
|
||||
success: Some(true),
|
||||
duration_ms: Some(3),
|
||||
},
|
||||
ThreadItem::AgentMessage {
|
||||
id: "msg-1".to_string(),
|
||||
text: "final answer".to_string(),
|
||||
phase: Some(MessagePhase::FinalAnswer),
|
||||
memory_citation: None,
|
||||
},
|
||||
],
|
||||
status: TurnStatus::Completed,
|
||||
error: None,
|
||||
started_at: Some(1),
|
||||
completed_at: Some(2),
|
||||
duration_ms: Some(3),
|
||||
};
|
||||
|
||||
let compact = compact_turn_completed_turn(turn);
|
||||
|
||||
assert_eq!(
|
||||
compact.items,
|
||||
vec![
|
||||
ThreadItem::CommandExecution {
|
||||
id: "cmd-1".to_string(),
|
||||
command: "printf hi".to_string(),
|
||||
cwd: test_path_buf("/tmp").abs(),
|
||||
process_id: Some("1".to_string()),
|
||||
source: CommandExecutionSource::Agent,
|
||||
status: CommandExecutionStatus::Completed,
|
||||
command_actions: vec![V2ParsedCommand::Unknown {
|
||||
command: "printf hi".to_string(),
|
||||
}],
|
||||
aggregated_output: None,
|
||||
exit_code: Some(0),
|
||||
duration_ms: Some(1),
|
||||
},
|
||||
ThreadItem::McpToolCall {
|
||||
id: "mcp-1".to_string(),
|
||||
server: "example".to_string(),
|
||||
tool: "search".to_string(),
|
||||
status: McpToolCallStatus::Completed,
|
||||
arguments: json!({"q":"hi"}),
|
||||
mcp_app_resource_uri: None,
|
||||
result: None,
|
||||
error: None,
|
||||
duration_ms: Some(2),
|
||||
},
|
||||
ThreadItem::DynamicToolCall {
|
||||
id: "tool-1".to_string(),
|
||||
namespace: Some("custom-tools".to_string()),
|
||||
tool: "custom".to_string(),
|
||||
arguments: json!({"q":"hi"}),
|
||||
status: DynamicToolCallStatus::Completed,
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
duration_ms: Some(3),
|
||||
},
|
||||
ThreadItem::AgentMessage {
|
||||
id: "msg-1".to_string(),
|
||||
text: "final answer".to_string(),
|
||||
phase: Some(MessagePhase::FinalAnswer),
|
||||
memory_citation: None,
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_handle_turn_interrupted_emits_interrupted_with_error() -> Result<()> {
|
||||
let conversation_id = ThreadId::new();
|
||||
|
||||
@@ -4,6 +4,7 @@ use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ThreadHistoryBuilder;
|
||||
use codex_app_server_protocol::Turn;
|
||||
use codex_app_server_protocol::TurnError;
|
||||
use codex_app_server_protocol::TurnStatus;
|
||||
use codex_core::CodexThread;
|
||||
use codex_core::ThreadConfigSnapshot;
|
||||
use codex_protocol::ThreadId;
|
||||
@@ -64,6 +65,7 @@ pub(crate) struct ThreadState {
|
||||
pub(crate) listener_generation: u64,
|
||||
listener_command_tx: Option<mpsc::UnboundedSender<ThreadListenerCommand>>,
|
||||
current_turn_history: ThreadHistoryBuilder,
|
||||
last_terminal_turn: Option<Turn>,
|
||||
listener_thread: Option<Weak<CodexThread>>,
|
||||
}
|
||||
|
||||
@@ -96,6 +98,7 @@ impl ThreadState {
|
||||
}
|
||||
self.listener_command_tx = None;
|
||||
self.current_turn_history.reset();
|
||||
self.last_terminal_turn = None;
|
||||
self.listener_thread = None;
|
||||
}
|
||||
|
||||
@@ -113,14 +116,26 @@ impl ThreadState {
|
||||
self.current_turn_history.active_turn_snapshot()
|
||||
}
|
||||
|
||||
pub(crate) fn completion_turn_snapshot(&self, turn_id: &str) -> Option<Turn> {
|
||||
self.active_turn_snapshot()
|
||||
.filter(|turn| turn.id == turn_id && !matches!(turn.status, TurnStatus::InProgress))
|
||||
.or_else(|| {
|
||||
self.last_terminal_turn
|
||||
.clone()
|
||||
.filter(|turn| turn.id == turn_id)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn track_current_turn_event(&mut self, event: &EventMsg) {
|
||||
if let EventMsg::TurnStarted(payload) = event {
|
||||
self.turn_summary.started_at = payload.started_at;
|
||||
self.last_terminal_turn = None;
|
||||
}
|
||||
self.current_turn_history.handle_event(event);
|
||||
if matches!(event, EventMsg::TurnAborted(_) | EventMsg::TurnComplete(_))
|
||||
&& !self.current_turn_history.has_active_turn()
|
||||
{
|
||||
self.last_terminal_turn = self.current_turn_history.active_turn_snapshot();
|
||||
self.current_turn_history.reset();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -253,6 +253,7 @@ impl AgentControl {
|
||||
(None, _) => state.spawn_new_thread(config, self.clone()).await?,
|
||||
};
|
||||
agent_metadata.agent_id = Some(new_thread.thread_id);
|
||||
let spawn_cleanup = SpawnAgentCancellationCleanup::new(self.clone(), new_thread.thread_id);
|
||||
reservation.commit(agent_metadata.clone());
|
||||
|
||||
if let Some(SessionSource::SubAgent(
|
||||
@@ -326,6 +327,7 @@ impl AgentControl {
|
||||
);
|
||||
}
|
||||
|
||||
spawn_cleanup.disarm();
|
||||
Ok(LiveAgent {
|
||||
thread_id: new_thread.thread_id,
|
||||
metadata: agent_metadata,
|
||||
@@ -686,6 +688,56 @@ impl AgentControl {
|
||||
result
|
||||
}
|
||||
|
||||
/// Submit a shutdown request for a live agent while keeping the thread tracked until the
|
||||
/// session loop actually terminates.
|
||||
///
|
||||
/// The spawned-agent slot is released only after a background waiter observes shutdown
|
||||
/// completion and removes the thread from [`ThreadManagerState`]. Keeping the thread tracked
|
||||
/// and counted until then ensures later global shutdown paths can still await or abort the
|
||||
/// underlying session loop instead of leaving an orphaned task alive on the runtime.
|
||||
pub(crate) async fn request_live_agent_shutdown_preserving_thread(
|
||||
&self,
|
||||
agent_id: ThreadId,
|
||||
) -> CodexResult<String> {
|
||||
let state = self.upgrade()?;
|
||||
let thread = match state.get_thread(agent_id).await {
|
||||
Ok(thread) => thread,
|
||||
Err(_) => {
|
||||
self.state.release_spawned_thread(agent_id);
|
||||
return Ok(String::new());
|
||||
}
|
||||
};
|
||||
thread.codex.session.ensure_rollout_materialized().await;
|
||||
let _ = thread.codex.session.flush_rollout().await;
|
||||
let result = if matches!(thread.agent_status().await, AgentStatus::Shutdown) {
|
||||
Ok(String::new())
|
||||
} else {
|
||||
state.send_op(agent_id, Op::Shutdown {}).await
|
||||
};
|
||||
if matches!(result, Err(CodexErr::InternalAgentDied)) {
|
||||
let _ = state.remove_thread(&agent_id).await;
|
||||
self.state.release_spawned_thread(agent_id);
|
||||
} else if result.is_ok() {
|
||||
let registry = self.state.clone();
|
||||
tokio::spawn(async move {
|
||||
match thread.shutdown_and_wait().await {
|
||||
Ok(()) | Err(CodexErr::InternalAgentDied) => {
|
||||
let _ = state.remove_thread(&agent_id).await;
|
||||
registry.release_spawned_thread(agent_id);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
thread_id = %agent_id,
|
||||
error = %err,
|
||||
"failed to wait for live agent shutdown; keeping thread tracked"
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Mark `agent_id` as explicitly closed in persisted spawn-edge state, then shut down the
|
||||
/// agent and any live descendants reached from the in-memory tree.
|
||||
pub(crate) async fn close_agent(&self, agent_id: ThreadId) -> CodexResult<String> {
|
||||
@@ -1169,6 +1221,58 @@ impl AgentControl {
|
||||
}
|
||||
}
|
||||
|
||||
struct SpawnAgentCancellationCleanup {
|
||||
control: AgentControl,
|
||||
thread_id: ThreadId,
|
||||
armed: bool,
|
||||
}
|
||||
|
||||
impl SpawnAgentCancellationCleanup {
|
||||
fn new(control: AgentControl, thread_id: ThreadId) -> Self {
|
||||
Self {
|
||||
control,
|
||||
thread_id,
|
||||
armed: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn disarm(mut self) {
|
||||
self.armed = false;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SpawnAgentCancellationCleanup {
|
||||
fn drop(&mut self) {
|
||||
if !self.armed {
|
||||
return;
|
||||
}
|
||||
|
||||
let control = self.control.clone();
|
||||
let thread_id = self.thread_id;
|
||||
let Ok(handle) = tokio::runtime::Handle::try_current() else {
|
||||
tracing::warn!(
|
||||
thread_id = %thread_id,
|
||||
"spawn_agent was cancelled without a Tokio runtime available for cleanup"
|
||||
);
|
||||
return;
|
||||
};
|
||||
handle.spawn(async move {
|
||||
if let Err(err) = control
|
||||
.request_live_agent_shutdown_preserving_thread(thread_id)
|
||||
.await
|
||||
&& let Ok(state) = control.upgrade()
|
||||
{
|
||||
tracing::warn!(
|
||||
thread_id = %thread_id,
|
||||
error = %err,
|
||||
"failed to shut down thread from cancelled spawn_agent; removing tracking entry"
|
||||
);
|
||||
let _ = state.remove_thread(&thread_id).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn thread_spawn_parent_thread_id(session_source: &SessionSource) -> Option<ThreadId> {
|
||||
match session_source {
|
||||
SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
|
||||
|
||||
@@ -996,6 +996,61 @@ async fn spawn_agent_releases_slot_after_shutdown() {
|
||||
.expect("shutdown agent");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn request_live_agent_shutdown_preserving_thread_cleans_up_after_shutdown() {
|
||||
let max_threads = 1usize;
|
||||
let (_home, config) = test_config_with_cli_overrides(vec![(
|
||||
"agents.max_threads".to_string(),
|
||||
TomlValue::Integer(max_threads as i64),
|
||||
)])
|
||||
.await;
|
||||
let manager = ThreadManager::with_models_provider_and_home_for_tests(
|
||||
CodexAuth::from_api_key("dummy"),
|
||||
config.model_provider.clone(),
|
||||
config.codex_home.clone().to_path_buf(),
|
||||
std::sync::Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()),
|
||||
);
|
||||
let control = manager.agent_control();
|
||||
|
||||
let first_agent_id = control
|
||||
.spawn_agent(
|
||||
config.clone(),
|
||||
text_input("hello"),
|
||||
/*session_source*/ None,
|
||||
)
|
||||
.await
|
||||
.expect("spawn_agent should succeed");
|
||||
let _ = control
|
||||
.request_live_agent_shutdown_preserving_thread(first_agent_id)
|
||||
.await
|
||||
.expect("shutdown request should succeed");
|
||||
|
||||
timeout(Duration::from_secs(5), async {
|
||||
while manager.get_thread(first_agent_id).await.is_ok() {
|
||||
sleep(Duration::from_millis(25)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("thread should be removed after shutdown completes");
|
||||
|
||||
let second_agent_id = control
|
||||
.spawn_agent(
|
||||
config.clone(),
|
||||
text_input("hello again"),
|
||||
/*session_source*/ None,
|
||||
)
|
||||
.await
|
||||
.expect("spawn_agent should succeed after preserving-thread shutdown");
|
||||
|
||||
let report = manager
|
||||
.shutdown_all_threads_bounded(Duration::from_secs(10))
|
||||
.await;
|
||||
assert_eq!(report.completed, vec![second_agent_id]);
|
||||
assert_eq!(report.submit_failed, Vec::<ThreadId>::new());
|
||||
assert_eq!(report.timed_out, Vec::<ThreadId>::new());
|
||||
assert!(manager.list_thread_ids().await.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn spawn_agent_limit_shared_across_clones() {
|
||||
let max_threads = 1usize;
|
||||
|
||||
@@ -110,6 +110,22 @@ struct ActiveJobItem {
|
||||
status_rx: Option<Receiver<AgentStatus>>,
|
||||
}
|
||||
|
||||
fn request_live_agent_shutdown(
|
||||
agent_control: crate::agent::control::AgentControl,
|
||||
thread_id: ThreadId,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
tokio::task::yield_now().await;
|
||||
let _ = agent_control
|
||||
.request_live_agent_shutdown_preserving_thread(thread_id)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
fn should_wait_for_scheduler_change(progressed: bool, agent_limit_reached: bool) -> bool {
|
||||
!progressed || agent_limit_reached
|
||||
}
|
||||
|
||||
struct JobProgressEmitter {
|
||||
started_at: Instant,
|
||||
last_emit_at: Instant,
|
||||
@@ -177,6 +193,13 @@ impl JobProgressEmitter {
|
||||
}
|
||||
}
|
||||
|
||||
#[path = "agent_jobs_db.rs"]
|
||||
mod db_ops;
|
||||
#[path = "agent_jobs_slots.rs"]
|
||||
mod slots;
|
||||
#[path = "agent_jobs_startup.rs"]
|
||||
mod startup;
|
||||
|
||||
impl ToolHandler for BatchJobHandler {
|
||||
type Output = FunctionToolOutput;
|
||||
|
||||
@@ -318,44 +341,50 @@ mod spawn_agents_on_csv {
|
||||
args.max_runtime_seconds
|
||||
.or(turn.config.agent_job_max_runtime_seconds),
|
||||
)?;
|
||||
let _job = db
|
||||
.create_agent_job(
|
||||
let _job = db_ops::retry_locked("create_agent_job", || async {
|
||||
db.create_agent_job(
|
||||
&codex_state::AgentJobCreateParams {
|
||||
id: job_id.clone(),
|
||||
name: job_name,
|
||||
instruction: args.instruction,
|
||||
name: job_name.clone(),
|
||||
instruction: args.instruction.clone(),
|
||||
auto_export: true,
|
||||
max_runtime_seconds,
|
||||
output_schema_json: args.output_schema,
|
||||
input_headers: headers,
|
||||
output_schema_json: args.output_schema.clone(),
|
||||
input_headers: headers.clone(),
|
||||
input_csv_path: input_path.display().to_string(),
|
||||
output_csv_path: output_csv_path.display().to_string(),
|
||||
},
|
||||
items.as_slice(),
|
||||
)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to create agent job: {err}"))
|
||||
})?;
|
||||
})
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to create agent job: {err}"))
|
||||
})?;
|
||||
|
||||
let requested_concurrency = args.max_concurrency.or(args.max_workers);
|
||||
let options = match build_runner_options(&session, &turn, requested_concurrency).await {
|
||||
Ok(options) => options,
|
||||
Err(err) => {
|
||||
let error_message = err.to_string();
|
||||
let _ = db
|
||||
.mark_agent_job_failed(job_id.as_str(), error_message.as_str())
|
||||
.await;
|
||||
let _ = db_ops::retry_locked("mark_agent_job_failed_after_options", || async {
|
||||
db.mark_agent_job_failed(job_id.as_str(), error_message.as_str())
|
||||
.await
|
||||
})
|
||||
.await;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
db.mark_agent_job_running(job_id.as_str())
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to transition agent job {job_id} to running: {err}"
|
||||
))
|
||||
})?;
|
||||
db_ops::retry_locked("mark_agent_job_running", || async {
|
||||
db.mark_agent_job_running(job_id.as_str()).await
|
||||
})
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to transition agent job {job_id} to running: {err}"
|
||||
))
|
||||
})?;
|
||||
let max_threads = turn.config.agent_max_threads;
|
||||
let effective_concurrency = options.max_concurrency;
|
||||
let message = format!(
|
||||
@@ -372,25 +401,26 @@ mod spawn_agents_on_csv {
|
||||
.await
|
||||
{
|
||||
let error_message = format!("job runner failed: {err}");
|
||||
let _ = db
|
||||
.mark_agent_job_failed(job_id.as_str(), error_message.as_str())
|
||||
.await;
|
||||
let _ = db_ops::retry_locked("mark_agent_job_failed_after_runner_error", || async {
|
||||
db.mark_agent_job_failed(job_id.as_str(), error_message.as_str())
|
||||
.await
|
||||
})
|
||||
.await;
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"agent job {job_id} failed: {err}"
|
||||
)));
|
||||
}
|
||||
|
||||
let job = db
|
||||
.get_agent_job(job_id.as_str())
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to load agent job {job_id}: {err}"
|
||||
))
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
FunctionCallError::RespondToModel(format!("agent job {job_id} not found"))
|
||||
})?;
|
||||
let job = db_ops::retry_locked("get_agent_job_after_runner", || async {
|
||||
db.get_agent_job(job_id.as_str()).await
|
||||
})
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to load agent job {job_id}: {err}"))
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
FunctionCallError::RespondToModel(format!("agent job {job_id} not found"))
|
||||
})?;
|
||||
let output_path = PathBuf::from(job.output_csv_path.clone());
|
||||
if !tokio::fs::try_exists(&output_path).await.unwrap_or(false) {
|
||||
export_job_csv_snapshot(db.clone(), &job)
|
||||
@@ -401,24 +431,27 @@ mod spawn_agents_on_csv {
|
||||
))
|
||||
})?;
|
||||
}
|
||||
let progress = db
|
||||
.get_agent_job_progress(job_id.as_str())
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to load agent job progress {job_id}: {err}"
|
||||
))
|
||||
})?;
|
||||
let progress = db_ops::retry_locked("get_agent_job_progress_after_runner", || async {
|
||||
db.get_agent_job_progress(job_id.as_str()).await
|
||||
})
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to load agent job progress {job_id}: {err}"
|
||||
))
|
||||
})?;
|
||||
let mut job_error = job.last_error.clone().filter(|err| !err.trim().is_empty());
|
||||
let failed_item_errors = if progress.failed_items > 0 {
|
||||
let items = db
|
||||
.list_agent_job_items(
|
||||
let items = db_ops::retry_locked("list_failed_agent_job_items", || async {
|
||||
db.list_agent_job_items(
|
||||
job_id.as_str(),
|
||||
Some(codex_state::AgentJobItemStatus::Failed),
|
||||
Some(5),
|
||||
)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
})
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let summaries: Vec<_> = items
|
||||
.into_iter()
|
||||
.filter_map(|item| {
|
||||
@@ -479,27 +512,50 @@ mod report_agent_job_result {
|
||||
));
|
||||
}
|
||||
let db = required_state_db(&session)?;
|
||||
let reporting_thread_id = session.conversation_id.to_string();
|
||||
let accepted = db
|
||||
.report_agent_job_item_result(
|
||||
args.job_id.as_str(),
|
||||
args.item_id.as_str(),
|
||||
reporting_thread_id.as_str(),
|
||||
&args.result,
|
||||
)
|
||||
let reporting_thread_id = session.conversation_id;
|
||||
let reporting_thread_id_str = reporting_thread_id.to_string();
|
||||
let accepted = if args.stop.unwrap_or(false) {
|
||||
db_ops::retry_locked("report_agent_job_item_result_and_cancel_job", || async {
|
||||
db.report_agent_job_item_result_and_cancel_job(
|
||||
args.job_id.as_str(),
|
||||
args.item_id.as_str(),
|
||||
reporting_thread_id_str.as_str(),
|
||||
&args.result,
|
||||
"cancelled by worker request",
|
||||
)
|
||||
.await
|
||||
})
|
||||
.await
|
||||
.map_err(|err| {
|
||||
let job_id = args.job_id.as_str();
|
||||
let item_id = args.item_id.as_str();
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to record agent job result for {job_id} / {item_id}: {err}"
|
||||
))
|
||||
})?;
|
||||
if accepted && args.stop.unwrap_or(false) {
|
||||
let message = "cancelled by worker request";
|
||||
let _ = db
|
||||
.mark_agent_job_cancelled(args.job_id.as_str(), message)
|
||||
.await;
|
||||
} else {
|
||||
db_ops::retry_locked("report_agent_job_item_result", || async {
|
||||
db.report_agent_job_item_result(
|
||||
args.job_id.as_str(),
|
||||
args.item_id.as_str(),
|
||||
reporting_thread_id_str.as_str(),
|
||||
&args.result,
|
||||
)
|
||||
.await
|
||||
})
|
||||
.await
|
||||
}
|
||||
.map_err(|err| {
|
||||
let job_id = args.job_id.as_str();
|
||||
let item_id = args.item_id.as_str();
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to record agent job result for {job_id} / {item_id}: {err}"
|
||||
))
|
||||
})?;
|
||||
if accepted {
|
||||
tracing::debug!(
|
||||
job_id = args.job_id,
|
||||
item_id = args.item_id,
|
||||
thread_id = %reporting_thread_id,
|
||||
"agent job accepted worker result; scheduling worker shutdown"
|
||||
);
|
||||
request_live_agent_shutdown(
|
||||
session.services.agent_control.clone(),
|
||||
reporting_thread_id,
|
||||
);
|
||||
}
|
||||
let content =
|
||||
serde_json::to_string(&ReportAgentJobResultToolResult { accepted }).map_err(|err| {
|
||||
@@ -571,12 +627,14 @@ async fn run_agent_job_loop(
|
||||
job_id: String,
|
||||
options: JobRunnerOptions,
|
||||
) -> anyhow::Result<()> {
|
||||
let job = db
|
||||
.get_agent_job(job_id.as_str())
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("agent job {job_id} was not found"))?;
|
||||
let job = db_ops::retry_locked("get_agent_job_for_runner", || async {
|
||||
db.get_agent_job(job_id.as_str()).await
|
||||
})
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("agent job {job_id} was not found"))?;
|
||||
let runtime_timeout = job_runtime_timeout(&job);
|
||||
let mut active_items: HashMap<ThreadId, ActiveJobItem> = HashMap::new();
|
||||
let mut starting_items = startup::StartupTasks::default();
|
||||
let mut progress_emitter = JobProgressEmitter::new();
|
||||
recover_running_items(
|
||||
session.clone(),
|
||||
@@ -586,7 +644,10 @@ async fn run_agent_job_loop(
|
||||
runtime_timeout,
|
||||
)
|
||||
.await?;
|
||||
let initial_progress = db.get_agent_job_progress(job_id.as_str()).await?;
|
||||
let initial_progress = db_ops::retry_locked("get_initial_agent_job_progress", || async {
|
||||
db.get_agent_job_progress(job_id.as_str()).await
|
||||
})
|
||||
.await?;
|
||||
progress_emitter
|
||||
.maybe_emit(
|
||||
&session,
|
||||
@@ -597,11 +658,20 @@ async fn run_agent_job_loop(
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut cancel_requested = db.is_agent_job_cancelled(job_id.as_str()).await?;
|
||||
let mut cancel_requested = db_ops::retry_locked("is_agent_job_cancelled_initial", || async {
|
||||
db.is_agent_job_cancelled(job_id.as_str()).await
|
||||
})
|
||||
.await?;
|
||||
loop {
|
||||
let mut progressed = false;
|
||||
let mut agent_limit_reached = false;
|
||||
|
||||
if !cancel_requested && db.is_agent_job_cancelled(job_id.as_str()).await? {
|
||||
if !cancel_requested
|
||||
&& db_ops::retry_locked("is_agent_job_cancelled_pre_launch", || async {
|
||||
db.is_agent_job_cancelled(job_id.as_str()).await
|
||||
})
|
||||
.await?
|
||||
{
|
||||
cancel_requested = true;
|
||||
let _ = session
|
||||
.notify_background_event(
|
||||
@@ -611,85 +681,112 @@ async fn run_agent_job_loop(
|
||||
.await;
|
||||
}
|
||||
|
||||
if !cancel_requested && active_items.len() < options.max_concurrency {
|
||||
let slots = options.max_concurrency - active_items.len();
|
||||
let pending_items = db
|
||||
.list_agent_job_items(
|
||||
job_id.as_str(),
|
||||
Some(codex_state::AgentJobItemStatus::Pending),
|
||||
Some(slots),
|
||||
let startup_result = startup::drain_ready_startups(
|
||||
session.clone(),
|
||||
db.clone(),
|
||||
job_id.as_str(),
|
||||
&mut active_items,
|
||||
&mut starting_items,
|
||||
)
|
||||
.await?;
|
||||
progressed |= startup_result.progressed;
|
||||
agent_limit_reached |= startup_result.agent_limit_reached;
|
||||
|
||||
let scheduler_progress =
|
||||
db_ops::retry_locked("get_scheduler_agent_job_progress", || async {
|
||||
db.get_agent_job_progress(job_id.as_str()).await
|
||||
})
|
||||
.await?;
|
||||
if slots::reclaim_inactive_active_items(
|
||||
session.clone(),
|
||||
db.clone(),
|
||||
job_id.as_str(),
|
||||
&mut active_items,
|
||||
scheduler_progress.running_items,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
progressed = true;
|
||||
}
|
||||
|
||||
if !cancel_requested
|
||||
&& db_ops::retry_locked("is_agent_job_cancelled_post_reclaim", || async {
|
||||
db.is_agent_job_cancelled(job_id.as_str()).await
|
||||
})
|
||||
.await?
|
||||
{
|
||||
cancel_requested = true;
|
||||
progressed = true;
|
||||
let _ = session
|
||||
.notify_background_event(
|
||||
&turn,
|
||||
format!("agent job {job_id} cancellation requested; stopping new workers"),
|
||||
)
|
||||
.await?;
|
||||
for item in pending_items {
|
||||
let prompt = build_worker_prompt(&job, &item)?;
|
||||
let items = vec![UserInput::Text {
|
||||
text: prompt,
|
||||
text_elements: Vec::new(),
|
||||
}];
|
||||
let thread_id = match session
|
||||
.services
|
||||
.agent_control
|
||||
.spawn_agent(
|
||||
options.spawn_config.clone(),
|
||||
items.into(),
|
||||
Some(SessionSource::SubAgent(SubAgentSource::Other(format!(
|
||||
"agent_job:{job_id}"
|
||||
)))),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(thread_id) => thread_id,
|
||||
Err(CodexErr::AgentLimitReached { .. }) => {
|
||||
db.mark_agent_job_item_pending(
|
||||
job_id.as_str(),
|
||||
item.item_id.as_str(),
|
||||
/*error_message*/ None,
|
||||
)
|
||||
.await?;
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
let error_message = format!("failed to spawn worker: {err}");
|
||||
db.mark_agent_job_item_failed(
|
||||
job_id.as_str(),
|
||||
item.item_id.as_str(),
|
||||
error_message.as_str(),
|
||||
)
|
||||
.await?;
|
||||
progressed = true;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let assigned = db
|
||||
.mark_agent_job_item_running_with_thread(
|
||||
job_id.as_str(),
|
||||
item.item_id.as_str(),
|
||||
thread_id.to_string().as_str(),
|
||||
)
|
||||
.await?;
|
||||
if !assigned {
|
||||
let _ = session
|
||||
.services
|
||||
.agent_control
|
||||
.shutdown_live_agent(thread_id)
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
active_items.insert(
|
||||
thread_id,
|
||||
ActiveJobItem {
|
||||
item_id: item.item_id.clone(),
|
||||
started_at: Instant::now(),
|
||||
status_rx: session
|
||||
.services
|
||||
.agent_control
|
||||
.subscribe_status(thread_id)
|
||||
.await
|
||||
.ok(),
|
||||
},
|
||||
);
|
||||
progressed = true;
|
||||
}
|
||||
.await;
|
||||
}
|
||||
|
||||
let terminal_in_db = if cancel_requested {
|
||||
scheduler_progress.running_items == 0
|
||||
} else {
|
||||
scheduler_progress.pending_items == 0 && scheduler_progress.running_items == 0
|
||||
};
|
||||
if terminal_in_db
|
||||
&& slots::reconcile_terminal_scheduler_state(
|
||||
session.clone(),
|
||||
job_id.as_str(),
|
||||
&scheduler_progress,
|
||||
&mut active_items,
|
||||
&mut starting_items,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
progressed = true;
|
||||
}
|
||||
if terminal_in_db && active_items.is_empty() && starting_items.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
if !cancel_requested
|
||||
&& !agent_limit_reached
|
||||
&& active_items.len() + starting_items.len() < options.max_concurrency
|
||||
&& startup::launch_pending_items(
|
||||
session.clone(),
|
||||
db.clone(),
|
||||
&job,
|
||||
job_id.as_str(),
|
||||
&options,
|
||||
startup::SchedulerOccupancy {
|
||||
active_items: active_items.len(),
|
||||
db_pending_items: scheduler_progress.pending_items,
|
||||
db_running_items: scheduler_progress.running_items,
|
||||
},
|
||||
&mut starting_items,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
progressed = true;
|
||||
}
|
||||
|
||||
let startup_result = startup::drain_ready_startups(
|
||||
session.clone(),
|
||||
db.clone(),
|
||||
job_id.as_str(),
|
||||
&mut active_items,
|
||||
&mut starting_items,
|
||||
)
|
||||
.await?;
|
||||
progressed |= startup_result.progressed;
|
||||
agent_limit_reached |= startup_result.agent_limit_reached;
|
||||
|
||||
if startup::reap_stale_startups(
|
||||
db.clone(),
|
||||
job_id.as_str(),
|
||||
&mut starting_items,
|
||||
runtime_timeout,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
progressed = true;
|
||||
}
|
||||
|
||||
if reap_stale_active_items(
|
||||
@@ -706,19 +803,43 @@ async fn run_agent_job_loop(
|
||||
|
||||
let finished = find_finished_threads(session.clone(), &active_items).await;
|
||||
if finished.is_empty() {
|
||||
let progress = db.get_agent_job_progress(job_id.as_str()).await?;
|
||||
if cancel_requested {
|
||||
if progress.running_items == 0 && active_items.is_empty() {
|
||||
break;
|
||||
}
|
||||
} else if progress.pending_items == 0
|
||||
&& progress.running_items == 0
|
||||
&& active_items.is_empty()
|
||||
let progress = if progressed {
|
||||
db_ops::retry_locked("get_agent_job_progress_after_progress", || async {
|
||||
db.get_agent_job_progress(job_id.as_str()).await
|
||||
})
|
||||
.await?
|
||||
} else {
|
||||
scheduler_progress
|
||||
};
|
||||
let terminal_in_db = if cancel_requested {
|
||||
progress.running_items == 0
|
||||
} else {
|
||||
progress.pending_items == 0 && progress.running_items == 0
|
||||
};
|
||||
if terminal_in_db
|
||||
&& slots::reconcile_terminal_scheduler_state(
|
||||
session.clone(),
|
||||
job_id.as_str(),
|
||||
&progress,
|
||||
&mut active_items,
|
||||
&mut starting_items,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
progressed = true;
|
||||
}
|
||||
if terminal_in_db && active_items.is_empty() && starting_items.is_empty() {
|
||||
break;
|
||||
}
|
||||
if !progressed {
|
||||
wait_for_status_change(&active_items).await;
|
||||
if should_wait_for_scheduler_change(progressed, agent_limit_reached) {
|
||||
startup::wait_for_startup_or_status_change(
|
||||
session.clone(),
|
||||
db.clone(),
|
||||
job_id.as_str(),
|
||||
&mut active_items,
|
||||
&mut starting_items,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@@ -733,7 +854,11 @@ async fn run_agent_job_loop(
|
||||
)
|
||||
.await?;
|
||||
active_items.remove(&thread_id);
|
||||
let progress = db.get_agent_job_progress(job_id.as_str()).await?;
|
||||
let progress =
|
||||
db_ops::retry_locked("get_agent_job_progress_after_finalize", || async {
|
||||
db.get_agent_job_progress(job_id.as_str()).await
|
||||
})
|
||||
.await?;
|
||||
progress_emitter
|
||||
.maybe_emit(
|
||||
&session,
|
||||
@@ -746,14 +871,24 @@ async fn run_agent_job_loop(
|
||||
}
|
||||
}
|
||||
|
||||
let progress = db.get_agent_job_progress(job_id.as_str()).await?;
|
||||
let progress = db_ops::retry_locked("get_agent_job_progress_before_export", || async {
|
||||
db.get_agent_job_progress(job_id.as_str()).await
|
||||
})
|
||||
.await?;
|
||||
if let Err(err) = export_job_csv_snapshot(db.clone(), &job).await {
|
||||
let message = format!("auto-export failed: {err}");
|
||||
db.mark_agent_job_failed(job_id.as_str(), message.as_str())
|
||||
.await?;
|
||||
db_ops::retry_locked("mark_agent_job_failed_after_export", || async {
|
||||
db.mark_agent_job_failed(job_id.as_str(), message.as_str())
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
let cancelled = cancel_requested || db.is_agent_job_cancelled(job_id.as_str()).await?;
|
||||
let cancelled = cancel_requested
|
||||
|| db_ops::retry_locked("is_agent_job_cancelled_before_complete", || async {
|
||||
db.is_agent_job_cancelled(job_id.as_str()).await
|
||||
})
|
||||
.await?;
|
||||
if cancelled {
|
||||
let pending_items = progress.pending_items;
|
||||
let message =
|
||||
@@ -775,8 +910,14 @@ async fn run_agent_job_loop(
|
||||
let message = format!("agent job completed with {failed_items} failed items");
|
||||
let _ = session.notify_background_event(&turn, message).await;
|
||||
}
|
||||
db.mark_agent_job_completed(job_id.as_str()).await?;
|
||||
let progress = db.get_agent_job_progress(job_id.as_str()).await?;
|
||||
db_ops::retry_locked("mark_agent_job_completed", || async {
|
||||
db.mark_agent_job_completed(job_id.as_str()).await
|
||||
})
|
||||
.await?;
|
||||
let progress = db_ops::retry_locked("get_agent_job_progress_after_complete", || async {
|
||||
db.get_agent_job_progress(job_id.as_str()).await
|
||||
})
|
||||
.await?;
|
||||
progress_emitter
|
||||
.maybe_emit(
|
||||
&session,
|
||||
@@ -793,9 +934,11 @@ async fn export_job_csv_snapshot(
|
||||
db: Arc<codex_state::StateRuntime>,
|
||||
job: &codex_state::AgentJob,
|
||||
) -> anyhow::Result<()> {
|
||||
let items = db
|
||||
.list_agent_job_items(job.id.as_str(), /*status*/ None, /*limit*/ None)
|
||||
.await?;
|
||||
let items = db_ops::retry_locked("list_agent_job_items_for_export", || async {
|
||||
db.list_agent_job_items(job.id.as_str(), /*status*/ None, /*limit*/ None)
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
let csv_content = render_job_csv(job.input_headers.as_slice(), items.as_slice())
|
||||
.map_err(|err| anyhow::anyhow!("failed to render job csv for auto-export: {err}"))?;
|
||||
let output_path = PathBuf::from(job.output_csv_path.clone());
|
||||
@@ -813,35 +956,40 @@ async fn recover_running_items(
|
||||
active_items: &mut HashMap<ThreadId, ActiveJobItem>,
|
||||
runtime_timeout: Duration,
|
||||
) -> anyhow::Result<()> {
|
||||
let running_items = db
|
||||
.list_agent_job_items(
|
||||
job_id,
|
||||
Some(codex_state::AgentJobItemStatus::Running),
|
||||
/*limit*/ None,
|
||||
)
|
||||
let running_items =
|
||||
db_ops::retry_locked("list_running_agent_job_items_for_recovery", || async {
|
||||
db.list_agent_job_items(
|
||||
job_id,
|
||||
Some(codex_state::AgentJobItemStatus::Running),
|
||||
/*limit*/ None,
|
||||
)
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
for item in running_items {
|
||||
if is_item_stale(&item, runtime_timeout) {
|
||||
let error_message = format!("worker exceeded max runtime of {runtime_timeout:?}");
|
||||
db.mark_agent_job_item_failed(job_id, item.item_id.as_str(), error_message.as_str())
|
||||
.await?;
|
||||
db_ops::retry_locked("mark_stale_agent_job_item_failed_on_recovery", || async {
|
||||
db.mark_agent_job_item_failed(job_id, item.item_id.as_str(), error_message.as_str())
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
if let Some(assigned_thread_id) = item.assigned_thread_id.as_ref()
|
||||
&& let Ok(thread_id) = ThreadId::from_string(assigned_thread_id.as_str())
|
||||
{
|
||||
let _ = session
|
||||
.services
|
||||
.agent_control
|
||||
.shutdown_live_agent(thread_id)
|
||||
.await;
|
||||
request_live_agent_shutdown(session.services.agent_control.clone(), thread_id);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
let Some(assigned_thread_id) = item.assigned_thread_id.clone() else {
|
||||
db.mark_agent_job_item_failed(
|
||||
job_id,
|
||||
item.item_id.as_str(),
|
||||
"running item is missing assigned_thread_id",
|
||||
)
|
||||
db_ops::retry_locked("mark_agent_job_item_pending_on_recovery", || async {
|
||||
db.mark_agent_job_item_pending(
|
||||
job_id,
|
||||
item.item_id.as_str(),
|
||||
Some("worker startup was interrupted before a thread was assigned"),
|
||||
)
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
continue;
|
||||
};
|
||||
@@ -849,11 +997,14 @@ async fn recover_running_items(
|
||||
Ok(thread_id) => thread_id,
|
||||
Err(err) => {
|
||||
let error_message = format!("invalid assigned_thread_id: {err:?}");
|
||||
db.mark_agent_job_item_failed(
|
||||
job_id,
|
||||
item.item_id.as_str(),
|
||||
error_message.as_str(),
|
||||
)
|
||||
db_ops::retry_locked("mark_agent_job_item_failed_invalid_thread", || async {
|
||||
db.mark_agent_job_item_failed(
|
||||
job_id,
|
||||
item.item_id.as_str(),
|
||||
error_message.as_str(),
|
||||
)
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
continue;
|
||||
}
|
||||
@@ -948,12 +1099,15 @@ async fn reap_stale_active_items(
|
||||
}
|
||||
for (thread_id, item_id) in stale {
|
||||
let error_message = format!("worker exceeded max runtime of {runtime_timeout:?}");
|
||||
db.mark_agent_job_item_failed(job_id, item_id.as_str(), error_message.as_str())
|
||||
.await?;
|
||||
db_ops::retry_locked("mark_stale_active_agent_job_item_failed", || async {
|
||||
db.mark_agent_job_item_failed(job_id, item_id.as_str(), error_message.as_str())
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
let _ = session
|
||||
.services
|
||||
.agent_control
|
||||
.shutdown_live_agent(thread_id)
|
||||
.request_live_agent_shutdown_preserving_thread(thread_id)
|
||||
.await;
|
||||
active_items.remove(&thread_id);
|
||||
}
|
||||
@@ -967,29 +1121,33 @@ async fn finalize_finished_item(
|
||||
item_id: &str,
|
||||
thread_id: ThreadId,
|
||||
) -> anyhow::Result<()> {
|
||||
let item = db
|
||||
.get_agent_job_item(job_id, item_id)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("job item not found for finalization: {job_id}/{item_id}")
|
||||
})?;
|
||||
let item = db_ops::retry_locked("get_agent_job_item_for_finalization", || async {
|
||||
db.get_agent_job_item(job_id, item_id).await
|
||||
})
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("job item not found for finalization: {job_id}/{item_id}"))?;
|
||||
if matches!(item.status, codex_state::AgentJobItemStatus::Running) {
|
||||
if item.result_json.is_some() {
|
||||
let _ = db.mark_agent_job_item_completed(job_id, item_id).await?;
|
||||
let _ = db_ops::retry_locked("mark_agent_job_item_completed", || async {
|
||||
db.mark_agent_job_item_completed(job_id, item_id).await
|
||||
})
|
||||
.await?;
|
||||
} else {
|
||||
let _ = db
|
||||
.mark_agent_job_item_failed(
|
||||
let _ = db_ops::retry_locked("mark_agent_job_item_failed_missing_report", || async {
|
||||
db.mark_agent_job_item_failed(
|
||||
job_id,
|
||||
item_id,
|
||||
"worker finished without calling report_agent_job_result",
|
||||
)
|
||||
.await?;
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
let _ = session
|
||||
.services
|
||||
.agent_control
|
||||
.shutdown_live_agent(thread_id)
|
||||
.request_live_agent_shutdown_preserving_thread(thread_id)
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
95
codex-rs/core/src/tools/handlers/agent_jobs_db.rs
Normal file
95
codex-rs/core/src/tools/handlers/agent_jobs_db.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
use std::future::Future;
|
||||
use std::time::Duration;
|
||||
|
||||
const SQLITE_LOCK_RETRY_LIMIT: usize = 2;
|
||||
|
||||
fn is_sqlite_lock_error(err: &anyhow::Error) -> bool {
|
||||
err.chain().any(|cause| {
|
||||
let message = cause.to_string();
|
||||
message.contains("database is locked") || message.contains("database table is locked")
|
||||
})
|
||||
}
|
||||
|
||||
fn retry_delay(attempt: usize) -> Duration {
|
||||
Duration::from_millis(250 * (attempt as u64 + 1))
|
||||
}
|
||||
|
||||
pub(super) async fn retry_locked<T, F, Fut>(operation: &'static str, mut op: F) -> anyhow::Result<T>
|
||||
where
|
||||
F: FnMut() -> Fut,
|
||||
Fut: Future<Output = anyhow::Result<T>>,
|
||||
{
|
||||
let mut attempt = 0usize;
|
||||
loop {
|
||||
match op().await {
|
||||
Ok(value) => return Ok(value),
|
||||
Err(err) if is_sqlite_lock_error(&err) && attempt < SQLITE_LOCK_RETRY_LIMIT => {
|
||||
let retry_in = retry_delay(attempt);
|
||||
tracing::warn!(
|
||||
operation,
|
||||
attempt = attempt + 1,
|
||||
max_attempts = SQLITE_LOCK_RETRY_LIMIT + 1,
|
||||
retry_delay_ms = retry_in.as_millis() as u64,
|
||||
error = %err,
|
||||
"agent job DB operation hit sqlite lock; retrying"
|
||||
);
|
||||
tokio::time::sleep(retry_in).await;
|
||||
attempt += 1;
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
#[tokio::test]
|
||||
async fn retry_locked_retries_transient_sqlite_locks() -> anyhow::Result<()> {
|
||||
let attempts = Arc::new(AtomicUsize::new(0));
|
||||
let attempts_for_op = Arc::clone(&attempts);
|
||||
|
||||
let result = retry_locked("test_sqlite_retry", move || {
|
||||
let attempts = Arc::clone(&attempts_for_op);
|
||||
async move {
|
||||
let attempt = attempts.fetch_add(1, Ordering::SeqCst);
|
||||
if attempt < 2 {
|
||||
Err(anyhow::anyhow!(
|
||||
"error returned from database: (code: 5) database is locked"
|
||||
))
|
||||
} else {
|
||||
Ok("ok")
|
||||
}
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert_eq!(result, "ok");
|
||||
assert_eq!(attempts.load(Ordering::SeqCst), 3);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retry_locked_does_not_retry_non_lock_errors() {
|
||||
let attempts = Arc::new(AtomicUsize::new(0));
|
||||
let attempts_for_op = Arc::clone(&attempts);
|
||||
|
||||
let err = retry_locked("test_non_sqlite_retry", move || {
|
||||
let attempts = Arc::clone(&attempts_for_op);
|
||||
async move {
|
||||
attempts.fetch_add(1, Ordering::SeqCst);
|
||||
Err::<(), anyhow::Error>(anyhow::anyhow!("boom"))
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect_err("non-lock errors should not be retried");
|
||||
|
||||
assert_eq!(err.to_string(), "boom");
|
||||
assert_eq!(attempts.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
}
|
||||
125
codex-rs/core/src/tools/handlers/agent_jobs_slots.rs
Normal file
125
codex-rs/core/src/tools/handlers/agent_jobs_slots.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InactiveActiveItem {
|
||||
thread_id: ThreadId,
|
||||
item_id: String,
|
||||
db_state: &'static str,
|
||||
db_assigned_thread_id: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) async fn reclaim_inactive_active_items(
|
||||
session: Arc<Session>,
|
||||
db: Arc<codex_state::StateRuntime>,
|
||||
job_id: &str,
|
||||
active_items: &mut HashMap<ThreadId, ActiveJobItem>,
|
||||
db_running_items: usize,
|
||||
) -> anyhow::Result<bool> {
|
||||
let running_items =
|
||||
db_ops::retry_locked("list_running_agent_job_items_for_reclaim", || async {
|
||||
db.list_agent_job_items(
|
||||
job_id,
|
||||
Some(codex_state::AgentJobItemStatus::Running),
|
||||
/*limit*/ None,
|
||||
)
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
let running_by_item_id: HashMap<_, _> = running_items
|
||||
.into_iter()
|
||||
.map(|item| (item.item_id, item.assigned_thread_id))
|
||||
.collect();
|
||||
|
||||
let mut inactive_items = Vec::new();
|
||||
for (thread_id, item) in active_items.iter() {
|
||||
let thread_id_str = thread_id.to_string();
|
||||
let Some(db_assigned_thread_id) = running_by_item_id.get(item.item_id.as_str()) else {
|
||||
inactive_items.push(InactiveActiveItem {
|
||||
thread_id: *thread_id,
|
||||
item_id: item.item_id.clone(),
|
||||
db_state: "missing",
|
||||
db_assigned_thread_id: None,
|
||||
});
|
||||
continue;
|
||||
};
|
||||
let still_running = db_assigned_thread_id.as_deref() == Some(thread_id_str.as_str());
|
||||
if still_running {
|
||||
continue;
|
||||
}
|
||||
inactive_items.push(InactiveActiveItem {
|
||||
thread_id: *thread_id,
|
||||
item_id: item.item_id.clone(),
|
||||
db_state: "running",
|
||||
db_assigned_thread_id: db_assigned_thread_id.clone(),
|
||||
});
|
||||
}
|
||||
if inactive_items.is_empty() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
job_id,
|
||||
db_running_items,
|
||||
active_items = active_items.len(),
|
||||
reclaimed_items = inactive_items.len(),
|
||||
"agent job reclaiming scheduler slots for items that are no longer running in state"
|
||||
);
|
||||
|
||||
for inactive_item in inactive_items {
|
||||
active_items.remove(&inactive_item.thread_id);
|
||||
request_live_agent_shutdown(
|
||||
session.services.agent_control.clone(),
|
||||
inactive_item.thread_id,
|
||||
);
|
||||
tracing::debug!(
|
||||
job_id,
|
||||
item_id = inactive_item.item_id,
|
||||
thread_id = %inactive_item.thread_id,
|
||||
db_status = inactive_item.db_state,
|
||||
db_assigned_thread_id = inactive_item.db_assigned_thread_id.as_deref().unwrap_or(""),
|
||||
active_items = active_items.len(),
|
||||
"agent job reclaimed scheduler slot"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
pub(super) async fn reconcile_terminal_scheduler_state(
|
||||
session: Arc<Session>,
|
||||
job_id: &str,
|
||||
progress: &codex_state::AgentJobProgress,
|
||||
active_items: &mut HashMap<ThreadId, ActiveJobItem>,
|
||||
startup_tasks: &mut startup::StartupTasks,
|
||||
) -> anyhow::Result<bool> {
|
||||
if active_items.is_empty() && startup_tasks.is_empty() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let active_count = active_items.len();
|
||||
let starting_count = startup_tasks.len();
|
||||
tracing::info!(
|
||||
job_id,
|
||||
pending_items = progress.pending_items,
|
||||
db_running_items = progress.running_items,
|
||||
active_items = active_count,
|
||||
starting_items = starting_count,
|
||||
"agent job state is terminal in DB; forcing scheduler teardown"
|
||||
);
|
||||
|
||||
let thread_ids: Vec<_> = active_items.keys().copied().collect();
|
||||
for thread_id in thread_ids {
|
||||
active_items.remove(&thread_id);
|
||||
request_live_agent_shutdown(session.services.agent_control.clone(), thread_id);
|
||||
}
|
||||
|
||||
let aborted_startups = startup::abort_all_startups(startup_tasks).await;
|
||||
tracing::debug!(
|
||||
job_id,
|
||||
active_items_reclaimed = active_count,
|
||||
starting_items_aborted = aborted_startups,
|
||||
"agent job terminal scheduler teardown completed"
|
||||
);
|
||||
Ok(active_count > 0 || aborted_startups > 0)
|
||||
}
|
||||
695
codex-rs/core/src/tools/handlers/agent_jobs_startup.rs
Normal file
695
codex-rs/core/src/tools/handlers/agent_jobs_startup.rs
Normal file
@@ -0,0 +1,695 @@
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use tokio::task::AbortHandle;
|
||||
use tokio::task::Id as TaskId;
|
||||
use tokio::task::JoinError;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct WorkerStartup {
|
||||
pub(super) item_id: String,
|
||||
pub(super) started_at: Instant,
|
||||
pub(super) spawn_latency: Duration,
|
||||
pub(super) result: Result<ThreadId, CodexErr>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct LaunchingJobItem {
|
||||
item_id: String,
|
||||
started_at: Instant,
|
||||
abort_handle: AbortHandle,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(super) struct StartupTasks {
|
||||
starting_items: JoinSet<WorkerStartup>,
|
||||
launching_items: HashMap<TaskId, LaunchingJobItem>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
|
||||
pub(super) struct StartupDrainResult {
|
||||
pub(super) progressed: bool,
|
||||
pub(super) agent_limit_reached: bool,
|
||||
}
|
||||
|
||||
impl StartupDrainResult {
|
||||
fn merge(&mut self, other: Self) {
|
||||
self.progressed |= other.progressed;
|
||||
self.agent_limit_reached |= other.agent_limit_reached;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub(super) struct SchedulerOccupancy {
|
||||
pub(super) active_items: usize,
|
||||
pub(super) db_pending_items: usize,
|
||||
pub(super) db_running_items: usize,
|
||||
}
|
||||
|
||||
impl StartupTasks {
|
||||
pub(super) fn len(&self) -> usize {
|
||||
self.starting_items.len()
|
||||
}
|
||||
|
||||
pub(super) fn is_empty(&self) -> bool {
|
||||
self.starting_items.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_tracked_startup_task<F>(
|
||||
startup_tasks: &mut StartupTasks,
|
||||
item_id: String,
|
||||
started_at: Instant,
|
||||
task: F,
|
||||
) where
|
||||
F: Future<Output = WorkerStartup> + Send + 'static,
|
||||
{
|
||||
let abort_handle = startup_tasks.starting_items.spawn(task);
|
||||
startup_tasks.launching_items.insert(
|
||||
abort_handle.id(),
|
||||
LaunchingJobItem {
|
||||
item_id,
|
||||
started_at,
|
||||
abort_handle,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub(super) async fn launch_pending_items(
|
||||
session: Arc<Session>,
|
||||
db: Arc<codex_state::StateRuntime>,
|
||||
job: &codex_state::AgentJob,
|
||||
job_id: &str,
|
||||
options: &JobRunnerOptions,
|
||||
occupancy: SchedulerOccupancy,
|
||||
startup_tasks: &mut StartupTasks,
|
||||
) -> anyhow::Result<bool> {
|
||||
let slots = options
|
||||
.max_concurrency
|
||||
.saturating_sub(occupancy.active_items + startup_tasks.len());
|
||||
if slots == 0 {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let pending_items = db_ops::retry_locked("list_pending_agent_job_items_for_launch", || async {
|
||||
db.list_agent_job_items(
|
||||
job_id,
|
||||
Some(codex_state::AgentJobItemStatus::Pending),
|
||||
Some(slots),
|
||||
)
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut launched = 0usize;
|
||||
let mut progressed = false;
|
||||
for item in pending_items {
|
||||
let claimed = db_ops::retry_locked("mark_agent_job_item_running_for_launch", || async {
|
||||
db.mark_agent_job_item_running(job_id, item.item_id.as_str())
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
if !claimed {
|
||||
continue;
|
||||
}
|
||||
|
||||
let prompt = match build_worker_prompt(job, &item) {
|
||||
Ok(prompt) => prompt,
|
||||
Err(err) => {
|
||||
let error_message = format!("failed to build worker prompt: {err}");
|
||||
db_ops::retry_locked("mark_agent_job_item_failed_for_prompt_build", || async {
|
||||
db.mark_agent_job_item_failed(
|
||||
job_id,
|
||||
item.item_id.as_str(),
|
||||
error_message.as_str(),
|
||||
)
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
progressed = true;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let item_id = item.item_id.clone();
|
||||
let session = session.clone();
|
||||
let spawn_config = options.spawn_config.clone();
|
||||
let session_source =
|
||||
SessionSource::SubAgent(SubAgentSource::Other(format!("agent_job:{job_id}")));
|
||||
let started_at = Instant::now();
|
||||
spawn_tracked_startup_task(startup_tasks, item_id.clone(), started_at, async move {
|
||||
let items = vec![UserInput::Text {
|
||||
text: prompt,
|
||||
text_elements: Vec::new(),
|
||||
}];
|
||||
let result = session
|
||||
.services
|
||||
.agent_control
|
||||
.spawn_agent(spawn_config, items.into(), Some(session_source))
|
||||
.await;
|
||||
WorkerStartup {
|
||||
item_id,
|
||||
started_at,
|
||||
spawn_latency: started_at.elapsed(),
|
||||
result,
|
||||
}
|
||||
});
|
||||
launched = launched.saturating_add(1);
|
||||
progressed = true;
|
||||
}
|
||||
|
||||
if launched > 0 {
|
||||
tracing::info!(
|
||||
job_id,
|
||||
launched,
|
||||
db_pending_items = occupancy.db_pending_items,
|
||||
db_running_items = occupancy.db_running_items,
|
||||
active_items = occupancy.active_items,
|
||||
starting_items = startup_tasks.len(),
|
||||
target_concurrency = options.max_concurrency,
|
||||
"agent job queued worker startups"
|
||||
);
|
||||
}
|
||||
Ok(progressed)
|
||||
}
|
||||
|
||||
pub(super) async fn drain_ready_startups(
|
||||
session: Arc<Session>,
|
||||
db: Arc<codex_state::StateRuntime>,
|
||||
job_id: &str,
|
||||
active_items: &mut HashMap<ThreadId, ActiveJobItem>,
|
||||
startup_tasks: &mut StartupTasks,
|
||||
) -> anyhow::Result<StartupDrainResult> {
|
||||
let mut drain_result = StartupDrainResult::default();
|
||||
while let Some(join_result) = startup_tasks.starting_items.try_join_next_with_id() {
|
||||
let starting_items_len = startup_tasks.starting_items.len();
|
||||
let startup_result = handle_worker_startup_result(
|
||||
session.clone(),
|
||||
db.clone(),
|
||||
job_id,
|
||||
active_items,
|
||||
startup_tasks,
|
||||
join_result,
|
||||
starting_items_len,
|
||||
)
|
||||
.await?;
|
||||
drain_result.merge(startup_result);
|
||||
}
|
||||
Ok(drain_result)
|
||||
}
|
||||
|
||||
pub(super) async fn wait_for_startup_or_status_change(
|
||||
session: Arc<Session>,
|
||||
db: Arc<codex_state::StateRuntime>,
|
||||
job_id: &str,
|
||||
active_items: &mut HashMap<ThreadId, ActiveJobItem>,
|
||||
startup_tasks: &mut StartupTasks,
|
||||
) -> anyhow::Result<()> {
|
||||
if startup_tasks.is_empty() {
|
||||
wait_for_status_change(active_items).await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let active_items_ref = &*active_items;
|
||||
if active_items_ref.is_empty() {
|
||||
if let Ok(Some(result)) = timeout(
|
||||
STATUS_POLL_INTERVAL,
|
||||
startup_tasks.starting_items.join_next_with_id(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
let starting_items_len = startup_tasks.starting_items.len();
|
||||
handle_worker_startup_result(
|
||||
session,
|
||||
db,
|
||||
job_id,
|
||||
active_items,
|
||||
startup_tasks,
|
||||
result,
|
||||
starting_items_len,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
startup = startup_tasks.starting_items.join_next_with_id() => {
|
||||
if let Some(result) = startup {
|
||||
let starting_items_len = startup_tasks.starting_items.len();
|
||||
handle_worker_startup_result(
|
||||
session,
|
||||
db,
|
||||
job_id,
|
||||
active_items,
|
||||
startup_tasks,
|
||||
result,
|
||||
starting_items_len,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
_ = wait_for_status_change(active_items_ref) => {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) async fn abort_all_startups(startup_tasks: &mut StartupTasks) -> usize {
|
||||
let startup_count = startup_tasks.starting_items.len();
|
||||
if startup_count == 0 {
|
||||
startup_tasks.launching_items.clear();
|
||||
return 0;
|
||||
}
|
||||
|
||||
for launching_item in startup_tasks.launching_items.values() {
|
||||
launching_item.abort_handle.abort();
|
||||
}
|
||||
startup_tasks.launching_items.clear();
|
||||
|
||||
while startup_tasks.starting_items.join_next().await.is_some() {}
|
||||
startup_count
|
||||
}
|
||||
|
||||
pub(super) async fn reap_stale_startups(
|
||||
db: Arc<codex_state::StateRuntime>,
|
||||
job_id: &str,
|
||||
startup_tasks: &mut StartupTasks,
|
||||
runtime_timeout: Duration,
|
||||
) -> anyhow::Result<bool> {
|
||||
let stale_task_ids: Vec<_> = startup_tasks
|
||||
.launching_items
|
||||
.iter()
|
||||
.filter_map(|(task_id, item)| {
|
||||
(item.started_at.elapsed() >= runtime_timeout).then_some(*task_id)
|
||||
})
|
||||
.collect();
|
||||
if stale_task_ids.is_empty() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
for task_id in stale_task_ids {
|
||||
let Some(item) = startup_tasks.launching_items.remove(&task_id) else {
|
||||
continue;
|
||||
};
|
||||
item.abort_handle.abort();
|
||||
let error_message =
|
||||
format!("worker exceeded max runtime of {runtime_timeout:?} before startup completed");
|
||||
db_ops::retry_locked("mark_agent_job_item_failed_for_stale_startup", || async {
|
||||
db.mark_agent_job_item_failed(job_id, item.item_id.as_str(), error_message.as_str())
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
tracing::warn!(
|
||||
job_id,
|
||||
item_id = item.item_id,
|
||||
?task_id,
|
||||
"agent job worker startup timed out"
|
||||
);
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn handle_worker_startup_result(
|
||||
session: Arc<Session>,
|
||||
db: Arc<codex_state::StateRuntime>,
|
||||
job_id: &str,
|
||||
active_items: &mut HashMap<ThreadId, ActiveJobItem>,
|
||||
startup_tasks: &mut StartupTasks,
|
||||
result: Result<(TaskId, WorkerStartup), JoinError>,
|
||||
starting_items_len: usize,
|
||||
) -> anyhow::Result<StartupDrainResult> {
|
||||
match result {
|
||||
Ok((task_id, startup)) => {
|
||||
startup_tasks.launching_items.remove(&task_id);
|
||||
match startup.result {
|
||||
Ok(thread_id) => {
|
||||
let thread_id_str = thread_id.to_string();
|
||||
let assigned =
|
||||
db_ops::retry_locked("set_agent_job_item_thread_after_startup", || async {
|
||||
db.set_agent_job_item_thread(
|
||||
job_id,
|
||||
startup.item_id.as_str(),
|
||||
thread_id_str.as_str(),
|
||||
)
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
if !assigned {
|
||||
let _ = session
|
||||
.services
|
||||
.agent_control
|
||||
.request_live_agent_shutdown_preserving_thread(thread_id)
|
||||
.await;
|
||||
tracing::debug!(
|
||||
job_id,
|
||||
item_id = startup.item_id,
|
||||
thread_id = %thread_id,
|
||||
"agent job worker startup finished after item left running state"
|
||||
);
|
||||
return Ok(StartupDrainResult {
|
||||
progressed: true,
|
||||
agent_limit_reached: false,
|
||||
});
|
||||
}
|
||||
|
||||
let item_id = startup.item_id;
|
||||
active_items.insert(
|
||||
thread_id,
|
||||
ActiveJobItem {
|
||||
item_id: item_id.clone(),
|
||||
started_at: startup.started_at,
|
||||
status_rx: session
|
||||
.services
|
||||
.agent_control
|
||||
.subscribe_status(thread_id)
|
||||
.await
|
||||
.ok(),
|
||||
},
|
||||
);
|
||||
tracing::info!(
|
||||
job_id,
|
||||
item_id,
|
||||
thread_id = %thread_id,
|
||||
spawn_latency_ms = startup.spawn_latency.as_millis() as u64,
|
||||
active_items = active_items.len(),
|
||||
starting_items = starting_items_len,
|
||||
"agent job worker startup completed"
|
||||
);
|
||||
Ok(StartupDrainResult {
|
||||
progressed: true,
|
||||
agent_limit_reached: false,
|
||||
})
|
||||
}
|
||||
Err(CodexErr::AgentLimitReached { .. }) => {
|
||||
let _ =
|
||||
db_ops::retry_locked("mark_agent_job_item_pending_after_limit", || async {
|
||||
db.mark_agent_job_item_pending(
|
||||
job_id,
|
||||
startup.item_id.as_str(),
|
||||
/*error_message*/ None,
|
||||
)
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
tracing::debug!(
|
||||
job_id,
|
||||
item_id = startup.item_id,
|
||||
starting_items = starting_items_len,
|
||||
"agent job worker startup hit agent limit"
|
||||
);
|
||||
Ok(StartupDrainResult {
|
||||
progressed: true,
|
||||
agent_limit_reached: true,
|
||||
})
|
||||
}
|
||||
Err(err) => {
|
||||
let error_message = format!("failed to spawn worker: {err}");
|
||||
let _ = db_ops::retry_locked(
|
||||
"mark_agent_job_item_failed_after_spawn_error",
|
||||
|| async {
|
||||
db.mark_agent_job_item_failed(
|
||||
job_id,
|
||||
startup.item_id.as_str(),
|
||||
error_message.as_str(),
|
||||
)
|
||||
.await
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
tracing::warn!(
|
||||
job_id,
|
||||
item_id = startup.item_id,
|
||||
error = %err,
|
||||
"agent job worker startup failed"
|
||||
);
|
||||
Ok(StartupDrainResult {
|
||||
progressed: true,
|
||||
agent_limit_reached: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(join_error) => {
|
||||
let task_id = join_error.id();
|
||||
let Some(item) = startup_tasks.launching_items.remove(&task_id) else {
|
||||
return Ok(StartupDrainResult::default());
|
||||
};
|
||||
let error_message = format!("worker startup task failed: {join_error}");
|
||||
let _ = db_ops::retry_locked(
|
||||
"mark_agent_job_item_failed_after_startup_join_error",
|
||||
|| async {
|
||||
db.mark_agent_job_item_failed(
|
||||
job_id,
|
||||
item.item_id.as_str(),
|
||||
error_message.as_str(),
|
||||
)
|
||||
.await
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
tracing::warn!(
|
||||
job_id,
|
||||
item_id = item.item_id,
|
||||
error = %join_error,
|
||||
"agent job worker startup task exited unexpectedly"
|
||||
);
|
||||
Ok(StartupDrainResult {
|
||||
progressed: true,
|
||||
agent_limit_reached: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::ConfigBuilder;
|
||||
use crate::thread_manager::ThreadManager;
|
||||
use codex_exec_server::EnvironmentManager;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_state::AgentJobCreateParams;
|
||||
use codex_state::AgentJobItemCreateParams;
|
||||
use codex_state::AgentJobItemStatus;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tempfile::TempDir;
|
||||
use tokio::sync::Barrier;
|
||||
use tokio::time::timeout;
|
||||
|
||||
#[tokio::test]
|
||||
async fn spawn_tracked_startup_task_starts_multiple_workers_without_serial_waiting() {
|
||||
let mut startup_tasks = StartupTasks::default();
|
||||
let started = Arc::new(AtomicUsize::new(0));
|
||||
let barrier = Arc::new(Barrier::new(4));
|
||||
|
||||
for idx in 0..3usize {
|
||||
let started = Arc::clone(&started);
|
||||
let barrier = Arc::clone(&barrier);
|
||||
spawn_tracked_startup_task(
|
||||
&mut startup_tasks,
|
||||
format!("item-{idx}"),
|
||||
Instant::now(),
|
||||
async move {
|
||||
started.fetch_add(1, Ordering::SeqCst);
|
||||
barrier.wait().await;
|
||||
WorkerStartup {
|
||||
item_id: format!("item-{idx}"),
|
||||
started_at: Instant::now(),
|
||||
spawn_latency: Duration::ZERO,
|
||||
result: Err(CodexErr::ThreadNotFound(ThreadId::new())),
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
timeout(Duration::from_secs(1), async {
|
||||
while started.load(Ordering::SeqCst) < 3 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all startup tasks should begin running");
|
||||
|
||||
assert_eq!(startup_tasks.len(), 3);
|
||||
assert_eq!(startup_tasks.launching_items.len(), 3);
|
||||
|
||||
barrier.wait().await;
|
||||
|
||||
let mut outputs = Vec::new();
|
||||
while let Some(result) = startup_tasks.starting_items.join_next().await {
|
||||
outputs.push(result.expect("startup task should complete").item_id);
|
||||
}
|
||||
outputs.sort();
|
||||
assert_eq!(outputs, vec!["item-0", "item-1", "item-2"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wait_for_startup_or_status_change_returns_when_only_startups_are_pending()
|
||||
-> anyhow::Result<()> {
|
||||
let home = TempDir::new()?;
|
||||
let config = ConfigBuilder::without_managed_config_for_tests()
|
||||
.codex_home(home.path().to_path_buf())
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
let manager = ThreadManager::with_models_provider_and_home_for_tests(
|
||||
CodexAuth::from_api_key("dummy"),
|
||||
config.model_provider.clone(),
|
||||
config.codex_home.clone().to_path_buf(),
|
||||
Arc::new(EnvironmentManager::default_for_tests()),
|
||||
);
|
||||
let root = manager.start_thread(config.clone()).await?;
|
||||
let session = root.thread.codex.session.clone();
|
||||
let db = codex_state::StateRuntime::init(
|
||||
config.codex_home.clone().to_path_buf(),
|
||||
"test-provider".to_string(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut startup_tasks = StartupTasks::default();
|
||||
spawn_tracked_startup_task(
|
||||
&mut startup_tasks,
|
||||
"item-1".to_string(),
|
||||
Instant::now(),
|
||||
std::future::pending(),
|
||||
);
|
||||
|
||||
let mut active_items = HashMap::new();
|
||||
timeout(
|
||||
Duration::from_secs(1),
|
||||
wait_for_startup_or_status_change(
|
||||
session,
|
||||
db,
|
||||
"job-1",
|
||||
&mut active_items,
|
||||
&mut startup_tasks,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("wait should return so stale startup reaping can run")?;
|
||||
|
||||
assert!(active_items.is_empty());
|
||||
assert_eq!(startup_tasks.len(), 1);
|
||||
assert_eq!(startup_tasks.launching_items.len(), 1);
|
||||
|
||||
let aborted = abort_all_startups(&mut startup_tasks).await;
|
||||
assert_eq!(aborted, 1);
|
||||
let report = manager
|
||||
.shutdown_all_threads_bounded(Duration::from_secs(10))
|
||||
.await;
|
||||
assert_eq!(report.submit_failed, Vec::new());
|
||||
assert_eq!(report.timed_out, Vec::new());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn drain_ready_startups_reports_agent_limit_and_requeues_item() -> anyhow::Result<()> {
|
||||
let home = TempDir::new()?;
|
||||
let config = ConfigBuilder::without_managed_config_for_tests()
|
||||
.codex_home(home.path().to_path_buf())
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
let manager = ThreadManager::with_models_provider_and_home_for_tests(
|
||||
CodexAuth::from_api_key("dummy"),
|
||||
config.model_provider.clone(),
|
||||
config.codex_home.clone().to_path_buf(),
|
||||
Arc::new(EnvironmentManager::default_for_tests()),
|
||||
);
|
||||
let root = manager.start_thread(config.clone()).await?;
|
||||
let session = root.thread.codex.session.clone();
|
||||
|
||||
let db = codex_state::StateRuntime::init(
|
||||
config.codex_home.clone().to_path_buf(),
|
||||
"test-provider".to_string(),
|
||||
)
|
||||
.await?;
|
||||
let job_id = "job-1".to_string();
|
||||
let item_id = "item-1".to_string();
|
||||
db.create_agent_job(
|
||||
&AgentJobCreateParams {
|
||||
id: job_id.clone(),
|
||||
name: "test-job".to_string(),
|
||||
instruction: "Return a result".to_string(),
|
||||
auto_export: true,
|
||||
max_runtime_seconds: None,
|
||||
output_schema_json: None,
|
||||
input_headers: vec!["path".to_string()],
|
||||
input_csv_path: "/tmp/in.csv".to_string(),
|
||||
output_csv_path: "/tmp/out.csv".to_string(),
|
||||
},
|
||||
&[AgentJobItemCreateParams {
|
||||
item_id: item_id.clone(),
|
||||
row_index: 0,
|
||||
source_id: None,
|
||||
row_json: json!({"path":"file-1"}),
|
||||
}],
|
||||
)
|
||||
.await?;
|
||||
db.mark_agent_job_running(job_id.as_str()).await?;
|
||||
|
||||
let mut startup_tasks = StartupTasks::default();
|
||||
let task_item_id = item_id.clone();
|
||||
spawn_tracked_startup_task(
|
||||
&mut startup_tasks,
|
||||
item_id.clone(),
|
||||
Instant::now(),
|
||||
async move {
|
||||
WorkerStartup {
|
||||
item_id: task_item_id,
|
||||
started_at: Instant::now(),
|
||||
spawn_latency: Duration::ZERO,
|
||||
result: Err(CodexErr::AgentLimitReached { max_threads: 1 }),
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let mut active_items = HashMap::new();
|
||||
let outcome = timeout(Duration::from_secs(1), async {
|
||||
loop {
|
||||
let outcome = drain_ready_startups(
|
||||
session.clone(),
|
||||
db.clone(),
|
||||
job_id.as_str(),
|
||||
&mut active_items,
|
||||
&mut startup_tasks,
|
||||
)
|
||||
.await?;
|
||||
if outcome.progressed {
|
||||
break Ok::<StartupDrainResult, anyhow::Error>(outcome);
|
||||
}
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
})
|
||||
.await??;
|
||||
|
||||
assert_eq!(
|
||||
outcome,
|
||||
StartupDrainResult {
|
||||
progressed: true,
|
||||
agent_limit_reached: true,
|
||||
}
|
||||
);
|
||||
assert!(active_items.is_empty());
|
||||
assert!(startup_tasks.is_empty());
|
||||
|
||||
let item = db
|
||||
.get_agent_job_item(job_id.as_str(), item_id.as_str())
|
||||
.await?
|
||||
.expect("job item should exist");
|
||||
assert_eq!(item.status, AgentJobItemStatus::Pending);
|
||||
assert_eq!(item.assigned_thread_id, None);
|
||||
|
||||
let report = manager
|
||||
.shutdown_all_threads_bounded(Duration::from_secs(10))
|
||||
.await;
|
||||
assert_eq!(report.submit_failed, Vec::new());
|
||||
assert_eq!(report.timed_out, Vec::new());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -60,3 +60,31 @@ fn ensure_unique_headers_rejects_duplicates() {
|
||||
FunctionCallError::RespondToModel("csv header path is duplicated".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_wait_when_agent_limit_blocks_refill() {
|
||||
assert_eq!(
|
||||
should_wait_for_scheduler_change(
|
||||
/*progressed*/ false, /*agent_limit_reached*/ false,
|
||||
),
|
||||
true
|
||||
);
|
||||
assert_eq!(
|
||||
should_wait_for_scheduler_change(
|
||||
/*progressed*/ true, /*agent_limit_reached*/ false,
|
||||
),
|
||||
false
|
||||
);
|
||||
assert_eq!(
|
||||
should_wait_for_scheduler_change(
|
||||
/*progressed*/ false, /*agent_limit_reached*/ true,
|
||||
),
|
||||
true
|
||||
);
|
||||
assert_eq!(
|
||||
should_wait_for_scheduler_change(
|
||||
/*progressed*/ true, /*agent_limit_reached*/ true,
|
||||
),
|
||||
true
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::task::JoinError;
|
||||
use tokio_util::either::Either;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
@@ -95,6 +96,7 @@ impl ToolCallRuntime {
|
||||
let invocation_cancellation_token = cancellation_token.clone();
|
||||
let started = Instant::now();
|
||||
let display_name = call.tool_name.display();
|
||||
let join_error_call = call.clone();
|
||||
|
||||
let dispatch_span = trace_span!(
|
||||
"dispatch_tool_call_with_code_mode_result",
|
||||
@@ -136,13 +138,38 @@ impl ToolCallRuntime {
|
||||
|
||||
async move {
|
||||
handle.await.map_err(|err| {
|
||||
FunctionCallError::Fatal(format!("tool task failed to receive: {err:?}"))
|
||||
tool_task_join_error_to_function_call_error(
|
||||
err,
|
||||
&join_error_call,
|
||||
started.elapsed(),
|
||||
)
|
||||
})?
|
||||
}
|
||||
.in_current_span()
|
||||
}
|
||||
}
|
||||
|
||||
fn tool_task_join_error_to_function_call_error(
|
||||
err: JoinError,
|
||||
call: &ToolCall,
|
||||
elapsed: std::time::Duration,
|
||||
) -> FunctionCallError {
|
||||
if err.is_cancelled() {
|
||||
let secs = elapsed.as_secs_f32().max(0.1);
|
||||
tracing::warn!(
|
||||
tool_name = %call.tool_name,
|
||||
call_id = call.call_id,
|
||||
elapsed_seconds = secs,
|
||||
"tool task was cancelled before delivering a response"
|
||||
);
|
||||
return FunctionCallError::RespondToModel(format!(
|
||||
"tool execution interrupted after {secs:.1}s"
|
||||
));
|
||||
}
|
||||
|
||||
FunctionCallError::Fatal(format!("tool task failed to receive: {err:?}"))
|
||||
}
|
||||
|
||||
impl ToolCallRuntime {
|
||||
fn failure_response(call: ToolCall, err: FunctionCallError) -> ResponseInputItem {
|
||||
let message = err.to_string();
|
||||
@@ -194,3 +221,60 @@ impl ToolCallRuntime {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
||||
fn test_tool_call() -> ToolCall {
|
||||
ToolCall {
|
||||
tool_name: "spawn_agents_on_csv".into(),
|
||||
call_id: "call-1".to_string(),
|
||||
payload: ToolPayload::Function {
|
||||
arguments: json!({"csv_path":"in.csv","instruction":"test"}).to_string(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancelled_tool_task_join_error_becomes_model_response() {
|
||||
let handle = tokio::spawn(async {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(60)).await;
|
||||
});
|
||||
handle.abort();
|
||||
let err = handle.await.expect_err("task should be cancelled");
|
||||
|
||||
let mapped = tool_task_join_error_to_function_call_error(
|
||||
err,
|
||||
&test_tool_call(),
|
||||
std::time::Duration::from_millis(200),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
mapped,
|
||||
FunctionCallError::RespondToModel("tool execution interrupted after 0.2s".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn panicked_tool_task_join_error_stays_fatal() {
|
||||
let handle = tokio::spawn(async {
|
||||
panic!("boom");
|
||||
});
|
||||
let err = handle.await.expect_err("task should panic");
|
||||
|
||||
let mapped = tool_task_join_error_to_function_call_error(
|
||||
err,
|
||||
&test_tool_call(),
|
||||
std::time::Duration::from_millis(200),
|
||||
);
|
||||
|
||||
let FunctionCallError::Fatal(message) = mapped else {
|
||||
panic!("panic join errors should remain fatal");
|
||||
};
|
||||
assert!(message.contains("tool task failed to receive"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use codex_features::Feature;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
@@ -15,6 +16,9 @@ use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
use tokio::time::timeout;
|
||||
use wiremock::Mock;
|
||||
use wiremock::Respond;
|
||||
use wiremock::ResponseTemplate;
|
||||
@@ -53,6 +57,155 @@ impl StopAfterFirstResponder {
|
||||
}
|
||||
}
|
||||
|
||||
struct DelayedWorkerAfterReportResponder {
|
||||
spawn_args_json: String,
|
||||
seen_main: AtomicBool,
|
||||
worker_calls: Arc<AtomicUsize>,
|
||||
delayed_worker_output: AtomicBool,
|
||||
worker_output_delay: Duration,
|
||||
}
|
||||
|
||||
struct WorkerNeverCompletesAfterReportResponder {
|
||||
spawn_args_json: String,
|
||||
seen_main: AtomicBool,
|
||||
worker_calls: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl DelayedWorkerAfterReportResponder {
|
||||
fn new(
|
||||
spawn_args_json: String,
|
||||
worker_calls: Arc<AtomicUsize>,
|
||||
worker_output_delay: Duration,
|
||||
) -> Self {
|
||||
Self {
|
||||
spawn_args_json,
|
||||
seen_main: AtomicBool::new(false),
|
||||
worker_calls,
|
||||
delayed_worker_output: AtomicBool::new(false),
|
||||
worker_output_delay,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WorkerNeverCompletesAfterReportResponder {
|
||||
fn new(spawn_args_json: String, worker_calls: Arc<AtomicUsize>) -> Self {
|
||||
Self {
|
||||
spawn_args_json,
|
||||
seen_main: AtomicBool::new(false),
|
||||
worker_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Respond for WorkerNeverCompletesAfterReportResponder {
|
||||
fn respond(&self, request: &wiremock::Request) -> ResponseTemplate {
|
||||
let body_bytes = decode_body_bytes(request);
|
||||
let body: Value = serde_json::from_slice(&body_bytes).unwrap_or(Value::Null);
|
||||
|
||||
let call_output_ids = function_call_output_call_ids(&body);
|
||||
if !call_output_ids.is_empty() {
|
||||
if call_output_ids
|
||||
.iter()
|
||||
.any(|call_id| call_id.starts_with("call-worker-"))
|
||||
{
|
||||
return sse_response(sse(vec![
|
||||
ev_response_created("resp-worker-post-report"),
|
||||
ev_assistant_message("msg-worker-post-report", "no bugs found"),
|
||||
]));
|
||||
}
|
||||
return sse_response(sse(vec![
|
||||
ev_response_created("resp-tool"),
|
||||
ev_completed("resp-tool"),
|
||||
]));
|
||||
}
|
||||
|
||||
if let Some((job_id, item_id)) = extract_job_and_item(&body) {
|
||||
let call_index = self.worker_calls.fetch_add(1, Ordering::SeqCst);
|
||||
let call_id = format!("call-worker-{call_index}");
|
||||
let args = json!({
|
||||
"job_id": job_id,
|
||||
"item_id": item_id,
|
||||
"result": { "item_id": item_id }
|
||||
});
|
||||
let args_json = serde_json::to_string(&args).unwrap_or_else(|err| {
|
||||
panic!("worker args serialize: {err}");
|
||||
});
|
||||
return sse_response(sse(vec![
|
||||
ev_response_created("resp-worker"),
|
||||
ev_function_call(&call_id, "report_agent_job_result", &args_json),
|
||||
ev_completed("resp-worker"),
|
||||
]));
|
||||
}
|
||||
|
||||
if !self.seen_main.swap(true, Ordering::SeqCst) {
|
||||
return sse_response(sse(vec![
|
||||
ev_response_created("resp-main"),
|
||||
ev_function_call("call-spawn", "spawn_agents_on_csv", &self.spawn_args_json),
|
||||
ev_completed("resp-main"),
|
||||
]));
|
||||
}
|
||||
|
||||
sse_response(sse(vec![
|
||||
ev_response_created("resp-default"),
|
||||
ev_completed("resp-default"),
|
||||
]))
|
||||
}
|
||||
}
|
||||
|
||||
impl Respond for DelayedWorkerAfterReportResponder {
|
||||
fn respond(&self, request: &wiremock::Request) -> ResponseTemplate {
|
||||
let body_bytes = decode_body_bytes(request);
|
||||
let body: Value = serde_json::from_slice(&body_bytes).unwrap_or(Value::Null);
|
||||
|
||||
let call_output_ids = function_call_output_call_ids(&body);
|
||||
if !call_output_ids.is_empty() {
|
||||
let response = sse_response(sse(vec![
|
||||
ev_response_created("resp-tool"),
|
||||
ev_completed("resp-tool"),
|
||||
]));
|
||||
if call_output_ids
|
||||
.iter()
|
||||
.any(|call_id| call_id == "call-worker-0")
|
||||
&& !self.delayed_worker_output.swap(true, Ordering::SeqCst)
|
||||
{
|
||||
return response.set_delay(self.worker_output_delay);
|
||||
}
|
||||
return response;
|
||||
}
|
||||
|
||||
if let Some((job_id, item_id)) = extract_job_and_item(&body) {
|
||||
let call_index = self.worker_calls.fetch_add(1, Ordering::SeqCst);
|
||||
let call_id = format!("call-worker-{call_index}");
|
||||
let args = json!({
|
||||
"job_id": job_id,
|
||||
"item_id": item_id,
|
||||
"result": { "item_id": item_id }
|
||||
});
|
||||
let args_json = serde_json::to_string(&args).unwrap_or_else(|err| {
|
||||
panic!("worker args serialize: {err}");
|
||||
});
|
||||
return sse_response(sse(vec![
|
||||
ev_response_created("resp-worker"),
|
||||
ev_function_call(&call_id, "report_agent_job_result", &args_json),
|
||||
ev_completed("resp-worker"),
|
||||
]));
|
||||
}
|
||||
|
||||
if !self.seen_main.swap(true, Ordering::SeqCst) {
|
||||
return sse_response(sse(vec![
|
||||
ev_response_created("resp-main"),
|
||||
ev_function_call("call-spawn", "spawn_agents_on_csv", &self.spawn_args_json),
|
||||
ev_completed("resp-main"),
|
||||
]));
|
||||
}
|
||||
|
||||
sse_response(sse(vec![
|
||||
ev_response_created("resp-default"),
|
||||
ev_completed("resp-default"),
|
||||
]))
|
||||
}
|
||||
}
|
||||
|
||||
impl Respond for StopAfterFirstResponder {
|
||||
fn respond(&self, request: &wiremock::Request) -> ResponseTemplate {
|
||||
let body_bytes = decode_body_bytes(request);
|
||||
@@ -176,6 +329,17 @@ fn has_function_call_output(body: &Value) -> bool {
|
||||
})
|
||||
}
|
||||
|
||||
fn function_call_output_call_ids(body: &Value) -> Vec<String> {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.filter(|item| item.get("type").and_then(Value::as_str) == Some("function_call_output"))
|
||||
.filter_map(|item| item.get("call_id").and_then(Value::as_str))
|
||||
.map(str::to_string)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extract_job_and_item(body: &Value) -> Option<(String, String)> {
|
||||
let texts = message_input_texts(body);
|
||||
let mut combined = texts.join("\n");
|
||||
@@ -446,3 +610,205 @@ async fn spawn_agents_on_csv_stop_halts_future_items() -> Result<()> {
|
||||
assert_eq!(worker_calls.load(Ordering::SeqCst), 1);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn spawn_agents_on_csv_reclaims_slot_after_report_before_worker_completes() -> Result<()> {
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config
|
||||
.features
|
||||
.enable(Feature::SpawnCsv)
|
||||
.expect("test config should allow feature update");
|
||||
config
|
||||
.features
|
||||
.enable(Feature::Sqlite)
|
||||
.expect("test config should allow feature update");
|
||||
});
|
||||
let test = Arc::new(builder.build(&server).await?);
|
||||
|
||||
let input_path = test.cwd_path().join("agent_jobs_reclaim.csv");
|
||||
let output_path = test.cwd_path().join("agent_jobs_reclaim_out.csv");
|
||||
fs::write(&input_path, "path\nfile-1\nfile-2\n")?;
|
||||
|
||||
let args = json!({
|
||||
"csv_path": input_path.display().to_string(),
|
||||
"instruction": "Return {path}",
|
||||
"output_csv_path": output_path.display().to_string(),
|
||||
"max_concurrency": 1,
|
||||
});
|
||||
let args_json = serde_json::to_string(&args)?;
|
||||
|
||||
let worker_calls = Arc::new(AtomicUsize::new(0));
|
||||
let responder = DelayedWorkerAfterReportResponder::new(
|
||||
args_json,
|
||||
worker_calls.clone(),
|
||||
Duration::from_secs(2),
|
||||
);
|
||||
Mock::given(method("POST"))
|
||||
.and(path_regex(".*/responses$"))
|
||||
.respond_with(responder)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let submit = tokio::spawn({
|
||||
let test = Arc::clone(&test);
|
||||
async move { test.submit_turn("run batch job").await }
|
||||
});
|
||||
|
||||
timeout(Duration::from_secs(1), async {
|
||||
while worker_calls.load(Ordering::SeqCst) < 2 {
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("second worker should start before first worker's delayed post-report response");
|
||||
|
||||
submit.await??;
|
||||
|
||||
let output = fs::read_to_string(&output_path)?;
|
||||
assert_eq!(output.lines().skip(1).count(), 2);
|
||||
|
||||
let requests = worker_calls.load(Ordering::SeqCst);
|
||||
assert_eq!(requests, 2);
|
||||
|
||||
let root_requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("wiremock should capture requests");
|
||||
let saw_spawn_output = root_requests.iter().any(|request| {
|
||||
let body_bytes = decode_body_bytes(request);
|
||||
let body: Value = serde_json::from_slice(&body_bytes).unwrap_or(Value::Null);
|
||||
function_call_output_call_ids(&body)
|
||||
.iter()
|
||||
.any(|call_id| call_id == "call-spawn")
|
||||
});
|
||||
assert!(saw_spawn_output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn spawn_agents_on_csv_finishes_after_rows_complete_even_if_worker_exit_lags() -> Result<()> {
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config
|
||||
.features
|
||||
.enable(Feature::SpawnCsv)
|
||||
.expect("test config should allow feature update");
|
||||
config
|
||||
.features
|
||||
.enable(Feature::Sqlite)
|
||||
.expect("test config should allow feature update");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let input_path = test.cwd_path().join("agent_jobs_terminal.csv");
|
||||
let output_path = test.cwd_path().join("agent_jobs_terminal_out.csv");
|
||||
fs::write(&input_path, "path\nfile-1\n")?;
|
||||
|
||||
let args = json!({
|
||||
"csv_path": input_path.display().to_string(),
|
||||
"instruction": "Return {path}",
|
||||
"output_csv_path": output_path.display().to_string(),
|
||||
"max_concurrency": 1,
|
||||
});
|
||||
let args_json = serde_json::to_string(&args)?;
|
||||
|
||||
let worker_calls = Arc::new(AtomicUsize::new(0));
|
||||
let responder = DelayedWorkerAfterReportResponder::new(
|
||||
args_json,
|
||||
worker_calls.clone(),
|
||||
Duration::from_secs(30),
|
||||
);
|
||||
Mock::given(method("POST"))
|
||||
.and(path_regex(".*/responses$"))
|
||||
.respond_with(responder)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
timeout(Duration::from_secs(5), test.submit_turn("run batch job"))
|
||||
.await
|
||||
.expect("root turn should finalize without waiting for the delayed worker exit")?;
|
||||
|
||||
let output = fs::read_to_string(&output_path)?;
|
||||
let rows: Vec<&str> = output.lines().skip(1).collect();
|
||||
assert_eq!(rows.len(), 1);
|
||||
|
||||
let job_id = rows
|
||||
.first()
|
||||
.and_then(|line| {
|
||||
parse_simple_csv_line(line)
|
||||
.iter()
|
||||
.find(|value| value.len() == 36)
|
||||
.cloned()
|
||||
})
|
||||
.expect("job_id from csv");
|
||||
let db = test.codex.state_db().expect("state db");
|
||||
let job = db.get_agent_job(job_id.as_str()).await?.expect("job");
|
||||
assert_eq!(job.status, codex_state::AgentJobStatus::Completed);
|
||||
assert!(job.completed_at.is_some());
|
||||
assert_eq!(worker_calls.load(Ordering::SeqCst), 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn spawn_agents_on_csv_finishes_when_worker_reports_but_never_completes_turn() -> Result<()> {
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config
|
||||
.features
|
||||
.enable(Feature::SpawnCsv)
|
||||
.expect("test config should allow feature update");
|
||||
config
|
||||
.features
|
||||
.enable(Feature::Sqlite)
|
||||
.expect("test config should allow feature update");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let input_path = test.cwd_path().join("agent_jobs_incomplete_turn.csv");
|
||||
let output_path = test.cwd_path().join("agent_jobs_incomplete_turn_out.csv");
|
||||
fs::write(&input_path, "path\nfile-1\n")?;
|
||||
|
||||
let args = json!({
|
||||
"csv_path": input_path.display().to_string(),
|
||||
"instruction": "Return {path}",
|
||||
"output_csv_path": output_path.display().to_string(),
|
||||
"max_concurrency": 1,
|
||||
});
|
||||
let args_json = serde_json::to_string(&args)?;
|
||||
|
||||
let worker_calls = Arc::new(AtomicUsize::new(0));
|
||||
let responder = WorkerNeverCompletesAfterReportResponder::new(args_json, worker_calls.clone());
|
||||
Mock::given(method("POST"))
|
||||
.and(path_regex(".*/responses$"))
|
||||
.respond_with(responder)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
timeout(Duration::from_secs(5), test.submit_turn("run batch job"))
|
||||
.await
|
||||
.expect("root turn should finalize even if a worker never reaches task_complete")?;
|
||||
|
||||
let output = fs::read_to_string(&output_path)?;
|
||||
let rows: Vec<&str> = output.lines().skip(1).collect();
|
||||
assert_eq!(rows.len(), 1);
|
||||
|
||||
let job_id = rows
|
||||
.first()
|
||||
.and_then(|line| {
|
||||
parse_simple_csv_line(line)
|
||||
.iter()
|
||||
.find(|value| value.len() == 36)
|
||||
.cloned()
|
||||
})
|
||||
.expect("job_id from csv");
|
||||
let db = test.codex.state_db().expect("state db");
|
||||
let job = db.get_agent_job(job_id.as_str()).await?.expect("job");
|
||||
assert_eq!(job.status, codex_state::AgentJobStatus::Completed);
|
||||
assert!(job.completed_at.is_some());
|
||||
assert_eq!(worker_calls.load(Ordering::SeqCst), 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -129,12 +129,15 @@ pub use exec_events::Usage;
|
||||
pub use exec_events::WebSearchItem;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::io::IsTerminal;
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
use supports_color::Stream;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
use tracing::Instrument;
|
||||
use tracing::error;
|
||||
use tracing::field;
|
||||
@@ -149,6 +152,8 @@ use crate::cli::Command as ExecCommand;
|
||||
use crate::event_processor::EventProcessor;
|
||||
|
||||
const DEFAULT_ANALYTICS_ENABLED: bool = true;
|
||||
const THREAD_UNSUBSCRIBE_TIMEOUT: Duration = Duration::from_secs(1);
|
||||
const TURN_COMPLETED_BACKFILL_TIMEOUT: Duration = Duration::from_secs(2);
|
||||
|
||||
enum InitialOperation {
|
||||
UserTurn {
|
||||
@@ -1128,9 +1133,9 @@ async fn maybe_backfill_turn_completed_items(
|
||||
notification: &mut ServerNotification,
|
||||
) {
|
||||
// In-process delivery may drop non-terminal item notifications under backpressure while still
|
||||
// guaranteeing `turn/completed`. Because app-server currently emits that completion with an
|
||||
// empty `turn.items`, exec does one last `thread/read` here so human/json output can recover
|
||||
// the final message and reconcile any still-running items before shutdown.
|
||||
// guaranteeing `turn/completed`. App-server now tries to include terminal turn items directly
|
||||
// on that notification, but exec keeps this bounded `thread/read` fallback for older or
|
||||
// degraded paths that still arrive with an empty `turn.items`.
|
||||
if !should_backfill_turn_completed_items(thread_ephemeral, notification) {
|
||||
return;
|
||||
}
|
||||
@@ -1139,16 +1144,23 @@ async fn maybe_backfill_turn_completed_items(
|
||||
return;
|
||||
};
|
||||
|
||||
let response = send_request_with_response::<ThreadReadResponse>(
|
||||
client,
|
||||
ClientRequest::ThreadRead {
|
||||
request_id: request_ids.next(),
|
||||
params: ThreadReadParams {
|
||||
thread_id: payload.thread_id.clone(),
|
||||
include_turns: true,
|
||||
},
|
||||
},
|
||||
// This runs inline on exec's event loop immediately after `TurnCompleted`.
|
||||
// Bound the request so a backpressured in-process event queue cannot deadlock
|
||||
// shutdown by blocking `thread/read` forever behind unrelated lossless events.
|
||||
let response = await_request_with_timeout(
|
||||
"thread/read",
|
||||
TURN_COMPLETED_BACKFILL_TIMEOUT,
|
||||
send_request_with_response::<ThreadReadResponse>(
|
||||
client,
|
||||
ClientRequest::ThreadRead {
|
||||
request_id: request_ids.next(),
|
||||
params: ThreadReadParams {
|
||||
thread_id: payload.thread_id.clone(),
|
||||
include_turns: true,
|
||||
},
|
||||
},
|
||||
"thread/read",
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1159,7 +1171,11 @@ async fn maybe_backfill_turn_completed_items(
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("thread/read failed while backfilling turn items for turn completion: {err}");
|
||||
warn!(
|
||||
thread_id = %payload.thread_id,
|
||||
turn_id = %payload.turn.id,
|
||||
"thread/read failed while backfilling turn items for turn completion: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1375,9 +1391,30 @@ async fn request_shutdown(
|
||||
thread_id: thread_id.to_string(),
|
||||
},
|
||||
};
|
||||
send_request_with_response::<ThreadUnsubscribeResponse>(client, request, "thread/unsubscribe")
|
||||
.await
|
||||
.map(|_| ())
|
||||
await_request_with_timeout(
|
||||
"thread/unsubscribe",
|
||||
THREAD_UNSUBSCRIBE_TIMEOUT,
|
||||
send_request_with_response::<ThreadUnsubscribeResponse>(
|
||||
client,
|
||||
request,
|
||||
"thread/unsubscribe",
|
||||
),
|
||||
)
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
|
||||
async fn await_request_with_timeout<T>(
|
||||
request_name: &str,
|
||||
timeout_duration: Duration,
|
||||
request: impl Future<Output = Result<T, String>>,
|
||||
) -> Result<T, String> {
|
||||
match timeout(timeout_duration, request).await {
|
||||
Ok(result) => result,
|
||||
Err(_) => Err(format!(
|
||||
"{request_name} timed out after {timeout_duration:?}"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn resolve_server_request(
|
||||
|
||||
@@ -9,6 +9,7 @@ use opentelemetry::trace::TracerProvider as _;
|
||||
use opentelemetry_sdk::trace::SdkTracerProvider;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::tempdir;
|
||||
use tokio::time::sleep;
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
fn test_tracing_subscriber() -> impl tracing::Subscriber + Send + Sync {
|
||||
@@ -323,6 +324,33 @@ fn should_backfill_turn_completed_items_skips_ephemeral_threads() {
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_backfill_turn_completed_items_skips_notifications_with_items() {
|
||||
let notification =
|
||||
ServerNotification::TurnCompleted(codex_app_server_protocol::TurnCompletedNotification {
|
||||
thread_id: "thread-1".to_string(),
|
||||
turn: codex_app_server_protocol::Turn {
|
||||
id: "turn-1".to_string(),
|
||||
items: vec![AppServerThreadItem::AgentMessage {
|
||||
id: "msg-1".to_string(),
|
||||
text: "finished".to_string(),
|
||||
phase: None,
|
||||
memory_citation: None,
|
||||
}],
|
||||
status: codex_app_server_protocol::TurnStatus::Completed,
|
||||
error: None,
|
||||
started_at: None,
|
||||
completed_at: None,
|
||||
duration_ms: None,
|
||||
},
|
||||
});
|
||||
|
||||
assert!(!should_backfill_turn_completed_items(
|
||||
/*thread_ephemeral*/ false,
|
||||
¬ification
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canceled_mcp_server_elicitation_response_uses_cancel_action() {
|
||||
let value = canceled_mcp_server_elicitation_response()
|
||||
@@ -440,3 +468,45 @@ fn session_configured_from_thread_response_uses_review_policy_from_response() {
|
||||
ApprovalsReviewer::GuardianSubagent
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn await_request_with_timeout_returns_ready_result() {
|
||||
let result = await_request_with_timeout(
|
||||
"test/request",
|
||||
std::time::Duration::from_millis(50),
|
||||
async { Ok::<_, String>(123usize) },
|
||||
)
|
||||
.await
|
||||
.expect("ready request should succeed");
|
||||
|
||||
assert_eq!(result, 123);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn await_request_with_timeout_errors_when_request_stalls() {
|
||||
let err = await_request_with_timeout(
|
||||
"test/request",
|
||||
std::time::Duration::from_millis(10),
|
||||
async {
|
||||
sleep(std::time::Duration::from_millis(50)).await;
|
||||
Ok::<_, String>(())
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect_err("stalled request should time out");
|
||||
|
||||
assert_eq!(err, "test/request timed out after 10ms");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_completed_backfill_times_out_instead_of_hanging_exec() {
|
||||
let err =
|
||||
await_request_with_timeout("thread/read", std::time::Duration::from_millis(10), async {
|
||||
sleep(std::time::Duration::from_millis(50)).await;
|
||||
Ok::<_, String>(())
|
||||
})
|
||||
.await
|
||||
.expect_err("stalled turn/read backfill should time out");
|
||||
|
||||
assert_eq!(err, "thread/read timed out after 10ms");
|
||||
}
|
||||
|
||||
@@ -14,20 +14,71 @@ pub use codex_state::LogEntry;
|
||||
use codex_state::ThreadMetadataBuilder;
|
||||
use codex_utils_path::normalize_for_path_comparison;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::Mutex;
|
||||
use tokio::sync::OnceCell;
|
||||
use tracing::warn;
|
||||
|
||||
/// Core-facing handle to the SQLite-backed state runtime.
|
||||
pub type StateDbHandle = Arc<codex_state::StateRuntime>;
|
||||
|
||||
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
||||
struct StateDbCacheKey {
|
||||
sqlite_home: PathBuf,
|
||||
default_provider: String,
|
||||
}
|
||||
|
||||
type StateDbRuntimeCell = Arc<OnceCell<StateDbHandle>>;
|
||||
|
||||
static STATE_DB_RUNTIME_CACHE: LazyLock<Mutex<HashMap<StateDbCacheKey, StateDbRuntimeCell>>> =
|
||||
LazyLock::new(|| Mutex::new(HashMap::new()));
|
||||
|
||||
fn cache_key(sqlite_home: &Path, default_provider: &str) -> StateDbCacheKey {
|
||||
StateDbCacheKey {
|
||||
sqlite_home: normalize_for_path_comparison(sqlite_home)
|
||||
.unwrap_or_else(|_| sqlite_home.to_path_buf()),
|
||||
default_provider: default_provider.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn cached_runtime_cell(sqlite_home: &Path, default_provider: &str) -> StateDbRuntimeCell {
|
||||
let key = cache_key(sqlite_home, default_provider);
|
||||
let mut cache = STATE_DB_RUNTIME_CACHE
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
if let Some(cell) = cache.get(&key) {
|
||||
return Arc::clone(cell);
|
||||
}
|
||||
|
||||
let cell = Arc::new(OnceCell::new());
|
||||
cache.insert(key, Arc::clone(&cell));
|
||||
cell
|
||||
}
|
||||
|
||||
async fn shared_state_db_runtime(
|
||||
sqlite_home: &Path,
|
||||
default_provider: &str,
|
||||
) -> anyhow::Result<StateDbHandle> {
|
||||
let cell = cached_runtime_cell(sqlite_home, default_provider);
|
||||
let runtime = cell
|
||||
.get_or_try_init(|| async {
|
||||
codex_state::StateRuntime::init(sqlite_home.to_path_buf(), default_provider.to_string())
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
Ok(Arc::clone(runtime))
|
||||
}
|
||||
|
||||
/// Initialize the state runtime for thread state persistence and backfill checks.
|
||||
pub async fn init(config: &impl RolloutConfigView) -> Option<StateDbHandle> {
|
||||
let config = RolloutConfig::from_view(config);
|
||||
let runtime = match codex_state::StateRuntime::init(
|
||||
config.sqlite_home.clone(),
|
||||
config.model_provider_id.clone(),
|
||||
let runtime = match shared_state_db_runtime(
|
||||
config.sqlite_home.as_path(),
|
||||
config.model_provider_id.as_str(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -66,12 +117,9 @@ pub async fn get_state_db(config: &impl RolloutConfigView) -> Option<StateDbHand
|
||||
if !tokio::fs::try_exists(&state_path).await.unwrap_or(false) {
|
||||
return None;
|
||||
}
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
config.sqlite_home().to_path_buf(),
|
||||
config.model_provider_id().to_string(),
|
||||
)
|
||||
.await
|
||||
.ok()?;
|
||||
let runtime = shared_state_db_runtime(config.sqlite_home(), config.model_provider_id())
|
||||
.await
|
||||
.ok()?;
|
||||
require_backfill_complete(runtime, config.sqlite_home()).await
|
||||
}
|
||||
|
||||
@@ -83,10 +131,9 @@ pub async fn open_if_present(codex_home: &Path, default_provider: &str) -> Optio
|
||||
if !tokio::fs::try_exists(&db_path).await.unwrap_or(false) {
|
||||
return None;
|
||||
}
|
||||
let runtime =
|
||||
codex_state::StateRuntime::init(codex_home.to_path_buf(), default_provider.to_string())
|
||||
.await
|
||||
.ok()?;
|
||||
let runtime = shared_state_db_runtime(codex_home, default_provider)
|
||||
.await
|
||||
.ok()?;
|
||||
require_backfill_complete(runtime, codex_home).await
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
#![allow(warnings, clippy::all)]
|
||||
|
||||
use super::*;
|
||||
use crate::config::RolloutConfig;
|
||||
use crate::list::parse_cursor;
|
||||
use chrono::DateTime;
|
||||
use chrono::NaiveDateTime;
|
||||
use chrono::Timelike;
|
||||
use chrono::Utc;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn cursor_to_anchor_normalizes_timestamp_format() {
|
||||
@@ -22,3 +26,82 @@ fn cursor_to_anchor_normalizes_timestamp_format() {
|
||||
|
||||
assert_eq!(anchor.ts, expected_ts);
|
||||
}
|
||||
|
||||
fn test_config(home: &Path) -> RolloutConfig {
|
||||
RolloutConfig {
|
||||
codex_home: home.to_path_buf(),
|
||||
sqlite_home: home.to_path_buf(),
|
||||
cwd: home.to_path_buf(),
|
||||
model_provider_id: "test-provider".to_string(),
|
||||
generate_memories: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn init_reuses_cached_runtime_for_same_home() {
|
||||
let temp = TempDir::new().expect("temp dir");
|
||||
let config = test_config(temp.path());
|
||||
|
||||
let first = init(&config).await.expect("state db init should succeed");
|
||||
first
|
||||
.mark_backfill_complete(/*last_watermark*/ None)
|
||||
.await
|
||||
.expect("backfill should be marked complete");
|
||||
let second = init(&config)
|
||||
.await
|
||||
.expect("cached state db init should succeed");
|
||||
|
||||
assert!(Arc::ptr_eq(&first, &second));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn get_state_db_reuses_cached_runtime() {
|
||||
let temp = TempDir::new().expect("temp dir");
|
||||
let config = test_config(temp.path());
|
||||
|
||||
let first = init(&config).await.expect("state db init should succeed");
|
||||
first
|
||||
.mark_backfill_complete(/*last_watermark*/ None)
|
||||
.await
|
||||
.expect("backfill should be marked complete");
|
||||
|
||||
let reopened = get_state_db(&config)
|
||||
.await
|
||||
.expect("cached state db should be returned");
|
||||
|
||||
assert!(Arc::ptr_eq(&first, &reopened));
|
||||
assert_eq!(reopened.codex_home(), config.sqlite_home.as_path());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent_init_reuses_single_cached_runtime() {
|
||||
let temp = TempDir::new().expect("temp dir");
|
||||
let config = Arc::new(test_config(temp.path()));
|
||||
|
||||
let mut handles = Vec::new();
|
||||
for _ in 0..8 {
|
||||
let config = Arc::clone(&config);
|
||||
handles.push(tokio::spawn(async move {
|
||||
init(config.as_ref())
|
||||
.await
|
||||
.expect("state db init should succeed")
|
||||
}));
|
||||
}
|
||||
|
||||
let mut runtimes = Vec::new();
|
||||
for handle in handles {
|
||||
runtimes.push(handle.await.expect("task should join"));
|
||||
}
|
||||
let first = runtimes
|
||||
.first()
|
||||
.cloned()
|
||||
.expect("at least one runtime should exist");
|
||||
first
|
||||
.mark_backfill_complete(/*last_watermark*/ None)
|
||||
.await
|
||||
.expect("backfill should be marked complete");
|
||||
|
||||
for runtime in runtimes.iter().skip(1) {
|
||||
assert!(Arc::ptr_eq(&first, runtime));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,15 +319,14 @@ WHERE id = ?
|
||||
let now = Utc::now().timestamp();
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
UPDATE agent_job_items
|
||||
SET
|
||||
status = ?,
|
||||
assigned_thread_id = NULL,
|
||||
attempt_count = attempt_count + 1,
|
||||
updated_at = ?,
|
||||
last_error = NULL
|
||||
WHERE job_id = ? AND item_id = ? AND status = ?
|
||||
"#,
|
||||
UPDATE agent_job_items
|
||||
SET
|
||||
status = ?,
|
||||
assigned_thread_id = NULL,
|
||||
updated_at = ?,
|
||||
last_error = NULL
|
||||
WHERE job_id = ? AND item_id = ? AND status = ?
|
||||
"#,
|
||||
)
|
||||
.bind(AgentJobItemStatus::Running.as_str())
|
||||
.bind(now)
|
||||
@@ -407,10 +406,10 @@ WHERE job_id = ? AND item_id = ? AND status = ?
|
||||
let now = Utc::now().timestamp();
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
UPDATE agent_job_items
|
||||
SET assigned_thread_id = ?, updated_at = ?
|
||||
WHERE job_id = ? AND item_id = ? AND status = ?
|
||||
"#,
|
||||
UPDATE agent_job_items
|
||||
SET assigned_thread_id = ?, attempt_count = attempt_count + 1, updated_at = ?
|
||||
WHERE job_id = ? AND item_id = ? AND status = ?
|
||||
"#,
|
||||
)
|
||||
.bind(thread_id)
|
||||
.bind(now)
|
||||
@@ -463,6 +462,68 @@ WHERE
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn report_agent_job_item_result_and_cancel_job(
|
||||
&self,
|
||||
job_id: &str,
|
||||
item_id: &str,
|
||||
reporting_thread_id: &str,
|
||||
result_json: &Value,
|
||||
cancel_reason: &str,
|
||||
) -> anyhow::Result<bool> {
|
||||
let now = Utc::now().timestamp();
|
||||
let serialized = serde_json::to_string(result_json)?;
|
||||
let mut tx = self.pool.begin().await?;
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
UPDATE agent_job_items
|
||||
SET
|
||||
status = ?,
|
||||
result_json = ?,
|
||||
reported_at = ?,
|
||||
completed_at = ?,
|
||||
updated_at = ?,
|
||||
last_error = NULL,
|
||||
assigned_thread_id = NULL
|
||||
WHERE
|
||||
job_id = ?
|
||||
AND item_id = ?
|
||||
AND status = ?
|
||||
AND assigned_thread_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(AgentJobItemStatus::Completed.as_str())
|
||||
.bind(serialized)
|
||||
.bind(now)
|
||||
.bind(now)
|
||||
.bind(now)
|
||||
.bind(job_id)
|
||||
.bind(item_id)
|
||||
.bind(AgentJobItemStatus::Running.as_str())
|
||||
.bind(reporting_thread_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
if result.rows_affected() > 0 {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE agent_jobs
|
||||
SET status = ?, updated_at = ?, completed_at = ?, last_error = ?
|
||||
WHERE id = ? AND status IN (?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(AgentJobStatus::Cancelled.as_str())
|
||||
.bind(now)
|
||||
.bind(now)
|
||||
.bind(cancel_reason)
|
||||
.bind(job_id)
|
||||
.bind(AgentJobStatus::Pending.as_str())
|
||||
.bind(AgentJobStatus::Running.as_str())
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
tx.commit().await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn mark_agent_job_item_completed(
|
||||
&self,
|
||||
job_id: &str,
|
||||
@@ -681,4 +742,99 @@ mod tests {
|
||||
assert_eq!(item.last_error, Some("missing report".to_string()));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn report_agent_job_item_result_and_cancel_job_is_atomic() -> anyhow::Result<()> {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home, "test-provider".to_string()).await?;
|
||||
let (job_id, item_id, thread_id) = create_running_single_item_job(runtime.as_ref()).await?;
|
||||
|
||||
let accepted = runtime
|
||||
.report_agent_job_item_result_and_cancel_job(
|
||||
job_id.as_str(),
|
||||
item_id.as_str(),
|
||||
thread_id.as_str(),
|
||||
&json!({"ok": true}),
|
||||
"cancelled by worker request",
|
||||
)
|
||||
.await?;
|
||||
assert!(accepted);
|
||||
|
||||
let item = runtime
|
||||
.get_agent_job_item(job_id.as_str(), item_id.as_str())
|
||||
.await?
|
||||
.expect("job item should exist");
|
||||
assert_eq!(item.status, AgentJobItemStatus::Completed);
|
||||
assert_eq!(item.result_json, Some(json!({"ok": true})));
|
||||
assert_eq!(item.assigned_thread_id, None);
|
||||
|
||||
let job = runtime
|
||||
.get_agent_job(job_id.as_str())
|
||||
.await?
|
||||
.expect("job should exist");
|
||||
assert_eq!(job.status, AgentJobStatus::Cancelled);
|
||||
assert_eq!(
|
||||
job.last_error,
|
||||
Some("cancelled by worker request".to_string())
|
||||
);
|
||||
assert!(job.completed_at.is_some());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn set_agent_job_item_thread_increments_attempt_count_after_claim() -> anyhow::Result<()>
|
||||
{
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home, "test-provider".to_string()).await?;
|
||||
let job_id = "job-1".to_string();
|
||||
let item_id = "item-1".to_string();
|
||||
runtime
|
||||
.create_agent_job(
|
||||
&AgentJobCreateParams {
|
||||
id: job_id.clone(),
|
||||
name: "test-job".to_string(),
|
||||
instruction: "Return a result".to_string(),
|
||||
auto_export: true,
|
||||
max_runtime_seconds: None,
|
||||
output_schema_json: None,
|
||||
input_headers: vec!["path".to_string()],
|
||||
input_csv_path: "/tmp/in.csv".to_string(),
|
||||
output_csv_path: "/tmp/out.csv".to_string(),
|
||||
},
|
||||
&[AgentJobItemCreateParams {
|
||||
item_id: item_id.clone(),
|
||||
row_index: 0,
|
||||
source_id: None,
|
||||
row_json: json!({"path":"file-1"}),
|
||||
}],
|
||||
)
|
||||
.await?;
|
||||
runtime.mark_agent_job_running(job_id.as_str()).await?;
|
||||
|
||||
let claimed = runtime
|
||||
.mark_agent_job_item_running(job_id.as_str(), item_id.as_str())
|
||||
.await?;
|
||||
assert!(claimed);
|
||||
|
||||
let item = runtime
|
||||
.get_agent_job_item(job_id.as_str(), item_id.as_str())
|
||||
.await?
|
||||
.expect("job item should exist");
|
||||
assert_eq!(item.attempt_count, 0);
|
||||
assert_eq!(item.assigned_thread_id, None);
|
||||
|
||||
let assigned = runtime
|
||||
.set_agent_job_item_thread(job_id.as_str(), item_id.as_str(), "thread-1")
|
||||
.await?;
|
||||
assert!(assigned);
|
||||
|
||||
let item = runtime
|
||||
.get_agent_job_item(job_id.as_str(), item_id.as_str())
|
||||
.await?
|
||||
.expect("job item should exist");
|
||||
assert_eq!(item.attempt_count, 1);
|
||||
assert_eq!(item.assigned_thread_id, Some("thread-1".to_string()));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user