Compare commits

...

1 Commits

Author SHA1 Message Date
efrazer-oai
2892e9e1ba feat: add durable queued turn store 2026-05-29 16:09:36 -07:00
10 changed files with 861 additions and 1 deletions

View File

@@ -8,5 +8,6 @@ codex_rust_crate(
"logs_migrations/**",
"memory_migrations/**",
"migrations/**",
"queue_migrations/**",
]),
)

View File

@@ -0,0 +1,14 @@
CREATE TABLE thread_queued_turns (
queued_turn_id TEXT PRIMARY KEY NOT NULL,
thread_id TEXT NOT NULL,
turn_submission_jsonb BLOB NOT NULL,
queue_order INTEGER NOT NULL,
state TEXT NOT NULL CHECK (state IN ('pending', 'dispatching', 'failed')),
dispatch_turn_id TEXT,
failure_jsonb BLOB,
created_at_ms INTEGER NOT NULL,
updated_at_ms INTEGER NOT NULL
);
CREATE INDEX thread_queued_turns_thread_state_order_idx
ON thread_queued_turns(thread_id, state, queue_order);

View File

@@ -50,12 +50,15 @@ pub use model::ThreadGoal;
pub use model::ThreadGoalStatus;
pub use model::ThreadMetadata;
pub use model::ThreadMetadataBuilder;
pub use model::ThreadQueuedTurn;
pub use model::ThreadQueuedTurnState;
pub use model::ThreadsPage;
pub use runtime::GoalAccountingMode;
pub use runtime::GoalAccountingOutcome;
pub use runtime::GoalStore;
pub use runtime::GoalUpdate;
pub use runtime::MemoryStore;
pub use runtime::QueueStore;
pub use runtime::RemoteControlEnrollmentRecord;
pub use runtime::RuntimeDbPath;
pub use runtime::ThreadFilterOptions;
@@ -65,6 +68,8 @@ pub use runtime::logs_db_filename;
pub use runtime::logs_db_path;
pub use runtime::memories_db_filename;
pub use runtime::memories_db_path;
pub use runtime::queue_db_filename;
pub use runtime::queue_db_path;
pub use runtime::runtime_db_paths;
pub use runtime::sqlite_integrity_check;
pub use runtime::state_db_filename;
@@ -81,6 +86,7 @@ pub const SQLITE_HOME_ENV: &str = "CODEX_SQLITE_HOME";
pub const LOGS_DB_FILENAME: &str = "logs_2.sqlite";
pub const GOALS_DB_FILENAME: &str = "goals_1.sqlite";
pub const MEMORIES_DB_FILENAME: &str = "memories_1.sqlite";
pub const QUEUE_DB_FILENAME: &str = "queue_1.sqlite";
pub const STATE_DB_FILENAME: &str = "state_5.sqlite";
/// Errors encountered during DB operations. Tags: [stage]

View File

