mirror of
https://github.com/openai/codex.git
synced 2026-05-01 09:56:37 +00:00
Compare commits
4 Commits
ice-window
...
etraut/mes
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
70b7a3700c | ||
|
|
75b389665a | ||
|
|
09b222272b | ||
|
|
5d482b54e9 |
22
codex-rs/state/migrations/0025_thread_timers.sql
Normal file
22
codex-rs/state/migrations/0025_thread_timers.sql
Normal file
@@ -0,0 +1,22 @@
|
||||
CREATE TABLE thread_timers (
|
||||
id TEXT PRIMARY KEY,
|
||||
thread_id TEXT NOT NULL,
|
||||
source TEXT NOT NULL,
|
||||
client_id TEXT NOT NULL,
|
||||
trigger_json TEXT NOT NULL,
|
||||
prompt TEXT NOT NULL,
|
||||
delivery TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
next_run_at INTEGER,
|
||||
last_run_at INTEGER,
|
||||
pending_run INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX idx_thread_timers_thread_created
|
||||
ON thread_timers(thread_id, created_at, id);
|
||||
|
||||
CREATE INDEX idx_thread_timers_thread_pending
|
||||
ON thread_timers(thread_id, pending_run, created_at, id);
|
||||
|
||||
CREATE INDEX idx_thread_timers_thread_next_run
|
||||
ON thread_timers(thread_id, next_run_at);
|
||||
77
codex-rs/state/migrations/0026_external_messages.sql
Normal file
77
codex-rs/state/migrations/0026_external_messages.sql
Normal file
@@ -0,0 +1,77 @@
|
||||
ALTER TABLE thread_timers RENAME TO thread_timers_old;
|
||||
|
||||
CREATE TABLE thread_timers (
|
||||
id TEXT PRIMARY KEY,
|
||||
thread_id TEXT NOT NULL,
|
||||
source TEXT NOT NULL,
|
||||
client_id TEXT NOT NULL,
|
||||
trigger_json TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
instructions TEXT,
|
||||
meta_json TEXT NOT NULL,
|
||||
delivery TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
next_run_at INTEGER,
|
||||
last_run_at INTEGER,
|
||||
pending_run INTEGER NOT NULL
|
||||
);
|
||||
|
||||
INSERT INTO thread_timers (
|
||||
id,
|
||||
thread_id,
|
||||
source,
|
||||
client_id,
|
||||
trigger_json,
|
||||
content,
|
||||
instructions,
|
||||
meta_json,
|
||||
delivery,
|
||||
created_at,
|
||||
next_run_at,
|
||||
last_run_at,
|
||||
pending_run
|
||||
)
|
||||
SELECT
|
||||
id,
|
||||
thread_id,
|
||||
source,
|
||||
client_id,
|
||||
trigger_json,
|
||||
prompt,
|
||||
NULL,
|
||||
'{}',
|
||||
delivery,
|
||||
created_at,
|
||||
next_run_at,
|
||||
last_run_at,
|
||||
pending_run
|
||||
FROM thread_timers_old;
|
||||
|
||||
DROP TABLE thread_timers_old;
|
||||
|
||||
CREATE INDEX idx_thread_timers_thread_created
|
||||
ON thread_timers(thread_id, created_at, id);
|
||||
|
||||
CREATE INDEX idx_thread_timers_thread_pending
|
||||
ON thread_timers(thread_id, pending_run, created_at, id);
|
||||
|
||||
CREATE INDEX idx_thread_timers_thread_next_run
|
||||
ON thread_timers(thread_id, next_run_at);
|
||||
|
||||
CREATE TABLE external_messages (
|
||||
seq INTEGER PRIMARY KEY,
|
||||
id TEXT NOT NULL UNIQUE,
|
||||
thread_id TEXT NOT NULL,
|
||||
source TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
instructions TEXT,
|
||||
meta_json TEXT NOT NULL,
|
||||
delivery TEXT NOT NULL,
|
||||
queued_at INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX external_messages_thread_order_idx
|
||||
ON external_messages(thread_id, queued_at, seq);
|
||||
|
||||
CREATE INDEX external_messages_thread_delivery_order_idx
|
||||
ON external_messages(thread_id, delivery, queued_at, seq);
|
||||
@@ -36,6 +36,9 @@ pub use model::BackfillState;
|
||||
pub use model::BackfillStats;
|
||||
pub use model::BackfillStatus;
|
||||
pub use model::DirectionalThreadSpawnEdgeStatus;
|
||||
pub use model::ExternalMessage;
|
||||
pub use model::ExternalMessageClaim;
|
||||
pub use model::ExternalMessageCreateParams;
|
||||
pub use model::ExtractionOutcome;
|
||||
pub use model::SortKey;
|
||||
pub use model::Stage1JobClaim;
|
||||
@@ -45,8 +48,12 @@ pub use model::Stage1OutputRef;
|
||||
pub use model::Stage1StartupClaimParams;
|
||||
pub use model::ThreadMetadata;
|
||||
pub use model::ThreadMetadataBuilder;
|
||||
pub use model::ThreadTimer;
|
||||
pub use model::ThreadTimerCreateParams;
|
||||
pub use model::ThreadTimerUpdateParams;
|
||||
pub use model::ThreadsPage;
|
||||
pub use runtime::RemoteControlEnrollmentRecord;
|
||||
pub use runtime::TimerDataVersionChecker;
|
||||
pub use runtime::logs_db_filename;
|
||||
pub use runtime::logs_db_path;
|
||||
pub use runtime::state_db_filename;
|
||||
|
||||
85
codex-rs/state/src/model/external_message.rs
Normal file
85
codex-rs/state/src/model/external_message.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
use sqlx::FromRow;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct ExternalMessageCreateParams {
|
||||
pub id: String,
|
||||
pub thread_id: String,
|
||||
pub source: String,
|
||||
pub content: String,
|
||||
pub instructions: Option<String>,
|
||||
pub meta_json: String,
|
||||
pub delivery: String,
|
||||
pub queued_at: i64,
|
||||
}
|
||||
|
||||
impl ExternalMessageCreateParams {
|
||||
pub fn new(
|
||||
thread_id: String,
|
||||
source: String,
|
||||
content: String,
|
||||
instructions: Option<String>,
|
||||
meta_json: String,
|
||||
delivery: String,
|
||||
queued_at: i64,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
thread_id,
|
||||
source,
|
||||
content,
|
||||
instructions,
|
||||
meta_json,
|
||||
delivery,
|
||||
queued_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct ExternalMessage {
|
||||
pub seq: i64,
|
||||
pub id: String,
|
||||
pub thread_id: String,
|
||||
pub source: String,
|
||||
pub content: String,
|
||||
pub instructions: Option<String>,
|
||||
pub meta_json: String,
|
||||
pub delivery: String,
|
||||
pub queued_at: i64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum ExternalMessageClaim {
|
||||
Claimed(ExternalMessage),
|
||||
Invalid { id: String, reason: String },
|
||||
NotReady,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
pub(crate) struct ExternalMessageRow {
|
||||
pub seq: i64,
|
||||
pub id: String,
|
||||
pub thread_id: String,
|
||||
pub source: String,
|
||||
pub content: String,
|
||||
pub instructions: Option<String>,
|
||||
pub meta_json: String,
|
||||
pub delivery: String,
|
||||
pub queued_at: i64,
|
||||
}
|
||||
|
||||
impl From<ExternalMessageRow> for ExternalMessage {
|
||||
fn from(row: ExternalMessageRow) -> Self {
|
||||
Self {
|
||||
seq: row.seq,
|
||||
id: row.id,
|
||||
thread_id: row.thread_id,
|
||||
source: row.source,
|
||||
content: row.content,
|
||||
instructions: row.instructions,
|
||||
meta_json: row.meta_json,
|
||||
delivery: row.delivery,
|
||||
queued_at: row.queued_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
mod agent_job;
|
||||
mod backfill_state;
|
||||
mod external_message;
|
||||
mod graph;
|
||||
mod log;
|
||||
mod memories;
|
||||
mod thread_metadata;
|
||||
mod thread_timer;
|
||||
|
||||
pub use agent_job::AgentJob;
|
||||
pub use agent_job::AgentJobCreateParams;
|
||||
@@ -14,6 +16,9 @@ pub use agent_job::AgentJobProgress;
|
||||
pub use agent_job::AgentJobStatus;
|
||||
pub use backfill_state::BackfillState;
|
||||
pub use backfill_state::BackfillStatus;
|
||||
pub use external_message::ExternalMessage;
|
||||
pub use external_message::ExternalMessageClaim;
|
||||
pub use external_message::ExternalMessageCreateParams;
|
||||
pub use graph::DirectionalThreadSpawnEdgeStatus;
|
||||
pub use log::LogEntry;
|
||||
pub use log::LogQuery;
|
||||
@@ -32,11 +37,16 @@ pub use thread_metadata::SortKey;
|
||||
pub use thread_metadata::ThreadMetadata;
|
||||
pub use thread_metadata::ThreadMetadataBuilder;
|
||||
pub use thread_metadata::ThreadsPage;
|
||||
pub use thread_timer::ThreadTimer;
|
||||
pub use thread_timer::ThreadTimerCreateParams;
|
||||
pub use thread_timer::ThreadTimerUpdateParams;
|
||||
|
||||
pub(crate) use agent_job::AgentJobItemRow;
|
||||
pub(crate) use agent_job::AgentJobRow;
|
||||
pub(crate) use external_message::ExternalMessageRow;
|
||||
pub(crate) use memories::Stage1OutputRow;
|
||||
pub(crate) use memories::stage1_output_ref_from_parts;
|
||||
pub(crate) use thread_metadata::ThreadRow;
|
||||
pub(crate) use thread_metadata::anchor_from_item;
|
||||
pub(crate) use thread_metadata::datetime_to_epoch_seconds;
|
||||
pub(crate) use thread_timer::ThreadTimerRow;
|
||||
|
||||
84
codex-rs/state/src/model/thread_timer.rs
Normal file
84
codex-rs/state/src/model/thread_timer.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use sqlx::FromRow;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct ThreadTimerCreateParams {
|
||||
pub id: String,
|
||||
pub thread_id: String,
|
||||
pub source: String,
|
||||
pub client_id: String,
|
||||
pub trigger_json: String,
|
||||
pub content: String,
|
||||
pub instructions: Option<String>,
|
||||
pub meta_json: String,
|
||||
pub delivery: String,
|
||||
pub created_at: i64,
|
||||
pub next_run_at: Option<i64>,
|
||||
pub last_run_at: Option<i64>,
|
||||
pub pending_run: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct ThreadTimerUpdateParams {
|
||||
pub trigger_json: String,
|
||||
pub content: String,
|
||||
pub instructions: Option<String>,
|
||||
pub meta_json: String,
|
||||
pub delivery: String,
|
||||
pub next_run_at: Option<i64>,
|
||||
pub last_run_at: Option<i64>,
|
||||
pub pending_run: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct ThreadTimer {
|
||||
pub id: String,
|
||||
pub thread_id: String,
|
||||
pub source: String,
|
||||
pub client_id: String,
|
||||
pub trigger_json: String,
|
||||
pub content: String,
|
||||
pub instructions: Option<String>,
|
||||
pub meta_json: String,
|
||||
pub delivery: String,
|
||||
pub created_at: i64,
|
||||
pub next_run_at: Option<i64>,
|
||||
pub last_run_at: Option<i64>,
|
||||
pub pending_run: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
pub(crate) struct ThreadTimerRow {
|
||||
pub id: String,
|
||||
pub thread_id: String,
|
||||
pub source: String,
|
||||
pub client_id: String,
|
||||
pub trigger_json: String,
|
||||
pub content: String,
|
||||
pub instructions: Option<String>,
|
||||
pub meta_json: String,
|
||||
pub delivery: String,
|
||||
pub created_at: i64,
|
||||
pub next_run_at: Option<i64>,
|
||||
pub last_run_at: Option<i64>,
|
||||
pub pending_run: i64,
|
||||
}
|
||||
|
||||
impl From<ThreadTimerRow> for ThreadTimer {
|
||||
fn from(row: ThreadTimerRow) -> Self {
|
||||
Self {
|
||||
id: row.id,
|
||||
thread_id: row.thread_id,
|
||||
source: row.source,
|
||||
client_id: row.client_id,
|
||||
trigger_json: row.trigger_json,
|
||||
content: row.content,
|
||||
instructions: row.instructions,
|
||||
meta_json: row.meta_json,
|
||||
delivery: row.delivery,
|
||||
created_at: row.created_at,
|
||||
next_run_at: row.next_run_at,
|
||||
last_run_at: row.last_run_at,
|
||||
pending_run: row.pending_run != 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,9 @@ use crate::AgentJobItemCreateParams;
|
||||
use crate::AgentJobItemStatus;
|
||||
use crate::AgentJobProgress;
|
||||
use crate::AgentJobStatus;
|
||||
use crate::ExternalMessage;
|
||||
use crate::ExternalMessageClaim;
|
||||
use crate::ExternalMessageCreateParams;
|
||||
use crate::LOGS_DB_FILENAME;
|
||||
use crate::LOGS_DB_VERSION;
|
||||
use crate::LogEntry;
|
||||
@@ -15,6 +18,9 @@ use crate::STATE_DB_VERSION;
|
||||
use crate::SortKey;
|
||||
use crate::ThreadMetadata;
|
||||
use crate::ThreadMetadataBuilder;
|
||||
use crate::ThreadTimer;
|
||||
use crate::ThreadTimerCreateParams;
|
||||
use crate::ThreadTimerUpdateParams;
|
||||
use crate::ThreadsPage;
|
||||
use crate::apply_rollout_item;
|
||||
use crate::migrations::runtime_logs_migrator;
|
||||
@@ -52,14 +58,18 @@ use tracing::warn;
|
||||
|
||||
mod agent_jobs;
|
||||
mod backfill;
|
||||
mod delivery_state;
|
||||
mod external_messages;
|
||||
mod logs;
|
||||
mod memories;
|
||||
mod remote_control;
|
||||
#[cfg(test)]
|
||||
mod test_support;
|
||||
mod threads;
|
||||
mod timers;
|
||||
|
||||
pub use remote_control::RemoteControlEnrollmentRecord;
|
||||
pub use timers::TimerDataVersionChecker;
|
||||
|
||||
// "Partition" is the retained-log-content bucket we cap at 10 MiB:
|
||||
// - one bucket per non-null thread_id
|
||||
|
||||
131
codex-rs/state/src/runtime/delivery_state.rs
Normal file
131
codex-rs/state/src/runtime/delivery_state.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
//! Cleanup operations for per-thread delivery state.
|
||||
//!
|
||||
//! Timers and queued external messages are stored independently because they have
|
||||
//! different runtime behavior, but thread lifecycle operations need to treat
|
||||
//! them as one unit. This module owns that cross-table cleanup.
|
||||
|
||||
use super::*;
|
||||
|
||||
impl StateRuntime {
|
||||
/// Delete all queued external messages and timers associated with `thread_id`.
|
||||
pub async fn delete_thread_delivery_state(&self, thread_id: &str) -> anyhow::Result<()> {
|
||||
let mut tx = self.pool.begin().await?;
|
||||
sqlx::query("DELETE FROM external_messages WHERE thread_id = ?")
|
||||
.bind(thread_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
sqlx::query("DELETE FROM thread_timers WHERE thread_id = ?")
|
||||
.bind(thread_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
tx.commit().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::StateRuntime;
|
||||
use super::test_support::unique_temp_dir;
|
||||
use crate::ExternalMessageCreateParams;
|
||||
use crate::ThreadTimerCreateParams;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn message_params(id: &str, thread_id: &str) -> ExternalMessageCreateParams {
|
||||
ExternalMessageCreateParams {
|
||||
id: id.to_string(),
|
||||
thread_id: thread_id.to_string(),
|
||||
source: "external".to_string(),
|
||||
content: "do something".to_string(),
|
||||
instructions: None,
|
||||
meta_json: "{}".to_string(),
|
||||
delivery: "after-turn".to_string(),
|
||||
queued_at: 100,
|
||||
}
|
||||
}
|
||||
|
||||
fn timer_params(id: &str, thread_id: &str) -> ThreadTimerCreateParams {
|
||||
ThreadTimerCreateParams {
|
||||
id: id.to_string(),
|
||||
thread_id: thread_id.to_string(),
|
||||
source: "agent".to_string(),
|
||||
client_id: "codex-tui".to_string(),
|
||||
trigger_json: r#"{"kind":"delay","seconds":10,"repeat":false}"#.to_string(),
|
||||
content: "run tests".to_string(),
|
||||
instructions: None,
|
||||
meta_json: "{}".to_string(),
|
||||
delivery: "after-turn".to_string(),
|
||||
created_at: 100,
|
||||
next_run_at: Some(110),
|
||||
last_run_at: None,
|
||||
pending_run: false,
|
||||
}
|
||||
}
|
||||
|
||||
async fn test_runtime() -> std::sync::Arc<StateRuntime> {
|
||||
StateRuntime::init(unique_temp_dir(), "test-provider".to_string())
|
||||
.await
|
||||
.expect("initialize runtime")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn delete_thread_delivery_state_removes_messages_and_timers_for_thread() {
|
||||
let runtime = test_runtime().await;
|
||||
runtime
|
||||
.create_external_message(&message_params("message-1", "thread-1"))
|
||||
.await
|
||||
.expect("create thread-1 message");
|
||||
runtime
|
||||
.create_external_message(&message_params("message-2", "thread-2"))
|
||||
.await
|
||||
.expect("create thread-2 message");
|
||||
runtime
|
||||
.create_thread_timer(&timer_params("timer-1", "thread-1"))
|
||||
.await
|
||||
.expect("create thread-1 timer");
|
||||
runtime
|
||||
.create_thread_timer(&timer_params("timer-2", "thread-2"))
|
||||
.await
|
||||
.expect("create thread-2 timer");
|
||||
|
||||
runtime
|
||||
.delete_thread_delivery_state("thread-1")
|
||||
.await
|
||||
.expect("delete delivery state");
|
||||
|
||||
assert_eq!(
|
||||
runtime
|
||||
.list_external_messages("thread-1")
|
||||
.await
|
||||
.expect("list thread-1 messages"),
|
||||
Vec::new()
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.list_thread_timers("thread-1")
|
||||
.await
|
||||
.expect("list thread-1 timers"),
|
||||
Vec::new()
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.list_external_messages("thread-2")
|
||||
.await
|
||||
.expect("list thread-2 messages")
|
||||
.into_iter()
|
||||
.map(|message| message.id)
|
||||
.collect::<Vec<_>>(),
|
||||
vec!["message-2".to_string()]
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.list_thread_timers("thread-2")
|
||||
.await
|
||||
.expect("list thread-2 timers")
|
||||
.into_iter()
|
||||
.map(|timer| timer.id)
|
||||
.collect::<Vec<_>>(),
|
||||
vec!["timer-2".to_string()]
|
||||
);
|
||||
}
|
||||
}
|
||||
443
codex-rs/state/src/runtime/external_messages.rs
Normal file
443
codex-rs/state/src/runtime/external_messages.rs
Normal file
@@ -0,0 +1,443 @@
|
||||
//! SQLite-backed state operations for queued external messages.
|
||||
//!
|
||||
//! This module extends [`StateRuntime`] with the storage APIs used by message
|
||||
//! producers and active threads. Claiming a message deletes the row inside the
|
||||
//! same transaction, so competing runtimes deliver each queued message at most
|
||||
//! once.
|
||||
|
||||
use super::*;
|
||||
use crate::model::ExternalMessageRow;
|
||||
|
||||
const DELIVERY_AFTER_TURN: &str = "after-turn";
|
||||
const DELIVERY_STEER_CURRENT_TURN: &str = "steer-current-turn";
|
||||
|
||||
impl StateRuntime {
|
||||
pub async fn create_external_message(
|
||||
&self,
|
||||
params: &ExternalMessageCreateParams,
|
||||
) -> anyhow::Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO external_messages (
|
||||
id,
|
||||
thread_id,
|
||||
source,
|
||||
content,
|
||||
instructions,
|
||||
meta_json,
|
||||
delivery,
|
||||
queued_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(params.id.as_str())
|
||||
.bind(params.thread_id.as_str())
|
||||
.bind(params.source.as_str())
|
||||
.bind(params.content.as_str())
|
||||
.bind(params.instructions.as_deref())
|
||||
.bind(params.meta_json.as_str())
|
||||
.bind(params.delivery.as_str())
|
||||
.bind(params.queued_at)
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn list_external_messages(
|
||||
&self,
|
||||
thread_id: &str,
|
||||
) -> anyhow::Result<Vec<ExternalMessage>> {
|
||||
let rows = sqlx::query_as::<_, ExternalMessageRow>(
|
||||
r#"
|
||||
SELECT
|
||||
seq,
|
||||
id,
|
||||
thread_id,
|
||||
source,
|
||||
content,
|
||||
instructions,
|
||||
meta_json,
|
||||
delivery,
|
||||
queued_at
|
||||
FROM external_messages
|
||||
WHERE thread_id = ?
|
||||
ORDER BY queued_at ASC, seq ASC
|
||||
"#,
|
||||
)
|
||||
.bind(thread_id)
|
||||
.fetch_all(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(rows.into_iter().map(ExternalMessage::from).collect())
|
||||
}
|
||||
|
||||
pub async fn delete_external_message(&self, thread_id: &str, id: &str) -> anyhow::Result<bool> {
|
||||
let result = sqlx::query("DELETE FROM external_messages WHERE thread_id = ? AND id = ?")
|
||||
.bind(thread_id)
|
||||
.bind(id)
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn claim_next_external_message(
|
||||
&self,
|
||||
thread_id: &str,
|
||||
can_after_turn: bool,
|
||||
can_steer_current_turn: bool,
|
||||
) -> anyhow::Result<Option<ExternalMessageClaim>> {
|
||||
let row = sqlx::query_as::<_, ExternalMessageRow>(
|
||||
r#"
|
||||
DELETE FROM external_messages
|
||||
WHERE seq = (
|
||||
SELECT seq
|
||||
FROM external_messages
|
||||
WHERE thread_id = ?
|
||||
ORDER BY queued_at ASC, seq ASC
|
||||
LIMIT 1
|
||||
)
|
||||
AND (
|
||||
delivery NOT IN (?, ?)
|
||||
OR (delivery = ? AND ?)
|
||||
OR (delivery = ? AND ?)
|
||||
)
|
||||
RETURNING
|
||||
seq,
|
||||
id,
|
||||
thread_id,
|
||||
source,
|
||||
content,
|
||||
instructions,
|
||||
meta_json,
|
||||
delivery,
|
||||
queued_at
|
||||
"#,
|
||||
)
|
||||
.bind(thread_id)
|
||||
.bind(DELIVERY_AFTER_TURN)
|
||||
.bind(DELIVERY_STEER_CURRENT_TURN)
|
||||
.bind(DELIVERY_AFTER_TURN)
|
||||
.bind(can_after_turn)
|
||||
.bind(DELIVERY_STEER_CURRENT_TURN)
|
||||
.bind(can_steer_current_turn || can_after_turn)
|
||||
.fetch_optional(self.pool.as_ref())
|
||||
.await?;
|
||||
|
||||
if let Some(row) = row {
|
||||
return match row.delivery.as_str() {
|
||||
DELIVERY_AFTER_TURN | DELIVERY_STEER_CURRENT_TURN => Ok(Some(
|
||||
ExternalMessageClaim::Claimed(ExternalMessage::from(row)),
|
||||
)),
|
||||
delivery => Ok(Some(ExternalMessageClaim::Invalid {
|
||||
id: row.id,
|
||||
reason: format!("invalid delivery `{delivery}`"),
|
||||
})),
|
||||
};
|
||||
}
|
||||
|
||||
let oldest_delivery = sqlx::query_scalar::<_, String>(
|
||||
r#"
|
||||
SELECT delivery
|
||||
FROM external_messages
|
||||
WHERE thread_id = ?
|
||||
ORDER BY queued_at ASC, seq ASC
|
||||
LIMIT 1
|
||||
"#,
|
||||
)
|
||||
.bind(thread_id)
|
||||
.fetch_optional(self.pool.as_ref())
|
||||
.await?;
|
||||
|
||||
match oldest_delivery.as_deref() {
|
||||
Some(DELIVERY_AFTER_TURN) if !can_after_turn => {
|
||||
Ok(Some(ExternalMessageClaim::NotReady))
|
||||
}
|
||||
Some(DELIVERY_STEER_CURRENT_TURN) if !(can_steer_current_turn || can_after_turn) => {
|
||||
Ok(Some(ExternalMessageClaim::NotReady))
|
||||
}
|
||||
None | Some(_) => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::StateRuntime;
|
||||
use super::test_support::unique_temp_dir;
|
||||
use crate::ExternalMessageClaim;
|
||||
use crate::ExternalMessageCreateParams;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn message_params(id: &str, thread_id: &str, queued_at: i64) -> ExternalMessageCreateParams {
|
||||
ExternalMessageCreateParams {
|
||||
id: id.to_string(),
|
||||
thread_id: thread_id.to_string(),
|
||||
source: "external".to_string(),
|
||||
content: "do something".to_string(),
|
||||
instructions: Some("be concise".to_string()),
|
||||
meta_json: r#"{"ticket":"ABC_123"}"#.to_string(),
|
||||
delivery: "after-turn".to_string(),
|
||||
queued_at,
|
||||
}
|
||||
}
|
||||
|
||||
async fn test_runtime() -> std::sync::Arc<StateRuntime> {
|
||||
StateRuntime::init(unique_temp_dir(), "test-provider".to_string())
|
||||
.await
|
||||
.expect("initialize runtime")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn external_messages_table_and_indexes_exist() {
|
||||
let runtime = test_runtime().await;
|
||||
let names = sqlx::query_scalar::<_, String>(
|
||||
r#"
|
||||
SELECT name
|
||||
FROM sqlite_master
|
||||
WHERE tbl_name = 'external_messages'
|
||||
AND name NOT LIKE 'sqlite_autoindex_%'
|
||||
ORDER BY name
|
||||
"#,
|
||||
)
|
||||
.fetch_all(runtime.pool.as_ref())
|
||||
.await
|
||||
.expect("query schema objects");
|
||||
|
||||
assert_eq!(
|
||||
names,
|
||||
vec![
|
||||
"external_messages",
|
||||
"external_messages_thread_delivery_order_idx",
|
||||
"external_messages_thread_order_idx",
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn external_message_rows_round_trip() {
|
||||
let runtime = test_runtime().await;
|
||||
let params = message_params("message-1", "thread-1", /*queued_at*/ 100);
|
||||
|
||||
runtime
|
||||
.create_external_message(¶ms)
|
||||
.await
|
||||
.expect("create message");
|
||||
let messages = runtime
|
||||
.list_external_messages("thread-1")
|
||||
.await
|
||||
.expect("list messages");
|
||||
|
||||
assert_eq!(messages.len(), 1);
|
||||
let message = &messages[0];
|
||||
assert_eq!(message.id, params.id);
|
||||
assert_eq!(message.thread_id, params.thread_id);
|
||||
assert_eq!(message.source, params.source);
|
||||
assert_eq!(message.content, params.content);
|
||||
assert_eq!(message.instructions, params.instructions);
|
||||
assert_eq!(message.meta_json, params.meta_json);
|
||||
assert_eq!(message.delivery, params.delivery);
|
||||
assert_eq!(message.queued_at, params.queued_at);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn delete_external_message_is_scoped_to_thread_id() {
|
||||
let runtime = test_runtime().await;
|
||||
runtime
|
||||
.create_external_message(&message_params(
|
||||
"message-1",
|
||||
"thread-1",
|
||||
/*queued_at*/ 100,
|
||||
))
|
||||
.await
|
||||
.expect("create thread-1 message");
|
||||
runtime
|
||||
.create_external_message(&message_params(
|
||||
"message-2",
|
||||
"thread-2",
|
||||
/*queued_at*/ 100,
|
||||
))
|
||||
.await
|
||||
.expect("create thread-2 message");
|
||||
|
||||
let deleted_wrong_thread = runtime
|
||||
.delete_external_message("thread-2", "message-1")
|
||||
.await
|
||||
.expect("delete wrong-external message");
|
||||
assert!(!deleted_wrong_thread);
|
||||
let deleted = runtime
|
||||
.delete_external_message("thread-1", "message-1")
|
||||
.await
|
||||
.expect("delete thread-1 message");
|
||||
assert!(deleted);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.list_external_messages("thread-1")
|
||||
.await
|
||||
.expect("list thread-1 messages"),
|
||||
Vec::new()
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.list_external_messages("thread-2")
|
||||
.await
|
||||
.expect("list thread-2 messages")
|
||||
.into_iter()
|
||||
.map(|message| message.id)
|
||||
.collect::<Vec<_>>(),
|
||||
vec!["message-2".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn claim_is_scoped_to_thread_id_and_ordered() {
|
||||
let runtime = test_runtime().await;
|
||||
runtime
|
||||
.create_external_message(&message_params("newer", "thread-1", /*queued_at*/ 200))
|
||||
.await
|
||||
.expect("create newer message");
|
||||
runtime
|
||||
.create_external_message(&message_params(
|
||||
"other-thread",
|
||||
"thread-2",
|
||||
/*queued_at*/ 50,
|
||||
))
|
||||
.await
|
||||
.expect("create other external message");
|
||||
runtime
|
||||
.create_external_message(&message_params("older", "thread-1", /*queued_at*/ 100))
|
||||
.await
|
||||
.expect("create older message");
|
||||
|
||||
let claim = runtime
|
||||
.claim_next_external_message(
|
||||
"thread-1", /*can_after_turn*/ true, /*can_steer_current_turn*/ true,
|
||||
)
|
||||
.await
|
||||
.expect("claim message");
|
||||
|
||||
let Some(ExternalMessageClaim::Claimed(claimed)) = claim else {
|
||||
panic!("expected claimed message");
|
||||
};
|
||||
assert_eq!(claimed.id, "older");
|
||||
assert_eq!(claimed.thread_id, "thread-1");
|
||||
assert_eq!(claimed.queued_at, 100);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.list_external_messages("thread-1")
|
||||
.await
|
||||
.expect("list remaining thread-1 messages")
|
||||
.into_iter()
|
||||
.map(|message| message.id)
|
||||
.collect::<Vec<_>>(),
|
||||
vec!["newer".to_string()]
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.list_external_messages("thread-2")
|
||||
.await
|
||||
.expect("list thread-2 messages")
|
||||
.into_iter()
|
||||
.map(|message| message.id)
|
||||
.collect::<Vec<_>>(),
|
||||
vec!["other-thread".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn claim_consumes_message_once() {
|
||||
let runtime = test_runtime().await;
|
||||
runtime
|
||||
.create_external_message(&message_params(
|
||||
"message-1",
|
||||
"thread-1",
|
||||
/*queued_at*/ 100,
|
||||
))
|
||||
.await
|
||||
.expect("create message");
|
||||
|
||||
assert!(matches!(
|
||||
runtime
|
||||
.claim_next_external_message(
|
||||
"thread-1", /*can_after_turn*/ true, /*can_steer_current_turn*/ true,
|
||||
)
|
||||
.await
|
||||
.expect("claim message"),
|
||||
Some(ExternalMessageClaim::Claimed(_))
|
||||
));
|
||||
assert_eq!(
|
||||
runtime
|
||||
.claim_next_external_message(
|
||||
"thread-1", /*can_after_turn*/ true, /*can_steer_current_turn*/ true,
|
||||
)
|
||||
.await
|
||||
.expect("claim message again"),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn oldest_unclaimable_message_blocks_later_messages() {
|
||||
let runtime = test_runtime().await;
|
||||
let mut steer = message_params("steer", "thread-1", /*queued_at*/ 100);
|
||||
steer.delivery = "steer-current-turn".to_string();
|
||||
runtime
|
||||
.create_external_message(&steer)
|
||||
.await
|
||||
.expect("create steer message");
|
||||
runtime
|
||||
.create_external_message(&message_params("after", "thread-1", /*queued_at*/ 200))
|
||||
.await
|
||||
.expect("create after-turn message");
|
||||
|
||||
assert_eq!(
|
||||
runtime
|
||||
.claim_next_external_message(
|
||||
"thread-1", /*can_after_turn*/ false,
|
||||
/*can_steer_current_turn*/ false,
|
||||
)
|
||||
.await
|
||||
.expect("claim message"),
|
||||
Some(ExternalMessageClaim::NotReady)
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.list_external_messages("thread-1")
|
||||
.await
|
||||
.expect("list messages")
|
||||
.into_iter()
|
||||
.map(|message| message.id)
|
||||
.collect::<Vec<_>>(),
|
||||
vec!["steer".to_string(), "after".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_delivery_is_deleted_without_claiming() {
|
||||
let runtime = test_runtime().await;
|
||||
let mut params = message_params("bad", "thread-1", /*queued_at*/ 100);
|
||||
params.delivery = "bad-delivery".to_string();
|
||||
runtime
|
||||
.create_external_message(¶ms)
|
||||
.await
|
||||
.expect("create message");
|
||||
|
||||
assert_eq!(
|
||||
runtime
|
||||
.claim_next_external_message(
|
||||
"thread-1", /*can_after_turn*/ true, /*can_steer_current_turn*/ true,
|
||||
)
|
||||
.await
|
||||
.expect("claim message"),
|
||||
Some(ExternalMessageClaim::Invalid {
|
||||
id: "bad".to_string(),
|
||||
reason: "invalid delivery `bad-delivery`".to_string(),
|
||||
})
|
||||
);
|
||||
assert!(
|
||||
runtime
|
||||
.list_external_messages("thread-1")
|
||||
.await
|
||||
.expect("list messages")
|
||||
.is_empty()
|
||||
);
|
||||
}
|
||||
}
|
||||
730
codex-rs/state/src/runtime/timers.rs
Normal file
730
codex-rs/state/src/runtime/timers.rs
Normal file
@@ -0,0 +1,730 @@
|
||||
//! SQLite-backed state operations for per-thread timers.
|
||||
//!
|
||||
//! This module extends [`StateRuntime`] with timer CRUD, due-state updates, and
|
||||
//! atomic pending-run claims. It also exposes a lightweight `PRAGMA
|
||||
//! data_version` checker so active threads can notice cross-process timer
|
||||
//! changes without constantly reconciling full timer rows.
|
||||
|
||||
use super::*;
|
||||
use crate::model::ThreadTimerRow;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub struct TimerDataVersionChecker {
|
||||
conn: Mutex<SqliteConnection>,
|
||||
}
|
||||
|
||||
impl TimerDataVersionChecker {
|
||||
pub async fn data_version(&self) -> anyhow::Result<i64> {
|
||||
let mut conn = self.conn.lock().await;
|
||||
let version = sqlx::query_scalar::<_, i64>("PRAGMA data_version")
|
||||
.fetch_one(&mut *conn)
|
||||
.await?;
|
||||
Ok(version)
|
||||
}
|
||||
}
|
||||
|
||||
impl StateRuntime {
|
||||
pub async fn timer_data_version_checker(&self) -> anyhow::Result<TimerDataVersionChecker> {
|
||||
let state_path = state_db_path(self.codex_home());
|
||||
let options = base_sqlite_options(state_path.as_path());
|
||||
let conn = options.connect().await?;
|
||||
Ok(TimerDataVersionChecker {
|
||||
conn: Mutex::new(conn),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn create_thread_timer(
|
||||
&self,
|
||||
params: &ThreadTimerCreateParams,
|
||||
) -> anyhow::Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO thread_timers (
|
||||
id,
|
||||
thread_id,
|
||||
source,
|
||||
client_id,
|
||||
trigger_json,
|
||||
content,
|
||||
instructions,
|
||||
meta_json,
|
||||
delivery,
|
||||
created_at,
|
||||
next_run_at,
|
||||
last_run_at,
|
||||
pending_run
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(params.id.as_str())
|
||||
.bind(params.thread_id.as_str())
|
||||
.bind(params.source.as_str())
|
||||
.bind(params.client_id.as_str())
|
||||
.bind(params.trigger_json.as_str())
|
||||
.bind(params.content.as_str())
|
||||
.bind(params.instructions.as_deref())
|
||||
.bind(params.meta_json.as_str())
|
||||
.bind(params.delivery.as_str())
|
||||
.bind(params.created_at)
|
||||
.bind(params.next_run_at)
|
||||
.bind(params.last_run_at)
|
||||
.bind(i64::from(params.pending_run))
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn create_thread_timer_if_below_limit(
|
||||
&self,
|
||||
params: &ThreadTimerCreateParams,
|
||||
max_thread_timers: usize,
|
||||
) -> anyhow::Result<bool> {
|
||||
let max_thread_timers = i64::try_from(max_thread_timers)?;
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO thread_timers (
|
||||
id,
|
||||
thread_id,
|
||||
source,
|
||||
client_id,
|
||||
trigger_json,
|
||||
content,
|
||||
instructions,
|
||||
meta_json,
|
||||
delivery,
|
||||
created_at,
|
||||
next_run_at,
|
||||
last_run_at,
|
||||
pending_run
|
||||
)
|
||||
SELECT ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
|
||||
WHERE (
|
||||
SELECT COUNT(*)
|
||||
FROM thread_timers
|
||||
WHERE thread_id = ?
|
||||
) < ?
|
||||
"#,
|
||||
)
|
||||
.bind(params.id.as_str())
|
||||
.bind(params.thread_id.as_str())
|
||||
.bind(params.source.as_str())
|
||||
.bind(params.client_id.as_str())
|
||||
.bind(params.trigger_json.as_str())
|
||||
.bind(params.content.as_str())
|
||||
.bind(params.instructions.as_deref())
|
||||
.bind(params.meta_json.as_str())
|
||||
.bind(params.delivery.as_str())
|
||||
.bind(params.created_at)
|
||||
.bind(params.next_run_at)
|
||||
.bind(params.last_run_at)
|
||||
.bind(i64::from(params.pending_run))
|
||||
.bind(params.thread_id.as_str())
|
||||
.bind(max_thread_timers)
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn list_thread_timers(&self, thread_id: &str) -> anyhow::Result<Vec<ThreadTimer>> {
|
||||
let rows = sqlx::query_as::<_, ThreadTimerRow>(
|
||||
r#"
|
||||
SELECT
|
||||
id,
|
||||
thread_id,
|
||||
source,
|
||||
client_id,
|
||||
trigger_json,
|
||||
content,
|
||||
instructions,
|
||||
meta_json,
|
||||
delivery,
|
||||
created_at,
|
||||
next_run_at,
|
||||
last_run_at,
|
||||
pending_run
|
||||
FROM thread_timers
|
||||
WHERE thread_id = ?
|
||||
ORDER BY created_at ASC, id ASC
|
||||
"#,
|
||||
)
|
||||
.bind(thread_id)
|
||||
.fetch_all(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(rows.into_iter().map(ThreadTimer::from).collect())
|
||||
}
|
||||
|
||||
pub async fn delete_thread_timer(&self, thread_id: &str, id: &str) -> anyhow::Result<bool> {
|
||||
let result = sqlx::query("DELETE FROM thread_timers WHERE thread_id = ? AND id = ?")
|
||||
.bind(thread_id)
|
||||
.bind(id)
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn update_thread_timer_due(
|
||||
&self,
|
||||
thread_id: &str,
|
||||
id: &str,
|
||||
due_at: i64,
|
||||
next_run_at: Option<i64>,
|
||||
) -> anyhow::Result<bool> {
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
UPDATE thread_timers
|
||||
SET pending_run = 1,
|
||||
next_run_at = ?
|
||||
WHERE thread_id = ?
|
||||
AND id = ?
|
||||
AND pending_run = 0
|
||||
AND next_run_at IS NOT NULL
|
||||
AND next_run_at <= ?
|
||||
"#,
|
||||
)
|
||||
.bind(next_run_at)
|
||||
.bind(thread_id)
|
||||
.bind(id)
|
||||
.bind(due_at)
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn claim_one_shot_thread_timer(
|
||||
&self,
|
||||
thread_id: &str,
|
||||
id: &str,
|
||||
due_at: i64,
|
||||
) -> anyhow::Result<bool> {
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
DELETE FROM thread_timers
|
||||
WHERE thread_id = ?
|
||||
AND id = ?
|
||||
AND (
|
||||
pending_run = 1
|
||||
OR (
|
||||
pending_run = 0
|
||||
AND next_run_at IS NOT NULL
|
||||
AND next_run_at <= ?
|
||||
)
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(thread_id)
|
||||
.bind(id)
|
||||
.bind(due_at)
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn claim_recurring_thread_timer(
|
||||
&self,
|
||||
thread_id: &str,
|
||||
id: &str,
|
||||
due_at: i64,
|
||||
expected_last_run_at: Option<i64>,
|
||||
params: &ThreadTimerUpdateParams,
|
||||
) -> anyhow::Result<bool> {
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
UPDATE thread_timers
|
||||
SET trigger_json = ?,
|
||||
content = ?,
|
||||
instructions = ?,
|
||||
meta_json = ?,
|
||||
delivery = ?,
|
||||
next_run_at = ?,
|
||||
last_run_at = ?,
|
||||
pending_run = ?
|
||||
WHERE thread_id = ?
|
||||
AND id = ?
|
||||
AND (
|
||||
pending_run = 1
|
||||
OR (
|
||||
pending_run = 0
|
||||
AND next_run_at IS NOT NULL
|
||||
AND next_run_at <= ?
|
||||
)
|
||||
)
|
||||
AND (
|
||||
(last_run_at IS NULL AND ? IS NULL)
|
||||
OR last_run_at = ?
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(params.trigger_json.as_str())
|
||||
.bind(params.content.as_str())
|
||||
.bind(params.instructions.as_deref())
|
||||
.bind(params.meta_json.as_str())
|
||||
.bind(params.delivery.as_str())
|
||||
.bind(params.next_run_at)
|
||||
.bind(params.last_run_at)
|
||||
.bind(i64::from(params.pending_run))
|
||||
.bind(thread_id)
|
||||
.bind(id)
|
||||
.bind(due_at)
|
||||
.bind(expected_last_run_at)
|
||||
.bind(expected_last_run_at)
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::StateRuntime;
|
||||
use super::test_support::unique_temp_dir;
|
||||
use crate::ThreadTimer;
|
||||
use crate::ThreadTimerCreateParams;
|
||||
use crate::ThreadTimerUpdateParams;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn timer_params(id: &str, thread_id: &str) -> ThreadTimerCreateParams {
|
||||
ThreadTimerCreateParams {
|
||||
id: id.to_string(),
|
||||
thread_id: thread_id.to_string(),
|
||||
source: "agent".to_string(),
|
||||
client_id: "codex-tui".to_string(),
|
||||
trigger_json: r#"{"kind":"delay","seconds":10,"repeat":true}"#.to_string(),
|
||||
content: "run tests".to_string(),
|
||||
instructions: Some("keep output brief".to_string()),
|
||||
meta_json: r#"{"ticket":"ABC_123"}"#.to_string(),
|
||||
delivery: "after-turn".to_string(),
|
||||
created_at: 100,
|
||||
next_run_at: Some(110),
|
||||
last_run_at: None,
|
||||
pending_run: false,
|
||||
}
|
||||
}
|
||||
|
||||
async fn test_runtime() -> std::sync::Arc<StateRuntime> {
|
||||
StateRuntime::init(unique_temp_dir(), "test-provider".to_string())
|
||||
.await
|
||||
.expect("initialize runtime")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_timers_table_and_indexes_exist() {
|
||||
let runtime = test_runtime().await;
|
||||
let names = sqlx::query_scalar::<_, String>(
|
||||
r#"
|
||||
SELECT name
|
||||
FROM sqlite_master
|
||||
WHERE tbl_name = 'thread_timers'
|
||||
AND name NOT LIKE 'sqlite_autoindex_%'
|
||||
ORDER BY name
|
||||
"#,
|
||||
)
|
||||
.fetch_all(runtime.pool.as_ref())
|
||||
.await
|
||||
.expect("query schema objects");
|
||||
|
||||
assert_eq!(
|
||||
names,
|
||||
vec![
|
||||
"idx_thread_timers_thread_created",
|
||||
"idx_thread_timers_thread_next_run",
|
||||
"idx_thread_timers_thread_pending",
|
||||
"thread_timers",
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_timer_rows_round_trip_source_and_client_metadata() {
|
||||
let runtime = test_runtime().await;
|
||||
let mut params = timer_params("timer-1", "thread-1");
|
||||
params.pending_run = true;
|
||||
params.last_run_at = Some(105);
|
||||
|
||||
runtime
|
||||
.create_thread_timer(¶ms)
|
||||
.await
|
||||
.expect("create timer");
|
||||
let timers = runtime
|
||||
.list_thread_timers("thread-1")
|
||||
.await
|
||||
.expect("list timers");
|
||||
|
||||
assert_eq!(timers.len(), 1);
|
||||
let timer = &timers[0];
|
||||
assert_eq!(timer.id, params.id);
|
||||
assert_eq!(timer.thread_id, params.thread_id);
|
||||
assert_eq!(timer.source, params.source);
|
||||
assert_eq!(timer.client_id, params.client_id);
|
||||
assert_eq!(timer.trigger_json, params.trigger_json);
|
||||
assert_eq!(timer.content, params.content);
|
||||
assert_eq!(timer.instructions, params.instructions);
|
||||
assert_eq!(timer.meta_json, params.meta_json);
|
||||
assert_eq!(timer.delivery, params.delivery);
|
||||
assert_eq!(timer.created_at, params.created_at);
|
||||
assert_eq!(timer.next_run_at, params.next_run_at);
|
||||
assert_eq!(timer.last_run_at, params.last_run_at);
|
||||
assert_eq!(timer.pending_run, params.pending_run);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_timer_crud_is_scoped_to_thread_id() {
|
||||
let runtime = test_runtime().await;
|
||||
runtime
|
||||
.create_thread_timer(&timer_params("timer-1", "thread-1"))
|
||||
.await
|
||||
.expect("create thread-1 timer");
|
||||
runtime
|
||||
.create_thread_timer(&timer_params("timer-2", "thread-2"))
|
||||
.await
|
||||
.expect("create thread-2 timer");
|
||||
|
||||
assert_eq!(
|
||||
runtime
|
||||
.list_thread_timers("thread-1")
|
||||
.await
|
||||
.expect("list thread-1 timers")
|
||||
.into_iter()
|
||||
.map(|timer| timer.id)
|
||||
.collect::<Vec<_>>(),
|
||||
vec!["timer-1".to_string()]
|
||||
);
|
||||
assert!(
|
||||
!runtime
|
||||
.delete_thread_timer("thread-1", "timer-2")
|
||||
.await
|
||||
.expect("delete wrong thread timer")
|
||||
);
|
||||
assert!(
|
||||
runtime
|
||||
.delete_thread_timer("thread-2", "timer-2")
|
||||
.await
|
||||
.expect("delete correct thread timer")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_thread_timer_if_below_limit_rejects_full_thread() {
|
||||
let runtime = test_runtime().await;
|
||||
assert!(
|
||||
runtime
|
||||
.create_thread_timer_if_below_limit(
|
||||
&timer_params("timer-1", "thread-1"),
|
||||
/*max_thread_timers*/ 2,
|
||||
)
|
||||
.await
|
||||
.expect("create first timer")
|
||||
);
|
||||
assert!(
|
||||
runtime
|
||||
.create_thread_timer_if_below_limit(
|
||||
&timer_params("timer-2", "thread-1"),
|
||||
/*max_thread_timers*/ 2,
|
||||
)
|
||||
.await
|
||||
.expect("create second timer")
|
||||
);
|
||||
assert!(
|
||||
!runtime
|
||||
.create_thread_timer_if_below_limit(
|
||||
&timer_params("timer-3", "thread-1"),
|
||||
/*max_thread_timers*/ 2,
|
||||
)
|
||||
.await
|
||||
.expect("reject third timer")
|
||||
);
|
||||
assert!(
|
||||
runtime
|
||||
.create_thread_timer_if_below_limit(
|
||||
&timer_params("timer-4", "thread-2"),
|
||||
/*max_thread_timers*/ 2,
|
||||
)
|
||||
.await
|
||||
.expect("create timer for different thread")
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
runtime
|
||||
.list_thread_timers("thread-1")
|
||||
.await
|
||||
.expect("list thread-1 timers")
|
||||
.into_iter()
|
||||
.map(|timer| timer.id)
|
||||
.collect::<Vec<_>>(),
|
||||
vec!["timer-1".to_string(), "timer-2".to_string()]
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.list_thread_timers("thread-2")
|
||||
.await
|
||||
.expect("list thread-2 timers")
|
||||
.into_iter()
|
||||
.map(|timer| timer.id)
|
||||
.collect::<Vec<_>>(),
|
||||
vec!["timer-4".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn one_shot_claim_consumes_pending_timer_once() {
|
||||
let runtime = test_runtime().await;
|
||||
let mut params = timer_params("timer-1", "thread-1");
|
||||
params.pending_run = true;
|
||||
params.next_run_at = None;
|
||||
runtime
|
||||
.create_thread_timer(¶ms)
|
||||
.await
|
||||
.expect("create pending timer");
|
||||
|
||||
assert!(
|
||||
runtime
|
||||
.claim_one_shot_thread_timer("thread-1", "timer-1", /*due_at*/ 110)
|
||||
.await
|
||||
.expect("claim timer")
|
||||
);
|
||||
assert!(
|
||||
!runtime
|
||||
.claim_one_shot_thread_timer("thread-1", "timer-1", /*due_at*/ 110)
|
||||
.await
|
||||
.expect("claim timer again")
|
||||
);
|
||||
assert!(
|
||||
runtime
|
||||
.list_thread_timers("thread-1")
|
||||
.await
|
||||
.expect("list timers")
|
||||
.is_empty()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recurring_claim_updates_pending_timer_once() {
|
||||
let runtime = test_runtime().await;
|
||||
let mut params = timer_params("timer-1", "thread-1");
|
||||
params.pending_run = true;
|
||||
runtime
|
||||
.create_thread_timer(¶ms)
|
||||
.await
|
||||
.expect("create pending timer");
|
||||
let update = ThreadTimerUpdateParams {
|
||||
trigger_json: params.trigger_json.clone(),
|
||||
content: "updated content".to_string(),
|
||||
instructions: None,
|
||||
meta_json: "{}".to_string(),
|
||||
delivery: "steer-current-turn".to_string(),
|
||||
next_run_at: Some(120),
|
||||
last_run_at: Some(110),
|
||||
pending_run: false,
|
||||
};
|
||||
|
||||
assert!(
|
||||
runtime
|
||||
.claim_recurring_thread_timer(
|
||||
"thread-1", "timer-1", /*due_at*/ 110, /*expected_last_run_at*/ None,
|
||||
&update,
|
||||
)
|
||||
.await
|
||||
.expect("claim recurring timer")
|
||||
);
|
||||
assert!(
|
||||
!runtime
|
||||
.claim_recurring_thread_timer(
|
||||
"thread-1", "timer-1", /*due_at*/ 110, /*expected_last_run_at*/ None,
|
||||
&update,
|
||||
)
|
||||
.await
|
||||
.expect("claim recurring timer again")
|
||||
);
|
||||
let timers = runtime
|
||||
.list_thread_timers("thread-1")
|
||||
.await
|
||||
.expect("list timers");
|
||||
assert_eq!(timers.len(), 1);
|
||||
assert_eq!(timers[0].delivery, "steer-current-turn");
|
||||
assert_eq!(timers[0].content, "updated content");
|
||||
assert_eq!(timers[0].instructions, None);
|
||||
assert_eq!(timers[0].meta_json, "{}");
|
||||
assert_eq!(timers[0].next_run_at, Some(120));
|
||||
assert_eq!(timers[0].last_run_at, Some(110));
|
||||
assert!(!timers[0].pending_run);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn one_shot_claim_consumes_overdue_timer_after_restart() {
|
||||
let runtime = test_runtime().await;
|
||||
let mut params = timer_params("timer-1", "thread-1");
|
||||
params.trigger_json = r#"{"kind":"delay","seconds":10,"repeat":false}"#.to_string();
|
||||
params.next_run_at = Some(110);
|
||||
params.pending_run = false;
|
||||
runtime
|
||||
.create_thread_timer(¶ms)
|
||||
.await
|
||||
.expect("create overdue one-shot timer");
|
||||
|
||||
assert!(
|
||||
runtime
|
||||
.claim_one_shot_thread_timer("thread-1", "timer-1", /*due_at*/ 110)
|
||||
.await
|
||||
.expect("claim overdue one-shot timer")
|
||||
);
|
||||
assert!(
|
||||
runtime
|
||||
.list_thread_timers("thread-1")
|
||||
.await
|
||||
.expect("list timers")
|
||||
.is_empty()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recurring_claim_consumes_overdue_timer_after_restart() {
|
||||
let runtime = test_runtime().await;
|
||||
let mut params = timer_params("timer-1", "thread-1");
|
||||
params.next_run_at = Some(110);
|
||||
params.pending_run = false;
|
||||
runtime
|
||||
.create_thread_timer(¶ms)
|
||||
.await
|
||||
.expect("create overdue recurring timer");
|
||||
let update = ThreadTimerUpdateParams {
|
||||
trigger_json: params.trigger_json.clone(),
|
||||
content: params.content.clone(),
|
||||
instructions: params.instructions.clone(),
|
||||
meta_json: params.meta_json.clone(),
|
||||
delivery: params.delivery.clone(),
|
||||
next_run_at: Some(120),
|
||||
last_run_at: Some(110),
|
||||
pending_run: false,
|
||||
};
|
||||
|
||||
assert!(
|
||||
runtime
|
||||
.claim_recurring_thread_timer(
|
||||
"thread-1", "timer-1", /*due_at*/ 110, /*expected_last_run_at*/ None,
|
||||
&update,
|
||||
)
|
||||
.await
|
||||
.expect("claim overdue recurring timer")
|
||||
);
|
||||
let timers = runtime
|
||||
.list_thread_timers("thread-1")
|
||||
.await
|
||||
.expect("list timers");
|
||||
assert_eq!(timers.len(), 1);
|
||||
assert_eq!(timers[0].next_run_at, Some(120));
|
||||
assert_eq!(timers[0].last_run_at, Some(110));
|
||||
assert!(!timers[0].pending_run);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn due_update_rejects_stale_timer_row_after_claim() {
|
||||
let runtime = test_runtime().await;
|
||||
let mut params = timer_params("timer-1", "thread-1");
|
||||
params.next_run_at = Some(110);
|
||||
params.pending_run = false;
|
||||
runtime
|
||||
.create_thread_timer(¶ms)
|
||||
.await
|
||||
.expect("create overdue recurring timer");
|
||||
let update = ThreadTimerUpdateParams {
|
||||
trigger_json: params.trigger_json.clone(),
|
||||
content: params.content.clone(),
|
||||
instructions: params.instructions.clone(),
|
||||
meta_json: params.meta_json.clone(),
|
||||
delivery: params.delivery.clone(),
|
||||
next_run_at: Some(120),
|
||||
last_run_at: Some(110),
|
||||
pending_run: false,
|
||||
};
|
||||
assert!(
|
||||
runtime
|
||||
.claim_recurring_thread_timer(
|
||||
"thread-1", "timer-1", /*due_at*/ 110, /*expected_last_run_at*/ None,
|
||||
&update,
|
||||
)
|
||||
.await
|
||||
.expect("claim overdue recurring timer")
|
||||
);
|
||||
|
||||
assert!(
|
||||
!runtime
|
||||
.update_thread_timer_due("thread-1", "timer-1", /*due_at*/ 110, Some(130))
|
||||
.await
|
||||
.expect("stale due update should be rejected")
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.list_thread_timers("thread-1")
|
||||
.await
|
||||
.expect("list timers"),
|
||||
vec![ThreadTimer {
|
||||
id: params.id,
|
||||
thread_id: params.thread_id,
|
||||
source: params.source,
|
||||
client_id: params.client_id,
|
||||
trigger_json: params.trigger_json,
|
||||
content: params.content,
|
||||
instructions: params.instructions,
|
||||
meta_json: params.meta_json,
|
||||
delivery: params.delivery,
|
||||
created_at: params.created_at,
|
||||
next_run_at: Some(120),
|
||||
last_run_at: Some(110),
|
||||
pending_run: false,
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recurring_idle_claim_rejects_stale_last_run_at_even_when_pending_stays_true() {
|
||||
let runtime = test_runtime().await;
|
||||
let mut params = timer_params("timer-1", "thread-1");
|
||||
params.pending_run = true;
|
||||
params.last_run_at = Some(100);
|
||||
runtime
|
||||
.create_thread_timer(¶ms)
|
||||
.await
|
||||
.expect("create pending timer");
|
||||
let update = ThreadTimerUpdateParams {
|
||||
trigger_json: params.trigger_json.clone(),
|
||||
content: params.content.clone(),
|
||||
instructions: params.instructions.clone(),
|
||||
meta_json: params.meta_json.clone(),
|
||||
delivery: params.delivery.clone(),
|
||||
next_run_at: Some(120),
|
||||
last_run_at: Some(110),
|
||||
pending_run: true,
|
||||
};
|
||||
|
||||
assert!(
|
||||
runtime
|
||||
.claim_recurring_thread_timer(
|
||||
"thread-1",
|
||||
"timer-1",
|
||||
/*due_at*/ 110,
|
||||
/*expected_last_run_at*/ Some(100),
|
||||
&update,
|
||||
)
|
||||
.await
|
||||
.expect("claim recurring idle timer")
|
||||
);
|
||||
assert!(
|
||||
!runtime
|
||||
.claim_recurring_thread_timer(
|
||||
"thread-1",
|
||||
"timer-1",
|
||||
/*due_at*/ 110,
|
||||
/*expected_last_run_at*/ Some(100),
|
||||
&update,
|
||||
)
|
||||
.await
|
||||
.expect("claim recurring idle timer again")
|
||||
);
|
||||
let timers = runtime
|
||||
.list_thread_timers("thread-1")
|
||||
.await
|
||||
.expect("list timers");
|
||||
assert_eq!(timers.len(), 1);
|
||||
assert_eq!(timers[0].last_run_at, Some(110));
|
||||
assert!(timers[0].pending_run);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user