Files
codex/codex-rs/state/src/runtime.rs
2026-02-02 20:57:17 +00:00

840 lines
25 KiB
Rust

use crate::DB_ERROR_METRIC;
use crate::LogEntry;
use crate::LogQuery;
use crate::LogRow;
use crate::SortKey;
use crate::ThreadMetadata;
use crate::ThreadMetadataBuilder;
use crate::ThreadsPage;
use crate::apply_rollout_item;
use crate::migrations::MIGRATOR;
use crate::model::MemorySummary;
use crate::model::MemorySummaryLock;
use crate::model::ThreadRow;
use crate::model::anchor_from_item;
use crate::model::datetime_to_epoch_seconds;
use crate::paths::file_modified_time_utc;
use chrono::DateTime;
use chrono::Utc;
use codex_otel::OtelManager;
use codex_protocol::ThreadId;
use codex_protocol::protocol::RolloutItem;
use log::LevelFilter;
use sqlx::ConnectOptions;
use sqlx::QueryBuilder;
use sqlx::Row;
use sqlx::Sqlite;
use sqlx::SqlitePool;
use sqlx::sqlite::SqliteConnectOptions;
use sqlx::sqlite::SqliteJournalMode;
use sqlx::sqlite::SqlitePoolOptions;
use sqlx::sqlite::SqliteSynchronous;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tracing::warn;
pub const STATE_DB_FILENAME: &str = "state.sqlite";
const METRIC_DB_INIT: &str = "codex.db.init";
#[derive(Clone)]
pub struct StateRuntime {
codex_home: PathBuf,
default_provider: String,
pool: Arc<sqlx::SqlitePool>,
}
impl StateRuntime {
/// Initialize the state runtime using the provided Codex home and default provider.
///
/// This opens (and migrates) the SQLite database at `codex_home/state.sqlite`.
pub async fn init(
codex_home: PathBuf,
default_provider: String,
otel: Option<OtelManager>,
) -> anyhow::Result<Arc<Self>> {
tokio::fs::create_dir_all(&codex_home).await?;
let state_path = codex_home.join(STATE_DB_FILENAME);
let existed = tokio::fs::try_exists(&state_path).await.unwrap_or(false);
let pool = match open_sqlite(&state_path).await {
Ok(db) => Arc::new(db),
Err(err) => {
warn!("failed to open state db at {}: {err}", state_path.display());
if let Some(otel) = otel.as_ref() {
otel.counter(METRIC_DB_INIT, 1, &[("status", "open_error")]);
}
return Err(err);
}
};
if let Some(otel) = otel.as_ref() {
otel.counter(METRIC_DB_INIT, 1, &[("status", "opened")]);
}
let runtime = Arc::new(Self {
pool,
codex_home,
default_provider,
});
if !existed && let Some(otel) = otel.as_ref() {
otel.counter(METRIC_DB_INIT, 1, &[("status", "created")]);
}
Ok(runtime)
}
/// Return the configured Codex home directory for this runtime.
pub fn codex_home(&self) -> &Path {
self.codex_home.as_path()
}
/// Load thread metadata by id using the underlying database.
pub async fn get_thread(&self, id: ThreadId) -> anyhow::Result<Option<crate::ThreadMetadata>> {
let row = sqlx::query(
r#"
SELECT
id,
rollout_path,
created_at,
updated_at,
source,
model_provider,
cwd,
title,
sandbox_policy,
approval_mode,
tokens_used,
has_user_event,
archived_at,
git_sha,
git_branch,
git_origin_url
FROM threads
WHERE id = ?
"#,
)
.bind(id.to_string())
.fetch_optional(self.pool.as_ref())
.await?;
row.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.transpose()
}
/// Find a rollout path by thread id using the underlying database.
pub async fn find_rollout_path_by_id(
&self,
id: ThreadId,
archived_only: Option<bool>,
) -> anyhow::Result<Option<PathBuf>> {
let mut builder =
QueryBuilder::<Sqlite>::new("SELECT rollout_path FROM threads WHERE id = ");
builder.push_bind(id.to_string());
match archived_only {
Some(true) => {
builder.push(" AND archived = 1");
}
Some(false) => {
builder.push(" AND archived = 0");
}
None => {}
}
let row = builder.build().fetch_optional(self.pool.as_ref()).await?;
Ok(row
.and_then(|r| r.try_get::<String, _>("rollout_path").ok())
.map(PathBuf::from))
}
/// List threads using the underlying database.
pub async fn list_threads(
&self,
page_size: usize,
anchor: Option<&crate::Anchor>,
sort_key: crate::SortKey,
allowed_sources: &[String],
model_providers: Option<&[String]>,
archived_only: bool,
) -> anyhow::Result<crate::ThreadsPage> {
let limit = page_size.saturating_add(1);
let mut builder = QueryBuilder::<Sqlite>::new(
r#"
SELECT
id,
rollout_path,
created_at,
updated_at,
source,
model_provider,
cwd,
title,
sandbox_policy,
approval_mode,
tokens_used,
has_user_event,
archived_at,
git_sha,
git_branch,
git_origin_url
FROM threads
"#,
);
push_thread_filters(
&mut builder,
archived_only,
allowed_sources,
model_providers,
anchor,
sort_key,
);
push_thread_order_and_limit(&mut builder, sort_key, limit);
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
let mut items = rows
.into_iter()
.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.collect::<Result<Vec<_>, _>>()?;
let num_scanned_rows = items.len();
let next_anchor = if items.len() > page_size {
items.pop();
items
.last()
.and_then(|item| anchor_from_item(item, sort_key))
} else {
None
};
Ok(ThreadsPage {
items,
next_anchor,
num_scanned_rows,
})
}
/// Insert one log entry into the logs table.
pub async fn insert_log(&self, entry: &LogEntry) -> anyhow::Result<()> {
self.insert_logs(std::slice::from_ref(entry)).await
}
/// Insert a batch of log entries into the logs table.
pub async fn insert_logs(&self, entries: &[LogEntry]) -> anyhow::Result<()> {
if entries.is_empty() {
return Ok(());
}
let mut builder = QueryBuilder::<Sqlite>::new(
"INSERT INTO logs (ts, ts_nanos, level, target, message, thread_id, module_path, file, line) ",
);
builder.push_values(entries, |mut row, entry| {
row.push_bind(entry.ts)
.push_bind(entry.ts_nanos)
.push_bind(&entry.level)
.push_bind(&entry.target)
.push_bind(&entry.message)
.push_bind(&entry.thread_id)
.push_bind(&entry.module_path)
.push_bind(&entry.file)
.push_bind(entry.line);
});
builder.build().execute(self.pool.as_ref()).await?;
Ok(())
}
pub(crate) async fn delete_logs_before(&self, cutoff_ts: i64) -> anyhow::Result<u64> {
let result = sqlx::query("DELETE FROM logs WHERE ts < ?")
.bind(cutoff_ts)
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected())
}
/// Query logs with optional filters.
pub async fn query_logs(&self, query: &LogQuery) -> anyhow::Result<Vec<LogRow>> {
let mut builder = QueryBuilder::<Sqlite>::new(
"SELECT id, ts, ts_nanos, level, target, message, thread_id, file, line FROM logs WHERE 1 = 1",
);
push_log_filters(&mut builder, query);
if query.descending {
builder.push(" ORDER BY id DESC");
} else {
builder.push(" ORDER BY id ASC");
}
if let Some(limit) = query.limit {
builder.push(" LIMIT ").push_bind(limit as i64);
}
let rows = builder
.build_query_as::<LogRow>()
.fetch_all(self.pool.as_ref())
.await?;
Ok(rows)
}
/// Return the max log id matching optional filters.
pub async fn max_log_id(&self, query: &LogQuery) -> anyhow::Result<i64> {
let mut builder =
QueryBuilder::<Sqlite>::new("SELECT MAX(id) AS max_id FROM logs WHERE 1 = 1");
push_log_filters(&mut builder, query);
let row = builder.build().fetch_one(self.pool.as_ref()).await?;
let max_id: Option<i64> = row.try_get("max_id")?;
Ok(max_id.unwrap_or(0))
}
/// Upsert a memory summary for a thread.
pub async fn upsert_memory_summary(&self, summary: &MemorySummary) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO memory_summaries (
thread_id,
cwd,
summary,
created_at,
updated_at
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(thread_id) DO UPDATE SET
cwd = excluded.cwd,
summary = excluded.summary,
updated_at = excluded.updated_at
"#,
)
.bind(summary.thread_id.to_string())
.bind(summary.cwd.to_string_lossy().as_ref())
.bind(&summary.summary)
.bind(summary.created_at)
.bind(summary.updated_at)
.execute(self.pool.as_ref())
.await?;
Ok(())
}
/// Fetch the most recent memory summaries for a given cwd.
pub async fn list_memory_summaries_for_cwd(
&self,
cwd: &Path,
limit: usize,
) -> anyhow::Result<Vec<MemorySummary>> {
let rows = sqlx::query(
r#"
SELECT thread_id, cwd, summary, created_at, updated_at
FROM memory_summaries
WHERE cwd = ?
ORDER BY updated_at DESC, thread_id DESC
LIMIT ?
"#,
)
.bind(cwd.to_string_lossy().as_ref())
.bind(limit as i64)
.fetch_all(self.pool.as_ref())
.await?;
rows.into_iter()
.map(|row| {
let thread_id: String = row.try_get("thread_id")?;
Ok(MemorySummary {
thread_id: ThreadId::try_from(thread_id)?,
cwd: PathBuf::from(row.try_get::<String, _>("cwd")?),
summary: row.try_get("summary")?,
created_at: row.try_get("created_at")?,
updated_at: row.try_get("updated_at")?,
})
})
.collect()
}
/// Fetch memory summaries for a set of thread ids.
pub async fn list_memory_summaries_for_threads(
&self,
thread_ids: &[ThreadId],
) -> anyhow::Result<Vec<MemorySummary>> {
if thread_ids.is_empty() {
return Ok(Vec::new());
}
let mut builder = QueryBuilder::<Sqlite>::new(
"SELECT thread_id, cwd, summary, created_at, updated_at FROM memory_summaries WHERE thread_id IN (",
);
let mut separated = builder.separated(", ");
for thread_id in thread_ids {
separated.push_bind(thread_id.to_string());
}
separated.push_unseparated(")");
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
rows.into_iter()
.map(|row| {
let thread_id: String = row.try_get("thread_id")?;
Ok(MemorySummary {
thread_id: ThreadId::try_from(thread_id)?,
cwd: PathBuf::from(row.try_get::<String, _>("cwd")?),
summary: row.try_get("summary")?,
created_at: row.try_get("created_at")?,
updated_at: row.try_get("updated_at")?,
})
})
.collect()
}
/// Acquire a per-cwd summary lock if it is free or expired.
pub async fn try_acquire_memory_summary_lock(
&self,
cwd: &Path,
owner_id: &str,
now_ts: i64,
ttl_secs: i64,
) -> anyhow::Result<bool> {
let expires_at = now_ts.saturating_add(ttl_secs);
let mut txn = self.pool.begin().await?;
let updated = sqlx::query(
r#"
UPDATE memory_summary_locks
SET owner_id = ?, acquired_at = ?, expires_at = ?
WHERE cwd = ? AND expires_at <= ?
"#,
)
.bind(owner_id)
.bind(now_ts)
.bind(expires_at)
.bind(cwd.to_string_lossy().as_ref())
.bind(now_ts)
.execute(&mut *txn)
.await?
.rows_affected();
if updated == 0 {
let inserted = sqlx::query(
r#"
INSERT OR IGNORE INTO memory_summary_locks (cwd, owner_id, acquired_at, expires_at)
VALUES (?, ?, ?, ?)
"#,
)
.bind(cwd.to_string_lossy().as_ref())
.bind(owner_id)
.bind(now_ts)
.bind(expires_at)
.execute(&mut *txn)
.await?
.rows_affected();
if inserted == 0 {
txn.rollback().await?;
return Ok(false);
}
}
txn.commit().await?;
Ok(true)
}
/// Release a per-cwd summary lock if owned by the caller.
pub async fn release_memory_summary_lock(
&self,
cwd: &Path,
owner_id: &str,
) -> anyhow::Result<()> {
sqlx::query(
r#"
DELETE FROM memory_summary_locks
WHERE cwd = ? AND owner_id = ?
"#,
)
.bind(cwd.to_string_lossy().as_ref())
.bind(owner_id)
.execute(self.pool.as_ref())
.await?;
Ok(())
}
/// Fetch the current memory summary lock, if any.
pub async fn get_memory_summary_lock(
&self,
cwd: &Path,
) -> anyhow::Result<Option<MemorySummaryLock>> {
let row = sqlx::query(
r#"
SELECT cwd, owner_id, acquired_at, expires_at
FROM memory_summary_locks
WHERE cwd = ?
"#,
)
.bind(cwd.to_string_lossy().as_ref())
.fetch_optional(self.pool.as_ref())
.await?;
let Some(row) = row else {
return Ok(None);
};
Ok(Some(MemorySummaryLock {
cwd: PathBuf::from(row.try_get::<String, _>("cwd")?),
owner_id: row.try_get("owner_id")?,
acquired_at: row.try_get("acquired_at")?,
expires_at: row.try_get("expires_at")?,
}))
}
/// List thread ids using the underlying database (no rollout scanning).
pub async fn list_thread_ids(
&self,
limit: usize,
anchor: Option<&crate::Anchor>,
sort_key: crate::SortKey,
allowed_sources: &[String],
model_providers: Option<&[String]>,
archived_only: bool,
) -> anyhow::Result<Vec<ThreadId>> {
let mut builder = QueryBuilder::<Sqlite>::new("SELECT id FROM threads");
push_thread_filters(
&mut builder,
archived_only,
allowed_sources,
model_providers,
anchor,
sort_key,
);
push_thread_order_and_limit(&mut builder, sort_key, limit);
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
rows.into_iter()
.map(|row| {
let id: String = row.try_get("id")?;
Ok(ThreadId::try_from(id)?)
})
.collect()
}
/// List recent threads for a cwd using the underlying database (no rollout scanning).
pub async fn list_recent_threads_for_cwd(
&self,
cwd: &Path,
limit: usize,
sort_key: SortKey,
allowed_sources: &[String],
model_providers: Option<&[String]>,
archived_only: bool,
) -> anyhow::Result<Vec<ThreadMetadata>> {
let mut builder = QueryBuilder::<Sqlite>::new(
r#"
SELECT
id,
rollout_path,
created_at,
updated_at,
source,
model_provider,
cwd,
title,
sandbox_policy,
approval_mode,
tokens_used,
has_user_event,
archived_at,
git_sha,
git_branch,
git_origin_url
FROM threads
"#,
);
push_thread_filters(
&mut builder,
archived_only,
allowed_sources,
model_providers,
None,
sort_key,
);
let cwd_value = cwd.to_string_lossy();
builder.push(" AND cwd = ");
builder.push_bind(cwd_value.as_ref());
push_thread_order_and_limit(&mut builder, sort_key, limit);
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
rows.into_iter()
.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.collect()
}
/// Insert or replace thread metadata directly.
pub async fn upsert_thread(&self, metadata: &crate::ThreadMetadata) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO threads (
id,
rollout_path,
created_at,
updated_at,
source,
model_provider,
cwd,
title,
sandbox_policy,
approval_mode,
tokens_used,
has_user_event,
archived,
archived_at,
git_sha,
git_branch,
git_origin_url
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
rollout_path = excluded.rollout_path,
created_at = excluded.created_at,
updated_at = excluded.updated_at,
source = excluded.source,
model_provider = excluded.model_provider,
cwd = excluded.cwd,
title = excluded.title,
sandbox_policy = excluded.sandbox_policy,
approval_mode = excluded.approval_mode,
tokens_used = excluded.tokens_used,
has_user_event = excluded.has_user_event,
archived = excluded.archived,
archived_at = excluded.archived_at,
git_sha = excluded.git_sha,
git_branch = excluded.git_branch,
git_origin_url = excluded.git_origin_url
"#,
)
.bind(metadata.id.to_string())
.bind(metadata.rollout_path.display().to_string())
.bind(datetime_to_epoch_seconds(metadata.created_at))
.bind(datetime_to_epoch_seconds(metadata.updated_at))
.bind(metadata.source.as_str())
.bind(metadata.model_provider.as_str())
.bind(metadata.cwd.display().to_string())
.bind(metadata.title.as_str())
.bind(metadata.sandbox_policy.as_str())
.bind(metadata.approval_mode.as_str())
.bind(metadata.tokens_used)
.bind(metadata.has_user_event)
.bind(metadata.archived_at.is_some())
.bind(metadata.archived_at.map(datetime_to_epoch_seconds))
.bind(metadata.git_sha.as_deref())
.bind(metadata.git_branch.as_deref())
.bind(metadata.git_origin_url.as_deref())
.execute(self.pool.as_ref())
.await?;
Ok(())
}
/// Apply rollout items incrementally using the underlying database.
pub async fn apply_rollout_items(
&self,
builder: &ThreadMetadataBuilder,
items: &[RolloutItem],
otel: Option<&OtelManager>,
) -> anyhow::Result<()> {
if items.is_empty() {
return Ok(());
}
let mut metadata = self
.get_thread(builder.id)
.await?
.unwrap_or_else(|| builder.build(&self.default_provider));
metadata.rollout_path = builder.rollout_path.clone();
for item in items {
apply_rollout_item(&mut metadata, item, &self.default_provider);
}
if let Some(updated_at) = file_modified_time_utc(builder.rollout_path.as_path()).await {
metadata.updated_at = updated_at;
}
if let Err(err) = self.upsert_thread(&metadata).await {
if let Some(otel) = otel {
otel.counter(DB_ERROR_METRIC, 1, &[("stage", "apply_rollout_items")]);
}
return Err(err);
}
Ok(())
}
/// Mark a thread as archived using the underlying database.
pub async fn mark_archived(
&self,
thread_id: ThreadId,
rollout_path: &Path,
archived_at: DateTime<Utc>,
) -> anyhow::Result<()> {
let Some(mut metadata) = self.get_thread(thread_id).await? else {
return Ok(());
};
metadata.archived_at = Some(archived_at);
metadata.rollout_path = rollout_path.to_path_buf();
if let Some(updated_at) = file_modified_time_utc(rollout_path).await {
metadata.updated_at = updated_at;
}
if metadata.id != thread_id {
warn!(
"thread id mismatch during archive: expected {thread_id}, got {}",
metadata.id
);
}
self.upsert_thread(&metadata).await
}
/// Mark a thread as unarchived using the underlying database.
pub async fn mark_unarchived(
&self,
thread_id: ThreadId,
rollout_path: &Path,
) -> anyhow::Result<()> {
let Some(mut metadata) = self.get_thread(thread_id).await? else {
return Ok(());
};
metadata.archived_at = None;
metadata.rollout_path = rollout_path.to_path_buf();
if let Some(updated_at) = file_modified_time_utc(rollout_path).await {
metadata.updated_at = updated_at;
}
if metadata.id != thread_id {
warn!(
"thread id mismatch during unarchive: expected {thread_id}, got {}",
metadata.id
);
}
self.upsert_thread(&metadata).await
}
}
fn push_log_filters<'a>(builder: &mut QueryBuilder<'a, Sqlite>, query: &'a LogQuery) {
if let Some(level_upper) = query.level_upper.as_ref() {
builder
.push(" AND UPPER(level) = ")
.push_bind(level_upper.as_str());
}
if let Some(from_ts) = query.from_ts {
builder.push(" AND ts >= ").push_bind(from_ts);
}
if let Some(to_ts) = query.to_ts {
builder.push(" AND ts <= ").push_bind(to_ts);
}
push_like_filters(builder, "module_path", &query.module_like);
push_like_filters(builder, "file", &query.file_like);
let has_thread_filter = !query.thread_ids.is_empty() || query.include_threadless;
if has_thread_filter {
builder.push(" AND (");
let mut needs_or = false;
for thread_id in &query.thread_ids {
if needs_or {
builder.push(" OR ");
}
builder.push("thread_id = ").push_bind(thread_id.as_str());
needs_or = true;
}
if query.include_threadless {
if needs_or {
builder.push(" OR ");
}
builder.push("thread_id IS NULL");
}
builder.push(")");
}
if let Some(after_id) = query.after_id {
builder.push(" AND id > ").push_bind(after_id);
}
}
fn push_like_filters<'a>(
builder: &mut QueryBuilder<'a, Sqlite>,
column: &str,
filters: &'a [String],
) {
if filters.is_empty() {
return;
}
builder.push(" AND (");
for (idx, filter) in filters.iter().enumerate() {
if idx > 0 {
builder.push(" OR ");
}
builder
.push(column)
.push(" LIKE '%' || ")
.push_bind(filter.as_str())
.push(" || '%'");
}
builder.push(")");
}
async fn open_sqlite(path: &Path) -> anyhow::Result<SqlitePool> {
let options = SqliteConnectOptions::new()
.filename(path)
.create_if_missing(true)
.journal_mode(SqliteJournalMode::Wal)
.synchronous(SqliteSynchronous::Normal)
.busy_timeout(Duration::from_secs(5))
.log_statements(LevelFilter::Off);
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(options)
.await?;
MIGRATOR.run(&pool).await?;
Ok(pool)
}
fn push_thread_filters<'a>(
builder: &mut QueryBuilder<'a, Sqlite>,
archived_only: bool,
allowed_sources: &'a [String],
model_providers: Option<&'a [String]>,
anchor: Option<&crate::Anchor>,
sort_key: SortKey,
) {
builder.push(" WHERE 1 = 1");
if archived_only {
builder.push(" AND archived = 1");
} else {
builder.push(" AND archived = 0");
}
builder.push(" AND has_user_event = 1");
if !allowed_sources.is_empty() {
builder.push(" AND source IN (");
let mut separated = builder.separated(", ");
for source in allowed_sources {
separated.push_bind(source);
}
separated.push_unseparated(")");
}
if let Some(model_providers) = model_providers
&& !model_providers.is_empty()
{
builder.push(" AND model_provider IN (");
let mut separated = builder.separated(", ");
for provider in model_providers {
separated.push_bind(provider);
}
separated.push_unseparated(")");
}
if let Some(anchor) = anchor {
let anchor_ts = datetime_to_epoch_seconds(anchor.ts);
let column = match sort_key {
SortKey::CreatedAt => "created_at",
SortKey::UpdatedAt => "updated_at",
};
builder.push(" AND (");
builder.push(column);
builder.push(" < ");
builder.push_bind(anchor_ts);
builder.push(" OR (");
builder.push(column);
builder.push(" = ");
builder.push_bind(anchor_ts);
builder.push(" AND id < ");
builder.push_bind(anchor.id.to_string());
builder.push("))");
}
}
fn push_thread_order_and_limit(
builder: &mut QueryBuilder<'_, Sqlite>,
sort_key: SortKey,
limit: usize,
) {
let order_column = match sort_key {
SortKey::CreatedAt => "created_at",
SortKey::UpdatedAt => "updated_at",
};
builder.push(" ORDER BY ");
builder.push(order_column);
builder.push(" DESC, id DESC");
builder.push(" LIMIT ");
builder.push_bind(limit as i64);
}