@@ -6,6 +6,7 @@ pub(crate) static STATE_MIGRATOR: Migrator = sqlx::migrate!("./migrations");
pub(crate) static LOGS_MIGRATOR: Migrator = sqlx::migrate!("./logs_migrations");
pub(crate) static GOALS_MIGRATOR: Migrator = sqlx::migrate!("./goals_migrations");
pub(crate) static MEMORIES_MIGRATOR: Migrator = sqlx::migrate!("./memory_migrations");
pub(crate) static QUEUE_MIGRATOR: Migrator = sqlx::migrate!("./queue_migrations");
/// Allow an older Codex binary to open a database that has already been
/// migrated by a newer binary running in parallel.
@@ -39,3 +40,7 @@ pub(crate) fn runtime_goals_migrator() -> Migrator {
pub(crate) fn runtime_memories_migrator() -> Migrator {
runtime_migrator(&MEMORIES_MIGRATOR)
}
pub(crate) fn runtime_queue_migrator() -> Migrator {
runtime_migrator(&QUEUE_MIGRATOR)
}

View File

@@ -5,6 +5,7 @@ mod log;
mod memories;
mod thread_goal;
mod thread_metadata;
mod thread_queued_turn;
pub use agent_job::AgentJob;
pub use agent_job::AgentJobCreateParams;
@@ -34,6 +35,8 @@ pub use thread_metadata::SortKey;
pub use thread_metadata::ThreadMetadata;
pub use thread_metadata::ThreadMetadataBuilder;
pub use thread_metadata::ThreadsPage;
pub use thread_queued_turn::ThreadQueuedTurn;
pub use thread_queued_turn::ThreadQueuedTurnState;
pub(crate) use agent_job::AgentJobItemRow;
pub(crate) use agent_job::AgentJobRow;
@@ -43,3 +46,4 @@ pub(crate) use thread_metadata::anchor_from_item;
pub(crate) use thread_metadata::datetime_to_epoch_millis;
pub(crate) use thread_metadata::datetime_to_epoch_seconds;
pub(crate) use thread_metadata::epoch_millis_to_datetime;
pub(crate) use thread_queued_turn::ThreadQueuedTurnRow;

View File

@@ -0,0 +1,98 @@
use anyhow::Result;
use anyhow::anyhow;
use chrono::DateTime;
use chrono::Utc;
use codex_protocol::ThreadId;
use sqlx::Row;
use sqlx::sqlite::SqliteRow;
use super::epoch_millis_to_datetime;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ThreadQueuedTurnState {
Pending,
Dispatching,
Failed,
}
impl ThreadQueuedTurnState {
pub fn as_str(self) -> &'static str {
match self {
Self::Pending => "pending",
Self::Dispatching => "dispatching",
Self::Failed => "failed",
}
}
}
impl TryFrom<&str> for ThreadQueuedTurnState {
type Error = anyhow::Error;
fn try_from(value: &str) -> Result<Self> {
match value {
"pending" => Ok(Self::Pending),
"dispatching" => Ok(Self::Dispatching),
"failed" => Ok(Self::Failed),
other => Err(anyhow!("unknown thread queued turn state `{other}`")),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ThreadQueuedTurn {
pub queued_turn_id: String,
pub thread_id: ThreadId,
pub turn_submission_jsonb: Vec<u8>,
pub queue_order: i64,
pub state: ThreadQueuedTurnState,
pub dispatch_turn_id: Option<String>,
pub failure_jsonb: Option<Vec<u8>>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
pub(crate) struct ThreadQueuedTurnRow {
pub queued_turn_id: String,
pub thread_id: String,
pub turn_submission_jsonb: Vec<u8>,
pub queue_order: i64,
pub state: String,
pub dispatch_turn_id: Option<String>,
pub failure_jsonb: Option<Vec<u8>>,
pub created_at_ms: i64,
pub updated_at_ms: i64,
}
impl ThreadQueuedTurnRow {
pub(crate) fn try_from_row(row: &SqliteRow) -> Result<Self> {
Ok(Self {
queued_turn_id: row.try_get("queued_turn_id")?,
thread_id: row.try_get("thread_id")?,
turn_submission_jsonb: row.try_get("turn_submission_jsonb")?,
queue_order: row.try_get("queue_order")?,
state: row.try_get("state")?,
dispatch_turn_id: row.try_get("dispatch_turn_id")?,
failure_jsonb: row.try_get("failure_jsonb")?,
created_at_ms: row.try_get("created_at_ms")?,
updated_at_ms: row.try_get("updated_at_ms")?,
})
}
}
impl TryFrom<ThreadQueuedTurnRow> for ThreadQueuedTurn {
type Error = anyhow::Error;
fn try_from(row: ThreadQueuedTurnRow) -> Result<Self> {
Ok(Self {
queued_turn_id: row.queued_turn_id,
thread_id: ThreadId::try_from(row.thread_id)?,
turn_submission_jsonb: row.turn_submission_jsonb,
queue_order: row.queue_order,
state: ThreadQueuedTurnState::try_from(row.state.as_str())?,
dispatch_turn_id: row.dispatch_turn_id,
failure_jsonb: row.failure_jsonb,
created_at: epoch_millis_to_datetime(row.created_at_ms)?,
updated_at: epoch_millis_to_datetime(row.updated_at_ms)?,
})
}
}

View File

@@ -11,6 +11,7 @@ use crate::LogEntry;
use crate::LogQuery;
use crate::LogRow;
use crate::MEMORIES_DB_FILENAME;
use crate::QUEUE_DB_FILENAME;
use crate::STATE_DB_FILENAME;
use crate::SortKey;
use crate::ThreadMetadata;
@@ -20,6 +21,7 @@ use crate::apply_rollout_item;
use crate::migrations::runtime_goals_migrator;
use crate::migrations::runtime_logs_migrator;
use crate::migrations::runtime_memories_migrator;
use crate::migrations::runtime_queue_migrator;
use crate::migrations::runtime_state_migrator;
use crate::model::AgentJobRow;
use crate::model::ThreadRow;
@@ -62,6 +64,7 @@ mod backfill;
mod goals;
mod logs;
mod memories;
mod queued_turns;
mod remote_control;
#[cfg(test)]
mod test_support;
@@ -72,6 +75,7 @@ pub use goals::GoalAccountingOutcome;
pub use goals::GoalStore;
pub use goals::GoalUpdate;
pub use memories::MemoryStore;
pub use queued_turns::QueueStore;
pub use remote_control::RemoteControlEnrollmentRecord;
pub use threads::ThreadFilterOptions;
@@ -131,7 +135,15 @@ const MEMORIES_DB: RuntimeDbSpec = RuntimeDbSpec {
migrate_phase: "migrate_memories",
};
const RUNTIME_DBS: [RuntimeDbSpec; 4] = [STATE_DB, LOGS_DB, GOALS_DB, MEMORIES_DB];
const QUEUE_DB: RuntimeDbSpec = RuntimeDbSpec {
label: "queue DB",
filename: QUEUE_DB_FILENAME,
kind: DbKind::Queue,
open_phase: "open_queue",
migrate_phase: "migrate_queue",
};
const RUNTIME_DBS: [RuntimeDbSpec; 5] = [STATE_DB, LOGS_DB, GOALS_DB, MEMORIES_DB, QUEUE_DB];
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct RuntimeDbPath {
@@ -147,6 +159,7 @@ pub struct StateRuntime {
logs_pool: Arc<sqlx::SqlitePool>,
thread_goals: GoalStore,
memories: MemoryStore,
thread_queue: QueueStore,
thread_updated_at_millis: Arc<AtomicI64>,
}
@@ -184,10 +197,12 @@ impl StateRuntime {
let logs_migrator = runtime_logs_migrator();
let goals_migrator = runtime_goals_migrator();
let memories_migrator = runtime_memories_migrator();
let queue_migrator = runtime_queue_migrator();
let state_path = STATE_DB.path(codex_home.as_path());
let logs_path = LOGS_DB.path(codex_home.as_path());
let goals_path = GOALS_DB.path(codex_home.as_path());
let memories_path = MEMORIES_DB.path(codex_home.as_path());
let queue_path = QUEUE_DB.path(codex_home.as_path());
let pool = match open_state_sqlite(&state_path, &state_migrator, telemetry_override).await {
Ok(db) => Arc::new(db),
Err(err) => {
@@ -227,6 +242,14 @@ impl StateRuntime {
return Err(err);
}
};
let queue_pool =
match open_queue_sqlite(&queue_path, &queue_migrator, telemetry_override).await {
Ok(db) => Arc::new(db),
Err(err) => {
warn!("failed to open queue db at {}: {err}", queue_path.display());
return Err(err);
}
};
let started = Instant::now();
let backfill_state_result = ensure_backfill_state_row_in_pool(pool.as_ref()).await;
crate::telemetry::record_init_result(
@@ -255,6 +278,7 @@ impl StateRuntime {
let runtime = Arc::new(Self {
thread_goals: GoalStore::new(Arc::clone(&goals_pool)),
memories: MemoryStore::new(Arc::clone(&memories_pool), Arc::clone(&pool)),
thread_queue: QueueStore::new(Arc::clone(&queue_pool)),
pool,
logs_pool,
codex_home,
@@ -283,6 +307,10 @@ impl StateRuntime {
&self.memories
}
pub fn thread_queue(&self) -> &QueueStore {
&self.thread_queue
}
pub async fn clear_memory_data_in_sqlite_home(sqlite_home: &Path) -> anyhow::Result<bool> {
let memories_path = MEMORIES_DB.path(sqlite_home);
if !tokio::fs::try_exists(&memories_path).await? {
@@ -347,6 +375,14 @@ async fn open_memories_sqlite(
open_sqlite(path, migrator, MEMORIES_DB, telemetry_override).await
}
async fn open_queue_sqlite(
path: &Path,
migrator: &Migrator,
telemetry_override: Option<&dyn DbTelemetry>,
) -> anyhow::Result<SqlitePool> {
open_sqlite(path, migrator, QUEUE_DB, telemetry_override).await
}
async fn open_sqlite(
path: &Path,
migrator: &Migrator,
@@ -431,6 +467,14 @@ pub fn memories_db_path(codex_home: &Path) -> PathBuf {
MEMORIES_DB.path(codex_home)
}
pub fn queue_db_filename() -> String {
QUEUE_DB.filename.to_string()
}
pub fn queue_db_path(codex_home: &Path) -> PathBuf {
QUEUE_DB.path(codex_home)
}
pub fn runtime_db_paths(codex_home: &Path) -> Vec<RuntimeDbPath> {
RUNTIME_DBS
.iter()
@@ -649,6 +693,8 @@ mod tests {
"migrate_goals",
"open_memories",
"migrate_memories",
"open_queue",
"migrate_queue",
"ensure_backfill_state",
"post_init_query",
]

View File

@@ -0,0 +1,683 @@
use super::*;
use uuid::Uuid;
#[derive(Clone)]
pub struct QueueStore {
pool: Arc<SqlitePool>,
}
impl QueueStore {
pub(crate) fn new(pool: Arc<SqlitePool>) -> Self {
Self { pool }
}
}
impl QueueStore {
pub async fn append_thread_queued_turn(
&self,
thread_id: ThreadId,
turn_submission_json: &[u8],
) -> anyhow::Result<crate::ThreadQueuedTurn> {
let queued_turn_id = Uuid::now_v7().to_string();
let now_ms = datetime_to_epoch_millis(Utc::now());
let row = sqlx::query(
r#"
INSERT INTO thread_queued_turns (
queued_turn_id,
thread_id,
turn_submission_jsonb,
queue_order,
state,
dispatch_turn_id,
failure_jsonb,
created_at_ms,
updated_at_ms
)
SELECT
?,
?,
jsonb(?),
COALESCE(MAX(queue_order), -1) + 1,
'pending',
NULL,
NULL,
?,
?
FROM thread_queued_turns
WHERE thread_id = ?
RETURNING
queued_turn_id,
thread_id,
CAST(json(turn_submission_jsonb) AS BLOB) AS turn_submission_jsonb,
queue_order,
state,
dispatch_turn_id,
CASE
WHEN failure_jsonb IS NULL THEN NULL
ELSE CAST(json(failure_jsonb) AS BLOB)
END AS failure_jsonb,
created_at_ms,
updated_at_ms
"#,
)
.bind(queued_turn_id)
.bind(thread_id.to_string())
.bind(turn_submission_json)
.bind(now_ms)
.bind(now_ms)
.bind(thread_id.to_string())
.fetch_one(self.pool.as_ref())
.await?;
thread_queued_turn_from_row(&row)
}
pub async fn list_visible_thread_queued_turns(
&self,
thread_id: ThreadId,
) -> anyhow::Result<Vec<crate::ThreadQueuedTurn>> {
self.list_visible_thread_queued_turns_page(thread_id, /*offset*/ 0, i64::MAX as usize)
.await
}
pub async fn list_visible_thread_queued_turns_page(
&self,
thread_id: ThreadId,
offset: usize,
limit: usize,
) -> anyhow::Result<Vec<crate::ThreadQueuedTurn>> {
let rows = sqlx::query(
r#"
SELECT
queued_turn_id,
thread_id,
CAST(json(turn_submission_jsonb) AS BLOB) AS turn_submission_jsonb,
queue_order,
state,
dispatch_turn_id,
CASE
WHEN failure_jsonb IS NULL THEN NULL
ELSE CAST(json(failure_jsonb) AS BLOB)
END AS failure_jsonb,
created_at_ms,
updated_at_ms
FROM thread_queued_turns
WHERE thread_id = ?
AND state IN ('pending', 'failed')
ORDER BY queue_order ASC
LIMIT ?
OFFSET ?
"#,
)
.bind(thread_id.to_string())
.bind(i64::try_from(limit)?)
.bind(i64::try_from(offset)?)
.fetch_all(self.pool.as_ref())
.await?;
rows.iter().map(thread_queued_turn_from_row).collect()
}
pub async fn delete_thread_queued_turn(
&self,
thread_id: ThreadId,
queued_turn_id: &str,
) -> anyhow::Result<bool> {
let result = sqlx::query(
r#"
DELETE FROM thread_queued_turns
WHERE thread_id = ?
AND queued_turn_id = ?
AND state IN ('pending', 'failed')
"#,
)
.bind(thread_id.to_string())
.bind(queued_turn_id)
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected() > 0)
}
pub async fn reorder_thread_queued_turns(
&self,
thread_id: ThreadId,
ordered_ids: &[String],
) -> anyhow::Result<Vec<crate::ThreadQueuedTurn>> {
let mut transaction = self.pool.begin().await?;
let visible_rows: Vec<(String, i64)> = sqlx::query_as(
r#"
SELECT queued_turn_id, queue_order
FROM thread_queued_turns
WHERE thread_id = ?
AND state IN ('pending', 'failed')
ORDER BY queue_order ASC
"#,
)
.bind(thread_id.to_string())
.fetch_all(transaction.as_mut())
.await?;
let visible_ids = visible_rows
.iter()
.map(|(queued_turn_id, _)| queued_turn_id.clone())
.collect::<Vec<_>>();
let visible_queue_orders = visible_rows
.into_iter()
.map(|(_, queue_order)| queue_order)
.collect::<Vec<_>>();
let mut expected_ids = visible_ids.clone();
expected_ids.sort();
let mut requested_ids = ordered_ids.to_vec();
requested_ids.sort();
if expected_ids != requested_ids {
anyhow::bail!("queue reorder must include every visible queued turn exactly once");
}
let now_ms = datetime_to_epoch_millis(Utc::now());
for (temporary_order, queued_turn_id) in ordered_ids.iter().enumerate() {
sqlx::query(
r#"
UPDATE thread_queued_turns
SET queue_order = ?, updated_at_ms = ?
WHERE thread_id = ?
AND queued_turn_id = ?
AND state IN ('pending', 'failed')
"#,
)
.bind(-((temporary_order as i64) + 1))
.bind(now_ms)
.bind(thread_id.to_string())
.bind(queued_turn_id)
.execute(transaction.as_mut())
.await?;
}
for (queue_order, queued_turn_id) in visible_queue_orders.into_iter().zip(ordered_ids) {
sqlx::query(
r#"
UPDATE thread_queued_turns
SET queue_order = ?, updated_at_ms = ?
WHERE thread_id = ?
AND queued_turn_id = ?
AND state IN ('pending', 'failed')
"#,
)
.bind(queue_order)
.bind(now_ms)
.bind(thread_id.to_string())
.bind(queued_turn_id)
.execute(transaction.as_mut())
.await?;
}
transaction.commit().await?;
self.list_visible_thread_queued_turns(thread_id).await
}
pub async fn claim_head_thread_queued_turn(
&self,
thread_id: ThreadId,
) -> anyhow::Result<Option<crate::ThreadQueuedTurn>> {
let now_ms = datetime_to_epoch_millis(Utc::now());
let row = sqlx::query(
r#"
UPDATE thread_queued_turns
SET state = 'dispatching', updated_at_ms = ?
WHERE queued_turn_id = (
SELECT head.queued_turn_id
FROM thread_queued_turns AS head
WHERE head.thread_id = ?
AND head.state IN ('pending', 'failed')
AND NOT EXISTS (
SELECT 1
FROM thread_queued_turns AS active
WHERE active.thread_id = head.thread_id
AND active.state = 'dispatching'
)
ORDER BY head.queue_order ASC
LIMIT 1
)
AND state = 'pending'
RETURNING
queued_turn_id,
thread_id,
CAST(json(turn_submission_jsonb) AS BLOB) AS turn_submission_jsonb,
queue_order,
state,
dispatch_turn_id,
CASE
WHEN failure_jsonb IS NULL THEN NULL
ELSE CAST(json(failure_jsonb) AS BLOB)
END AS failure_jsonb,
created_at_ms,
updated_at_ms
"#,
)
.bind(now_ms)
.bind(thread_id.to_string())
.fetch_optional(self.pool.as_ref())
.await?;
row.map(|row| thread_queued_turn_from_row(&row)).transpose()
}
pub async fn set_dispatching_thread_queued_turn_turn_id(
&self,
queued_turn_id: &str,
turn_id: &str,
) -> anyhow::Result<bool> {
let now_ms = datetime_to_epoch_millis(Utc::now());
let result = sqlx::query(
r#"
UPDATE thread_queued_turns
SET dispatch_turn_id = ?, updated_at_ms = ?
WHERE queued_turn_id = ?
AND state = 'dispatching'
"#,
)
.bind(turn_id)
.bind(now_ms)
.bind(queued_turn_id)
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected() > 0)
}
pub async fn remove_dispatching_thread_queued_turn(
&self,
thread_id: ThreadId,
turn_id: &str,
) -> anyhow::Result<bool> {
let result = sqlx::query(
r#"
DELETE FROM thread_queued_turns
WHERE thread_id = ?
AND state = 'dispatching'
AND dispatch_turn_id = ?
"#,
)
.bind(thread_id.to_string())
.bind(turn_id)
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected() > 0)
}
pub async fn mark_thread_queued_turn_failed(
&self,
queued_turn_id: &str,
failure_json: &[u8],
) -> anyhow::Result<bool> {
let now_ms = datetime_to_epoch_millis(Utc::now());
let result = sqlx::query(
r#"
UPDATE thread_queued_turns
SET
state = 'failed',
failure_jsonb = jsonb(?),
updated_at_ms = ?
WHERE queued_turn_id = ?
AND state = 'dispatching'
"#,
)
.bind(failure_json)
.bind(now_ms)
.bind(queued_turn_id)
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected() > 0)
}
pub async fn recover_dispatching_thread_queued_turns(
&self,
thread_id: ThreadId,
failure_json: &[u8],
) -> anyhow::Result<u64> {
let now_ms = datetime_to_epoch_millis(Utc::now());
let result = sqlx::query(
r#"
UPDATE thread_queued_turns
SET
state = 'failed',
failure_jsonb = jsonb(?),
updated_at_ms = ?
WHERE thread_id = ?
AND state = 'dispatching'
"#,
)
.bind(failure_json)
.bind(now_ms)
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected())
}
pub async fn delete_thread_queue(&self, thread_id: ThreadId) -> anyhow::Result<bool> {
let result = sqlx::query(
r#"
DELETE FROM thread_queued_turns
WHERE thread_id = ?
"#,
)
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected() > 0)
}
}
fn thread_queued_turn_from_row(
row: &sqlx::sqlite::SqliteRow,
) -> anyhow::Result<crate::ThreadQueuedTurn> {
crate::model::ThreadQueuedTurnRow::try_from_row(row)?.try_into()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::test_support::test_thread_metadata;
use crate::runtime::test_support::unique_temp_dir;
use pretty_assertions::assert_eq;
async fn runtime_with_thread() -> (Arc<StateRuntime>, ThreadId) {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string())
.await
.expect("state runtime");
let thread_id = ThreadId::new();
runtime
.upsert_thread(&test_thread_metadata(
codex_home.as_path(),
thread_id,
codex_home.clone(),
))
.await
.expect("insert thread");
(runtime, thread_id)
}
#[tokio::test]
async fn queued_turn_claim_is_single_winner_and_hides_dispatching_row() {
let (runtime, thread_id) = runtime_with_thread().await;
let queue = runtime.thread_queue();
queue
.append_thread_queued_turn(thread_id, br#"{"threadId":"t","input":[]}"#)
.await
.expect("append queued turn");
let (first, second) = tokio::join!(
queue.claim_head_thread_queued_turn(thread_id),
queue.claim_head_thread_queued_turn(thread_id),
);
let claimed = [first.expect("first claim"), second.expect("second claim")]
.into_iter()
.flatten()
.count();
assert_eq!(claimed, 1);
assert_eq!(
queue
.list_visible_thread_queued_turns(thread_id)
.await
.expect("list visible queued turns"),
Vec::new()
);
}
#[tokio::test]
async fn queued_turn_added_during_dispatch_claim_waits_for_existing_claim() {
let (runtime, thread_id) = runtime_with_thread().await;
let queue = runtime.thread_queue();
queue
.append_thread_queued_turn(thread_id, br#"{"threadId":"t","input":[]}"#)
.await
.expect("append first");
queue
.claim_head_thread_queued_turn(thread_id)
.await
.expect("claim first")
.expect("claimed row");
let second = queue
.append_thread_queued_turn(thread_id, br#"{"threadId":"t","input":[]}"#)
.await
.expect("append second");
assert_eq!(
queue
.claim_head_thread_queued_turn(thread_id)
.await
.expect("claim blocked by dispatch"),
None
);
assert_eq!(
queue
.list_visible_thread_queued_turns(thread_id)
.await
.expect("list visible queued turns"),
vec![second]
);
}
#[tokio::test]
async fn dispatch_claim_rejects_stale_mutations_and_keeps_later_rows_reorderable() {
let (runtime, thread_id) = runtime_with_thread().await;
let queue = runtime.thread_queue();
let first = queue
.append_thread_queued_turn(thread_id, br#"{"threadId":"t","input":[]}"#)
.await
.expect("append first");
let second = queue
.append_thread_queued_turn(thread_id, br#"{"threadId":"t","input":[]}"#)
.await
.expect("append second");
let third = queue
.append_thread_queued_turn(thread_id, br#"{"threadId":"t","input":[]}"#)
.await
.expect("append third");
queue
.claim_head_thread_queued_turn(thread_id)
.await
.expect("claim first")
.expect("claimed row");
assert!(
!queue
.delete_thread_queued_turn(thread_id, &first.queued_turn_id)
.await
.expect("dispatching row is not deletable")
);
assert!(
queue
.reorder_thread_queued_turns(
thread_id,
&[
first.queued_turn_id.clone(),
third.queued_turn_id.clone(),
second.queued_turn_id.clone(),
],
)
.await
.is_err()
);
let reordered_ids = queue
.reorder_thread_queued_turns(
thread_id,
&[third.queued_turn_id.clone(), second.queued_turn_id.clone()],
)
.await
.expect("reorder visible rows")
.into_iter()
.map(|queued_turn| queued_turn.queued_turn_id)
.collect::<Vec<_>>();
assert_eq!(
reordered_ids,
vec![third.queued_turn_id, second.queued_turn_id]
);
}
#[tokio::test]
async fn abandoned_dispatch_claim_recovers_as_failed_and_blocks_fifo() {
let (runtime, thread_id) = runtime_with_thread().await;
let queue = runtime.thread_queue();
let first = queue
.append_thread_queued_turn(thread_id, br#"{"threadId":"t","input":[]}"#)
.await
.expect("append first");
queue
.append_thread_queued_turn(thread_id, br#"{"threadId":"t","input":[]}"#)
.await
.expect("append second");
queue
.claim_head_thread_queued_turn(thread_id)
.await
.expect("claim first")
.expect("claimed row");
assert_eq!(
queue
.recover_dispatching_thread_queued_turns(
thread_id,
br#"{"message":"dispatch interrupted"}"#,
)
.await
.expect("recover dispatching rows"),
1
);
let visible = queue
.list_visible_thread_queued_turns(thread_id)
.await
.expect("list recovered queue");
assert_eq!(visible[0].queued_turn_id, first.queued_turn_id);
assert_eq!(visible[0].state, crate::ThreadQueuedTurnState::Failed);
assert_eq!(
queue
.claim_head_thread_queued_turn(thread_id)
.await
.expect("failed head blocks claim"),
None
);
}
#[tokio::test]
async fn failed_head_blocks_later_pending_work_until_removed() {
let (runtime, thread_id) = runtime_with_thread().await;
let queue = runtime.thread_queue();
let first = queue
.append_thread_queued_turn(thread_id, br#"{"threadId":"t","input":[]}"#)
.await
.expect("append first");
queue
.append_thread_queued_turn(thread_id, br#"{"threadId":"t","input":[]}"#)
.await
.expect("append second");
let claimed = queue
.claim_head_thread_queued_turn(thread_id)
.await
.expect("claim first")
.expect("claimed row");
assert_eq!(claimed.queued_turn_id, first.queued_turn_id);
queue
.mark_thread_queued_turn_failed(&claimed.queued_turn_id, br#"{"message":"nope"}"#)
.await
.expect("mark failed");
assert_eq!(
queue
.claim_head_thread_queued_turn(thread_id)
.await
.expect("blocked claim"),
None
);
assert!(
queue
.delete_thread_queued_turn(thread_id, &first.queued_turn_id)
.await
.expect("delete failed head")
);
assert!(
queue
.claim_head_thread_queued_turn(thread_id)
.await
.expect("claim next")
.is_some()
);
}
#[tokio::test]
async fn dispatch_claim_clears_only_for_its_submitted_turn() {
let (runtime, thread_id) = runtime_with_thread().await;
let queue = runtime.thread_queue();
let queued_turn = queue
.append_thread_queued_turn(thread_id, br#"{"threadId":"t","input":[]}"#)
.await
.expect("append queued turn");
queue
.claim_head_thread_queued_turn(thread_id)
.await
.expect("claim queued turn")
.expect("claimed row");
assert!(
!queue
.remove_dispatching_thread_queued_turn(thread_id, "regular-turn")
.await
.expect("unmatched turn must not clear claim")
);
assert!(
queue
.set_dispatching_thread_queued_turn_turn_id(
&queued_turn.queued_turn_id,
"queued-turn",
)
.await
.expect("record submitted queued turn id")
);
assert!(
!queue
.remove_dispatching_thread_queued_turn(thread_id, "regular-turn")
.await
.expect("different started turn must not clear claim")
);
assert!(
queue
.remove_dispatching_thread_queued_turn(thread_id, "queued-turn")
.await
.expect("matching queued turn clears claim")
);
}
#[tokio::test]
async fn deleting_thread_metadata_deletes_queued_turns() {
let (runtime, thread_id) = runtime_with_thread().await;
runtime
.thread_queue()
.append_thread_queued_turn(thread_id, br#"{"threadId":"t","input":[]}"#)
.await
.expect("append queued turn");
assert_eq!(
runtime
.delete_thread(thread_id)
.await
.expect("delete thread"),
1
);
assert_eq!(
runtime
.thread_queue()
.list_visible_thread_queued_turns(thread_id)
.await
.expect("list queued turns"),
Vec::new()
);
}
}

View File

@@ -890,6 +890,7 @@ ON CONFLICT(id) DO UPDATE SET
self.memories.delete_thread_memory(thread_id).await?;
if rows_affected > 0 {
self.thread_goals.delete_thread_goal(thread_id).await?;
self.thread_queue.delete_thread_queue(thread_id).await?;
}
Ok(rows_affected)
}

View File

@@ -41,6 +41,7 @@ pub(crate) enum DbKind {
Logs,
Goals,
Memories,
Queue,
}
impl DbKind {
@@ -50,6 +51,7 @@ impl DbKind {
Self::Logs => "logs",
Self::Goals => "goals",
Self::Memories => "memories",
Self::Queue => "queue",
}
}
}