Compare commits

...

2 Commits

Author SHA1 Message Date
Owen Lin
b168272203 Add SQLite operation telemetry 2026-05-07 19:25:18 -07:00
Owen Lin
9ddf828d4c Add SQLite init and fallback telemetry 2026-05-07 19:24:28 -07:00
24 changed files with 1345 additions and 452 deletions

View File

@@ -80,7 +80,7 @@ pub fn build_provider(
let service_name = service_name_override.unwrap_or(originator.value.as_str());
let runtime_metrics = config.features.enabled(Feature::RuntimeMetrics);
OtelProvider::from(&OtelSettings {
let provider = OtelProvider::from(&OtelSettings {
service_name: service_name.to_string(),
service_version: service_version.to_string(),
codex_home: config.codex_home.to_path_buf(),
@@ -91,7 +91,15 @@ pub fn build_provider(
runtime_metrics,
span_attributes: config.otel.span_attributes.clone(),
tracestate: config.otel.tracestate.clone(),
})
})?;
if let Some(provider) = provider.as_ref()
&& let Some(metrics) = provider.metrics()
{
let _ = codex_otel::record_process_start_once(metrics, originator.value.as_str());
}
Ok(provider)
}
/// Filter predicate for exporting only Codex-owned events via OTEL.

View File

@@ -200,10 +200,15 @@ pub async fn run_main(
mod tests {
use super::*;
use codex_config::types::OtelExporterKind;
use codex_config::types::OtelHttpProtocol;
use codex_core::config::ConfigBuilder;
use pretty_assertions::assert_eq;
use std::collections::HashMap;
use tempfile::TempDir;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::method;
#[test]
fn mcp_server_defaults_analytics_to_enabled() {
@@ -212,14 +217,21 @@ mod tests {
#[tokio::test]
async fn mcp_server_builds_otel_provider_with_logs_traces_and_metrics() -> anyhow::Result<()> {
let collector = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200))
.mount(&collector)
.await;
let codex_home = TempDir::new()?;
let mut config = ConfigBuilder::default()
.codex_home(codex_home.path().to_path_buf())
.build()
.await?;
let exporter = OtelExporterKind::OtlpGrpc {
endpoint: "http://localhost:4317".to_string(),
let exporter = OtelExporterKind::OtlpHttp {
endpoint: collector.uri(),
headers: HashMap::new(),
protocol: OtelHttpProtocol::Binary,
tls: None,
};
config.otel.exporter = exporter.clone();

View File

@@ -41,6 +41,7 @@ use std::time::Duration;
use tracing::debug;
const ENV_ATTRIBUTE: &str = "env";
const ARCH_ATTRIBUTE: &str = "arch";
const METER_NAME: &str = "codex";
const DURATION_UNIT: &str = "ms";
const DURATION_DESCRIPTION: &str = "Duration in milliseconds.";
@@ -198,13 +199,13 @@ impl MetricsClient {
validate_tags(&default_tags)?;
let mut resource_attributes = Vec::with_capacity(4);
let mut resource_attributes = Vec::with_capacity(5);
resource_attributes.push(KeyValue::new(
semconv::attribute::SERVICE_VERSION,
service_version,
));
resource_attributes.push(KeyValue::new(ENV_ATTRIBUTE, environment));
resource_attributes.extend(os_resource_attributes());
resource_attributes.extend(platform_resource_attributes());
let resource = Resource::builder()
.with_service_name(service_name)
@@ -290,12 +291,13 @@ impl MetricsClient {
}
}
fn os_resource_attributes() -> Vec<KeyValue> {
fn platform_resource_attributes() -> Vec<KeyValue> {
let os_info = os_info::get();
let os_type_raw = os_info.os_type().to_string();
let os_type = sanitize_metric_tag_value(os_type_raw.as_str());
let os_version_raw = os_info.version().to_string();
let os_version = sanitize_metric_tag_value(os_version_raw.as_str());
let arch = sanitize_metric_tag_value(std::env::consts::ARCH);
let mut attributes = Vec::new();
if os_type != "unspecified" {
attributes.push(KeyValue::new("os", os_type));
@@ -303,6 +305,9 @@ fn os_resource_attributes() -> Vec<KeyValue> {
if os_version != "unspecified" {
attributes.push(KeyValue::new("os_version", os_version));
}
if arch != "unspecified" {
attributes.push(KeyValue::new(ARCH_ATTRIBUTE, arch));
}
attributes
}

View File

@@ -2,6 +2,7 @@ mod client;
mod config;
mod error;
pub(crate) mod names;
mod process;
pub(crate) mod runtime_metrics;
pub(crate) mod tags;
pub(crate) mod timer;
@@ -13,9 +14,12 @@ pub use crate::metrics::config::MetricsConfig;
pub use crate::metrics::config::MetricsExporter;
pub use crate::metrics::error::MetricsError;
pub use crate::metrics::error::Result;
pub use crate::metrics::process::record_process_start_once;
pub use names::*;
use std::sync::OnceLock;
pub use tags::ORIGINATOR_TAG;
pub use tags::SessionMetricTagValues;
pub use tags::bounded_originator_tag_value;
static GLOBAL_METRICS: OnceLock<MetricsClient> = OnceLock::new();
static GLOBAL_STATSIG_METRICS_SETTINGS: OnceLock<StatsigMetricsSettings> = OnceLock::new();

View File

@@ -1,6 +1,7 @@
pub const TOOL_CALL_COUNT_METRIC: &str = "codex.tool.call";
pub const TOOL_CALL_DURATION_METRIC: &str = "codex.tool.call.duration_ms";
pub const TOOL_CALL_UNIFIED_EXEC_METRIC: &str = "codex.tool.unified_exec";
pub const PROCESS_START_METRIC: &str = "codex.process.start";
pub const API_CALL_COUNT_METRIC: &str = "codex.api_request";
pub const API_CALL_DURATION_METRIC: &str = "codex.api_request.duration_ms";
pub const SSE_EVENT_COUNT_METRIC: &str = "codex.sse_event";

View File

@@ -0,0 +1,27 @@
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use super::client::MetricsClient;
use super::error::Result;
use super::names::PROCESS_START_METRIC;
use super::tags::ORIGINATOR_TAG;
use super::tags::bounded_originator_tag_value;
static PROCESS_START_RECORDED: AtomicBool = AtomicBool::new(false);
/// Record the process start counter at most once for this process.
pub fn record_process_start_once(metrics: &MetricsClient, originator: &str) -> Result<bool> {
if PROCESS_START_RECORDED
.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
.is_err()
{
return Ok(false);
}
metrics.counter(
PROCESS_START_METRIC,
/*inc*/ 1,
&[(ORIGINATOR_TAG, bounded_originator_tag_value(originator))],
)?;
Ok(true)
}

View File

@@ -1,6 +1,7 @@
use crate::metrics::Result;
use crate::metrics::validation::validate_tag_key;
use crate::metrics::validation::validate_tag_value;
use codex_utils_string::sanitize_metric_tag_value;
pub const APP_VERSION_TAG: &str = "app.version";
pub const AUTH_MODE_TAG: &str = "auth_mode";
@@ -9,6 +10,24 @@ pub const ORIGINATOR_TAG: &str = "originator";
pub const SERVICE_NAME_TAG: &str = "service_name";
pub const SESSION_SOURCE_TAG: &str = "session_source";
const OTHER_ORIGINATOR_TAG_VALUE: &str = "other";
/// Returns a sanitized, low-cardinality originator value that is safe to use as a metric tag.
pub fn bounded_originator_tag_value(originator: &str) -> &'static str {
match sanitize_metric_tag_value(originator).as_str() {
"codex_desktop" => "codex_desktop",
"codex_cli_rs" => "codex_cli_rs",
"codex-tui" => "codex-tui",
"codex_vscode" => "codex_vscode",
"none" => "none",
"codex_exec" => "codex_exec",
"codex-cli" => "codex-cli",
"codex_sdk_ts" => "codex_sdk_ts",
"codex-app-server-sdk" => "codex-app-server-sdk",
_ => OTHER_ORIGINATOR_TAG_VALUE,
}
}
pub struct SessionMetricTagValues<'a> {
pub auth_mode: Option<&'a str>,
pub session_source: &'a str,

View File

@@ -10,6 +10,7 @@ pub(crate) mod metadata;
pub(crate) mod policy;
pub(crate) mod recorder;
pub(crate) mod session_index;
pub(crate) mod sqlite_metrics;
pub mod state_db;
pub(crate) mod default_client {

View File

@@ -1279,6 +1279,7 @@ async fn find_thread_path_by_id_str_in_subdir(
tracing::warn!(
"state db discrepancy during find_thread_path_by_id_str_in_subdir: mismatched_db_path"
);
crate::sqlite_metrics::record_fallback("find_thread_path", "mismatch");
}
Err(err) => {
tracing::debug!(
@@ -1296,6 +1297,7 @@ async fn find_thread_path_by_id_str_in_subdir(
tracing::warn!(
"state db discrepancy during find_thread_path_by_id_str_in_subdir: stale_db_path"
);
crate::sqlite_metrics::record_fallback("find_thread_path", "stale_path");
}
}
@@ -1323,6 +1325,12 @@ async fn find_thread_path_by_id_str_in_subdir(
tracing::warn!(
"state db discrepancy during find_thread_path_by_id_str_in_subdir: falling_back"
);
let reason = if state_db_ctx.is_some() {
"missing_row"
} else {
"db_unavailable"
};
crate::sqlite_metrics::record_fallback("find_thread_path", reason);
state_db::read_repair_rollout_path(
state_db_ctx,
thread_id,

View File

@@ -450,6 +450,7 @@ impl RolloutRecorder {
if state_db_ctx.is_none() {
// Keep legacy behavior when SQLite is unavailable: return filesystem results
// at the requested page size.
crate::sqlite_metrics::record_fallback("list_threads", "db_unavailable");
return Ok(page_from_filesystem_scan(
fs_page,
sort_direction,
@@ -569,6 +570,7 @@ impl RolloutRecorder {
}
if listing_has_metadata_filters {
let page = page_from_filesystem_scan(fs_page, sort_direction, page_size, sort_key);
crate::sqlite_metrics::record_fallback("list_threads", "db_error");
return Ok(fill_missing_thread_item_metadata_from_state_db(
state_db_ctx.as_deref(),
page,
@@ -578,6 +580,7 @@ impl RolloutRecorder {
// If SQLite listing still fails, return the filesystem page rather than failing the list.
tracing::error!("Falling back on rollout system");
tracing::warn!("state db discrepancy during list_threads_with_db_fallback: falling_back");
crate::sqlite_metrics::record_fallback("list_threads", "db_error");
Ok(page_from_filesystem_scan(
fs_page,
sort_direction,

View File

@@ -0,0 +1,49 @@
use std::sync::Arc;
use std::time::Duration;
use codex_otel::ORIGINATOR_TAG;
use codex_otel::bounded_originator_tag_value;
use codex_state::DbMetricsRecorder;
use codex_state::DbMetricsRecorderHandle;
use crate::default_client::originator;
struct OtelDbMetrics {
metrics: codex_otel::MetricsClient,
originator: &'static str,
}
impl DbMetricsRecorder for OtelDbMetrics {
fn counter(&self, name: &str, inc: i64, tags: &[(&str, &str)]) {
let tags = sqlite_originator_tags(tags, self.originator);
let _ = self.metrics.counter(name, inc, &tags);
}
fn record_duration(&self, name: &str, duration: Duration, tags: &[(&str, &str)]) {
let tags = sqlite_originator_tags(tags, self.originator);
let _ = self.metrics.record_duration(name, duration, &tags);
}
}
pub(crate) fn global() -> Option<DbMetricsRecorderHandle> {
codex_otel::global().map(|metrics| {
Arc::new(OtelDbMetrics {
metrics,
originator: bounded_originator_tag_value(originator().value.as_str()),
}) as DbMetricsRecorderHandle
})
}
pub(crate) fn record_fallback(caller: &'static str, reason: &'static str) {
let metrics = global();
codex_state::record_db_fallback_metric(metrics.as_deref(), caller, reason);
}
fn sqlite_originator_tags<'a>(
tags: &[(&'a str, &'a str)],
originator: &'static str,
) -> Vec<(&'a str, &'a str)> {
let mut tags = tags.to_vec();
tags.push((ORIGINATOR_TAG, originator));
tags
}

View File

@@ -4,6 +4,7 @@ use crate::list::Cursor;
use crate::list::SortDirection;
use crate::list::ThreadSortKey;
use crate::metadata;
use crate::sqlite_metrics;
use chrono::DateTime;
use chrono::Utc;
use codex_protocol::ThreadId;
@@ -106,52 +107,80 @@ async fn try_init_with_roots_inner(
default_model_provider_id: String,
backfill_lease_seconds: Option<i64>,
) -> anyhow::Result<StateDbHandle> {
let runtime =
codex_state::StateRuntime::init(sqlite_home.clone(), default_model_provider_id.clone())
.await
.map_err(|err| {
anyhow::anyhow!(
"failed to initialize state runtime at {}: {err}",
sqlite_home.display()
)
})?;
let metrics = sqlite_metrics::global();
let runtime = codex_state::StateRuntime::init_with_metrics(
sqlite_home.clone(),
default_model_provider_id.clone(),
metrics.clone(),
)
.await
.map_err(|err| {
anyhow::anyhow!(
"failed to initialize state runtime at {}: {err}",
sqlite_home.display()
)
})?;
let backfill_gate_started = Instant::now();
let backfill_gate_result = wait_for_startup_backfill(
runtime.as_ref(),
codex_home.as_path(),
default_model_provider_id.as_str(),
backfill_lease_seconds,
)
.await;
codex_state::record_db_init_backfill_gate_metric(
metrics.as_deref(),
backfill_gate_started.elapsed(),
&backfill_gate_result,
);
backfill_gate_result?;
Ok(runtime)
}
async fn wait_for_startup_backfill(
runtime: &codex_state::StateRuntime,
codex_home: &Path,
default_model_provider_id: &str,
backfill_lease_seconds: Option<i64>,
) -> anyhow::Result<()> {
let wait_started = Instant::now();
let mut reported_wait = false;
loop {
let backfill_state = runtime.get_backfill_state().await.map_err(|err| {
anyhow::anyhow!(
"failed to read backfill state at {}: {err}",
codex_home.display()
)
})?;
let backfill_state = match runtime.get_backfill_state().await {
Ok(state) => state,
Err(err) => {
return Err(anyhow::anyhow!(
"failed to read backfill state at {}: {err}",
codex_home.display()
));
}
};
if backfill_state.status == codex_state::BackfillStatus::Complete {
return Ok(runtime);
return Ok(());
}
if let Some(backfill_lease_seconds) = backfill_lease_seconds {
metadata::backfill_sessions_with_lease(
runtime.as_ref(),
codex_home.as_path(),
default_model_provider_id.as_str(),
runtime,
codex_home,
default_model_provider_id,
backfill_lease_seconds,
)
.await;
} else {
metadata::backfill_sessions(
runtime.as_ref(),
codex_home.as_path(),
default_model_provider_id.as_str(),
)
.await;
metadata::backfill_sessions(runtime, codex_home, default_model_provider_id).await;
}
let backfill_state = runtime.get_backfill_state().await.map_err(|err| {
anyhow::anyhow!(
"failed to read backfill state at {} after startup backfill: {err}",
codex_home.display()
)
})?;
let backfill_state = match runtime.get_backfill_state().await {
Ok(state) => state,
Err(err) => {
return Err(anyhow::anyhow!(
"failed to read backfill state at {} after startup backfill: {err}",
codex_home.display()
));
}
};
if backfill_state.status == codex_state::BackfillStatus::Complete {
return Ok(runtime);
return Ok(());
}
if wait_started.elapsed() >= STARTUP_BACKFILL_WAIT_TIMEOUT {
return Err(anyhow::anyhow!(
@@ -193,22 +222,36 @@ fn emit_startup_warning(message: &str) {
/// Unlike [`init`], this helper does not run rollout backfill. It is for
/// optional local reads from non-owning contexts such as remote app-server mode.
pub async fn get_state_db(config: &impl RolloutConfigView) -> Option<StateDbHandle> {
let metrics = sqlite_metrics::global();
let state_path = codex_state::state_db_path(config.sqlite_home());
if !tokio::fs::try_exists(&state_path).await.unwrap_or(false) {
codex_state::record_db_fallback_metric(
metrics.as_deref(),
"get_state_db",
"db_unavailable",
);
return None;
}
let runtime = codex_state::StateRuntime::init(
let runtime = match codex_state::StateRuntime::init_with_metrics(
config.sqlite_home().to_path_buf(),
config.model_provider_id().to_string(),
metrics.clone(),
)
.await
.ok()?;
require_backfill_complete(runtime, config.sqlite_home()).await
{
Ok(runtime) => runtime,
Err(_) => {
codex_state::record_db_fallback_metric(metrics.as_deref(), "get_state_db", "db_error");
return None;
}
};
require_backfill_complete(runtime, config.sqlite_home(), metrics.as_deref()).await
}
async fn require_backfill_complete(
runtime: StateDbHandle,
codex_home: &Path,
metrics: Option<&dyn codex_state::DbMetricsRecorder>,
) -> Option<StateDbHandle> {
match runtime.get_backfill_state().await {
Ok(state) if state.status == codex_state::BackfillStatus::Complete => Some(runtime),
@@ -218,6 +261,7 @@ async fn require_backfill_complete(
codex_home.display(),
state.status.as_str()
);
codex_state::record_db_fallback_metric(metrics, "get_state_db", "backfill_incomplete");
None
}
Err(err) => {
@@ -225,6 +269,7 @@ async fn require_backfill_complete(
"failed to read backfill state at {}: {err}",
codex_home.display()
);
codex_state::record_db_fallback_metric(metrics, "get_state_db", "db_error");
None
}
}

View File

@@ -10,12 +10,13 @@ mod migrations;
mod model;
mod paths;
mod runtime;
mod telemetry;
pub use model::LogEntry;
pub use model::LogQuery;
pub use model::LogRow;
pub use model::Phase2JobClaimOutcome;
/// Preferred entrypoint: owns configuration and metrics.
/// Preferred entrypoint: owns SQLite configuration and optional metrics injection.
pub use runtime::StateRuntime;
/// Low-level storage engine: useful for focused tests.
@@ -56,6 +57,8 @@ pub use runtime::logs_db_filename;
pub use runtime::logs_db_path;
pub use runtime::state_db_filename;
pub use runtime::state_db_path;
pub use telemetry::DbMetricsRecorder;
pub use telemetry::DbMetricsRecorderHandle;
/// Environment variable for overriding the SQLite state database home directory.
pub const SQLITE_HOME_ENV: &str = "CODEX_SQLITE_HOME";
@@ -71,3 +74,31 @@ pub const DB_ERROR_METRIC: &str = "codex.db.error";
pub const DB_METRIC_BACKFILL: &str = "codex.db.backfill";
/// Metrics on backfill duration. Tags: [status]
pub const DB_METRIC_BACKFILL_DURATION_MS: &str = "codex.db.backfill.duration_ms";
/// SQLite startup initialization attempts. Tags: [status, phase, db, error]
pub const DB_INIT_METRIC: &str = "codex.sqlite.init.count";
/// SQLite startup initialization duration. Tags: [status, phase, db, error]
pub const DB_INIT_DURATION_METRIC: &str = "codex.sqlite.init.duration_ms";
/// SQLite logical operation attempts. Tags: [status, db, operation, access, error]
pub const DB_OPERATION_METRIC: &str = "codex.sqlite.operation.count";
/// SQLite logical operation duration. Tags: [status, db, operation, access, error]
pub const DB_OPERATION_DURATION_METRIC: &str = "codex.sqlite.operation.duration_ms";
/// Filesystem fallback after SQLite could not serve a request. Tags: [caller, reason]
pub const DB_FALLBACK_METRIC: &str = "codex.sqlite.fallback.count";
/// SQLite log queue loss or flush failure. Tags: [event, reason]
pub const DB_LOG_QUEUE_METRIC: &str = "codex.sqlite.log_queue.count";
pub fn record_db_fallback_metric(
metrics: Option<&dyn DbMetricsRecorder>,
caller: &'static str,
reason: &'static str,
) {
telemetry::record_fallback(metrics, caller, reason);
}
pub fn record_db_init_backfill_gate_metric(
metrics: Option<&dyn DbMetricsRecorder>,
duration: std::time::Duration,
result: &anyhow::Result<()>,
) {
telemetry::record_init_backfill_gate(metrics, duration, result);
}

View File

@@ -43,6 +43,7 @@ use uuid::Uuid;
use crate::LogEntry;
use crate::StateRuntime;
use crate::telemetry;
const LOG_QUEUE_CAPACITY: usize = 512;
const LOG_BATCH_SIZE: usize = 128;
@@ -94,6 +95,7 @@ where
pub struct LogDbLayer {
sender: mpsc::Sender<LogDbCommand>,
process_uuid: String,
metrics: Option<crate::DbMetricsRecorderHandle>,
}
pub fn start(state_db: std::sync::Arc<StateRuntime>) -> LogDbLayer {
@@ -105,6 +107,7 @@ impl Clone for LogDbLayer {
Self {
sender: self.sender.clone(),
process_uuid: self.process_uuid.clone(),
metrics: self.metrics.clone(),
}
}
}
@@ -120,22 +123,34 @@ impl LogDbLayer {
) -> Self {
let config = config.normalized();
let (sender, receiver) = mpsc::channel(config.queue_capacity);
let metrics = state_db.metrics_handle();
tokio::spawn(run_inserter(state_db, receiver, config));
Self {
sender,
process_uuid: current_process_log_uuid().to_string(),
metrics,
}
}
pub async fn flush(&self) {
let (tx, rx) = oneshot::channel();
if self.sender.send(LogDbCommand::Flush(tx)).await.is_ok() {
let _ = rx.await;
if self.sender.send(LogDbCommand::Flush(tx)).await.is_err() {
telemetry::record_log_queue(self.metrics.as_deref(), "flush_failed", "closed");
return;
}
if rx.await.is_err() {
telemetry::record_log_queue(self.metrics.as_deref(), "flush_failed", "closed");
}
}
fn try_send(&self, entry: LogEntry) {
let _ = self.sender.try_send(LogDbCommand::Entry(Box::new(entry)));
if let Err(err) = self.sender.try_send(LogDbCommand::Entry(Box::new(entry))) {
let reason = match err {
mpsc::error::TrySendError::Full(_) => "full",
mpsc::error::TrySendError::Closed(_) => "closed",
};
telemetry::record_log_queue(self.metrics.as_deref(), "dropped", reason);
}
}
}
@@ -401,7 +416,9 @@ async fn flush(state_db: &StateRuntime, buffer: &mut Vec<LogEntry>) {
return;
}
let entries = buffer.split_off(0);
let _ = state_db.insert_logs(entries.as_slice()).await;
if state_db.insert_logs(entries.as_slice()).await.is_err() {
telemetry::record_log_queue(state_db.metrics(), "flush_failed", "insert_failed");
}
}
#[derive(Default)]
@@ -721,6 +738,7 @@ mod tests {
let layer = LogDbLayer {
sender,
process_uuid: "process-1".to_string(),
metrics: None,
};
layer.try_send(test_entry("first-queued-log"));
@@ -741,6 +759,7 @@ mod tests {
let layer = LogDbLayer {
sender,
process_uuid: "process-1".to_string(),
metrics: None,
};
layer.try_send(test_entry("queued-before-flush"));

View File

@@ -27,6 +27,10 @@ use crate::model::datetime_to_epoch_millis;
use crate::model::datetime_to_epoch_seconds;
use crate::model::epoch_millis_to_datetime;
use crate::paths::file_modified_time_utc;
use crate::telemetry::DbAccess;
use crate::telemetry::DbKind;
use crate::telemetry::DbMetricsRecorder;
use crate::telemetry::DbMetricsRecorderHandle;
use chrono::DateTime;
use chrono::Utc;
use codex_protocol::ThreadId;
@@ -52,10 +56,12 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicI64;
use std::time::Duration;
use std::time::Instant;
use tracing::warn;
mod agent_jobs;
mod backfill;
mod db;
mod goals;
mod logs;
mod memories;
@@ -70,6 +76,9 @@ pub use goals::ThreadGoalUpdate;
pub use remote_control::RemoteControlEnrollmentRecord;
pub use threads::ThreadFilterOptions;
use db::DbOperation;
use db::InstrumentedDb;
// "Partition" is the retained-log-content bucket we cap at 10 MiB:
// - one bucket per non-null thread_id
// - one bucket per threadless (thread_id IS NULL) non-null process_uuid
@@ -83,8 +92,8 @@ const LOG_PARTITION_ROW_LIMIT: i64 = 1_000;
pub struct StateRuntime {
codex_home: PathBuf,
default_provider: String,
pool: Arc<sqlx::SqlitePool>,
logs_pool: Arc<sqlx::SqlitePool>,
state_db: InstrumentedDb,
logs_db: InstrumentedDb,
thread_updated_at_millis: Arc<AtomicI64>,
}
@@ -93,8 +102,18 @@ impl StateRuntime {
///
/// This opens (and migrates) the SQLite databases under `codex_home`,
/// keeping logs in a dedicated file to reduce lock contention with the
/// rest of the state store.
/// rest of the state store. Use [`Self::init_with_metrics`] when the caller
/// has a metrics sink to attach.
pub async fn init(codex_home: PathBuf, default_provider: String) -> anyhow::Result<Arc<Self>> {
Self::init_with_metrics(codex_home, default_provider, /*metrics*/ None).await
}
/// Initialize the state runtime with an explicit metrics client.
pub async fn init_with_metrics(
codex_home: PathBuf,
default_provider: String,
metrics: Option<DbMetricsRecorderHandle>,
) -> anyhow::Result<Arc<Self>> {
tokio::fs::create_dir_all(&codex_home).await?;
let state_migrator = runtime_state_migrator();
let logs_migrator = runtime_logs_migrator();
@@ -116,28 +135,49 @@ impl StateRuntime {
.await;
let state_path = state_db_path(codex_home.as_path());
let logs_path = logs_db_path(codex_home.as_path());
let pool = match open_state_sqlite(&state_path, &state_migrator).await {
let pool = match open_state_sqlite(&state_path, &state_migrator, metrics.as_deref()).await {
Ok(db) => Arc::new(db),
Err(err) => {
warn!("failed to open state db at {}: {err}", state_path.display());
return Err(err);
}
};
let logs_pool = match open_logs_sqlite(&logs_path, &logs_migrator).await {
let logs_pool = match open_logs_sqlite(&logs_path, &logs_migrator, metrics.as_deref()).await
{
Ok(db) => Arc::new(db),
Err(err) => {
warn!("failed to open logs db at {}: {err}", logs_path.display());
return Err(err);
}
};
let thread_updated_at_millis: Option<i64> =
let started = Instant::now();
let backfill_state_result = ensure_backfill_state_row_in_pool(pool.as_ref()).await;
crate::telemetry::record_init_result(
metrics.as_deref(),
DbKind::State,
"ensure_backfill_state",
started.elapsed(),
&backfill_state_result,
);
backfill_state_result?;
let started = Instant::now();
let thread_updated_at_millis_result: anyhow::Result<Option<i64>> =
sqlx::query_scalar("SELECT MAX(threads.updated_at_ms) FROM threads")
.fetch_one(pool.as_ref())
.await?;
.await
.map_err(anyhow::Error::from);
crate::telemetry::record_init_result(
metrics.as_deref(),
DbKind::State,
"post_init_query",
started.elapsed(),
&thread_updated_at_millis_result,
);
let thread_updated_at_millis = thread_updated_at_millis_result?;
let thread_updated_at_millis = thread_updated_at_millis.unwrap_or(0);
let runtime = Arc::new(Self {
pool,
logs_pool,
state_db: InstrumentedDb::new(pool, DbKind::State, metrics.clone()),
logs_db: InstrumentedDb::new(logs_pool, DbKind::Logs, metrics),
codex_home,
default_provider,
thread_updated_at_millis: Arc::new(AtomicI64::new(thread_updated_at_millis)),
@@ -155,6 +195,14 @@ impl StateRuntime {
pub fn codex_home(&self) -> &Path {
self.codex_home.as_path()
}
pub(crate) fn metrics(&self) -> Option<&dyn DbMetricsRecorder> {
self.state_db.metrics()
}
pub(crate) fn metrics_handle(&self) -> Option<DbMetricsRecorderHandle> {
self.state_db.metrics_handle()
}
}
fn base_sqlite_options(path: &Path) -> SqliteConnectOptions {
@@ -165,29 +213,90 @@ fn base_sqlite_options(path: &Path) -> SqliteConnectOptions {
.synchronous(SqliteSynchronous::Normal)
.busy_timeout(Duration::from_secs(5))
.log_statements(LevelFilter::Off)
.log_slow_statements(LevelFilter::Warn, Duration::from_millis(250))
}
async fn open_state_sqlite(path: &Path, migrator: &Migrator) -> anyhow::Result<SqlitePool> {
async fn open_state_sqlite(
path: &Path,
migrator: &Migrator,
metrics: Option<&dyn DbMetricsRecorder>,
) -> anyhow::Result<SqlitePool> {
// New state DBs should use incremental auto-vacuum, but retrofitting an
// existing DB requires a full VACUUM. Do not attempt that during process
// startup: it is maintenance work that can contend with foreground writers.
open_sqlite(
path,
migrator,
metrics,
DbKind::State,
"open_state",
"migrate_state",
)
.await
}
async fn open_logs_sqlite(
path: &Path,
migrator: &Migrator,
metrics: Option<&dyn DbMetricsRecorder>,
) -> anyhow::Result<SqlitePool> {
open_sqlite(
path,
migrator,
metrics,
DbKind::Logs,
"open_logs",
"migrate_logs",
)
.await
}
async fn open_sqlite(
path: &Path,
migrator: &Migrator,
metrics: Option<&dyn DbMetricsRecorder>,
db: DbKind,
open_phase: &'static str,
migrate_phase: &'static str,
) -> anyhow::Result<SqlitePool> {
let options = base_sqlite_options(path).auto_vacuum(SqliteAutoVacuum::Incremental);
let pool = SqlitePoolOptions::new()
let started = Instant::now();
let pool_result = SqlitePoolOptions::new()
.max_connections(5)
.acquire_slow_level(LevelFilter::Warn)
.acquire_slow_threshold(Duration::from_millis(250))
.connect_with(options)
.await?;
migrator.run(&pool).await?;
.await
.map_err(anyhow::Error::from);
crate::telemetry::record_init_result(metrics, db, open_phase, started.elapsed(), &pool_result);
let pool = pool_result?;
let started = Instant::now();
let migrate_result = migrator.run(&pool).await.map_err(anyhow::Error::from);
crate::telemetry::record_init_result(
metrics,
db,
migrate_phase,
started.elapsed(),
&migrate_result,
);
migrate_result?;
Ok(pool)
}
async fn open_logs_sqlite(path: &Path, migrator: &Migrator) -> anyhow::Result<SqlitePool> {
let options = base_sqlite_options(path).auto_vacuum(SqliteAutoVacuum::Incremental);
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(options)
.await?;
migrator.run(&pool).await?;
Ok(pool)
async fn ensure_backfill_state_row_in_pool(pool: &sqlx::SqlitePool) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO backfill_state (id, status, last_watermark, last_success_at, updated_at)
VALUES (?, ?, NULL, NULL, ?)
ON CONFLICT(id) DO NOTHING
"#,
)
.bind(1_i64)
.bind(crate::BackfillStatus::Pending.as_str())
.bind(Utc::now().timestamp())
.execute(pool)
.await?;
Ok(())
}
fn db_filename(base_name: &str, version: u32) -> String {
@@ -355,9 +464,13 @@ mod tests {
strict_pool.close().await;
let tolerant_migrator = runtime_state_migrator();
let tolerant_pool = open_state_sqlite(state_path.as_path(), &tolerant_migrator)
.await
.expect("runtime migrator should tolerate newer applied migrations");
let tolerant_pool = open_state_sqlite(
state_path.as_path(),
&tolerant_migrator,
/*metrics*/ None,
)
.await
.expect("runtime migrator should tolerate newer applied migrations");
tolerant_pool.close().await;
let _ = tokio::fs::remove_dir_all(codex_home).await;

View File

@@ -19,7 +19,7 @@ impl StateRuntime {
.map(i64::try_from)
.transpose()
.map_err(|_| anyhow::anyhow!("invalid max_runtime_seconds value"))?;
let mut tx = self.pool.begin().await?;
let mut tx = self.state_db.pool().begin().await?;
sqlx::query(
r#"
INSERT INTO agent_jobs (
@@ -122,7 +122,7 @@ WHERE id = ?
"#,
)
.bind(job_id)
.fetch_optional(self.pool.as_ref())
.fetch_optional(self.state_db.pool())
.await?;
row.map(AgentJob::try_from).transpose()
}
@@ -166,7 +166,7 @@ WHERE job_id =
}
let rows: Vec<AgentJobItemRow> = builder
.build_query_as::<AgentJobItemRow>()
.fetch_all(self.pool.as_ref())
.fetch_all(self.state_db.pool())
.await?;
rows.into_iter().map(AgentJobItem::try_from).collect()
}
@@ -199,7 +199,7 @@ WHERE job_id = ? AND item_id = ?
)
.bind(job_id)
.bind(item_id)
.fetch_optional(self.pool.as_ref())
.fetch_optional(self.state_db.pool())
.await?;
row.map(AgentJobItem::try_from).transpose()
}
@@ -222,7 +222,7 @@ WHERE id = ?
.bind(now)
.bind(now)
.bind(job_id)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(())
}
@@ -240,7 +240,7 @@ WHERE id = ?
.bind(now)
.bind(now)
.bind(job_id)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(())
}
@@ -263,7 +263,7 @@ WHERE id = ?
.bind(now)
.bind(error_message)
.bind(job_id)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(())
}
@@ -288,7 +288,7 @@ WHERE id = ? AND status IN (?, ?)
.bind(job_id)
.bind(AgentJobStatus::Pending.as_str())
.bind(AgentJobStatus::Running.as_str())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -302,7 +302,7 @@ WHERE id = ?
"#,
)
.bind(job_id)
.fetch_optional(self.pool.as_ref())
.fetch_optional(self.state_db.pool())
.await?;
let Some(row) = row else {
return Ok(false);
@@ -334,7 +334,7 @@ WHERE job_id = ? AND item_id = ? AND status = ?
.bind(job_id)
.bind(item_id)
.bind(AgentJobItemStatus::Pending.as_str())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -364,7 +364,7 @@ WHERE job_id = ? AND item_id = ? AND status = ?
.bind(job_id)
.bind(item_id)
.bind(AgentJobItemStatus::Pending.as_str())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -393,7 +393,7 @@ WHERE job_id = ? AND item_id = ? AND status = ?
.bind(job_id)
.bind(item_id)
.bind(AgentJobItemStatus::Running.as_str())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -417,7 +417,7 @@ WHERE job_id = ? AND item_id = ? AND status = ?
.bind(job_id)
.bind(item_id)
.bind(AgentJobItemStatus::Running.as_str())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -458,7 +458,7 @@ WHERE
.bind(item_id)
.bind(AgentJobItemStatus::Running.as_str())
.bind(reporting_thread_id)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -490,7 +490,7 @@ WHERE
.bind(job_id)
.bind(item_id)
.bind(AgentJobItemStatus::Running.as_str())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -524,7 +524,7 @@ WHERE
.bind(job_id)
.bind(item_id)
.bind(AgentJobItemStatus::Running.as_str())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -547,7 +547,7 @@ WHERE job_id = ?
.bind(AgentJobItemStatus::Completed.as_str())
.bind(AgentJobItemStatus::Failed.as_str())
.bind(job_id)
.fetch_one(self.pool.as_ref())
.fetch_one(self.state_db.pool())
.await?;
let total_items: i64 = row.try_get("total_items")?;

View File

@@ -2,17 +2,20 @@ use super::*;
impl StateRuntime {
pub async fn get_backfill_state(&self) -> anyhow::Result<crate::BackfillState> {
self.ensure_backfill_state_row().await?;
let row = sqlx::query(
r#"
self.state_db
.read(DbOperation::GetBackfillState, |pool| async move {
let row = sqlx::query(
r#"
SELECT status, last_watermark, last_success_at
FROM backfill_state
WHERE id = 1
"#,
)
.fetch_one(self.pool.as_ref())
.await?;
crate::BackfillState::try_from_row(&row)
)
.fetch_one(&pool)
.await?;
crate::BackfillState::try_from_row(&row)
})
.await
}
/// Attempt to claim ownership of rollout metadata backfill.
@@ -21,69 +24,83 @@ WHERE id = 1
/// Returns `false` if backfill is already complete or currently owned by a
/// non-expired worker.
pub async fn try_claim_backfill(&self, lease_seconds: i64) -> anyhow::Result<bool> {
self.ensure_backfill_state_row().await?;
let now = Utc::now().timestamp();
let lease_cutoff = now.saturating_sub(lease_seconds.max(0));
let result = sqlx::query(
r#"
self.state_db
.write(DbOperation::TryClaimBackfill, |pool| async move {
ensure_backfill_state_row_in_pool(&pool).await?;
let now = Utc::now().timestamp();
let lease_cutoff = now.saturating_sub(lease_seconds.max(0));
let result = sqlx::query(
r#"
UPDATE backfill_state
SET status = ?, updated_at = ?
WHERE id = 1
AND status != ?
AND (status != ? OR updated_at <= ?)
"#,
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(now)
.bind(crate::BackfillStatus::Complete.as_str())
.bind(crate::BackfillStatus::Running.as_str())
.bind(lease_cutoff)
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected() == 1)
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(now)
.bind(crate::BackfillStatus::Complete.as_str())
.bind(crate::BackfillStatus::Running.as_str())
.bind(lease_cutoff)
.execute(&pool)
.await?;
Ok(result.rows_affected() == 1)
})
.await
}
/// Mark rollout metadata backfill as running.
pub async fn mark_backfill_running(&self) -> anyhow::Result<()> {
self.ensure_backfill_state_row().await?;
sqlx::query(
r#"
self.state_db
.write(DbOperation::MarkBackfillRunning, |pool| async move {
ensure_backfill_state_row_in_pool(&pool).await?;
sqlx::query(
r#"
UPDATE backfill_state
SET status = ?, updated_at = ?
WHERE id = 1
"#,
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(Utc::now().timestamp())
.execute(self.pool.as_ref())
.await?;
Ok(())
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(Utc::now().timestamp())
.execute(&pool)
.await?;
Ok(())
})
.await
}
/// Persist rollout metadata backfill progress.
pub async fn checkpoint_backfill(&self, watermark: &str) -> anyhow::Result<()> {
self.ensure_backfill_state_row().await?;
sqlx::query(
r#"
self.state_db
.write(DbOperation::CheckpointBackfill, |pool| async move {
ensure_backfill_state_row_in_pool(&pool).await?;
sqlx::query(
r#"
UPDATE backfill_state
SET status = ?, last_watermark = ?, updated_at = ?
WHERE id = 1
"#,
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(watermark)
.bind(Utc::now().timestamp())
.execute(self.pool.as_ref())
.await?;
Ok(())
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(watermark)
.bind(Utc::now().timestamp())
.execute(&pool)
.await?;
Ok(())
})
.await
}
/// Mark rollout metadata backfill as complete.
pub async fn mark_backfill_complete(&self, last_watermark: Option<&str>) -> anyhow::Result<()> {
self.ensure_backfill_state_row().await?;
let now = Utc::now().timestamp();
sqlx::query(
r#"
self.state_db
.write(DbOperation::MarkBackfillComplete, |pool| async move {
ensure_backfill_state_row_in_pool(&pool).await?;
let now = Utc::now().timestamp();
sqlx::query(
r#"
UPDATE backfill_state
SET
status = ?,
@@ -92,30 +109,16 @@ SET
updated_at = ?
WHERE id = 1
"#,
)
.bind(crate::BackfillStatus::Complete.as_str())
.bind(last_watermark)
.bind(now)
.bind(now)
.execute(self.pool.as_ref())
.await?;
Ok(())
}
async fn ensure_backfill_state_row(&self) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO backfill_state (id, status, last_watermark, last_success_at, updated_at)
VALUES (?, ?, NULL, NULL, ?)
ON CONFLICT(id) DO NOTHING
"#,
)
.bind(1_i64)
.bind(crate::BackfillStatus::Pending.as_str())
.bind(Utc::now().timestamp())
.execute(self.pool.as_ref())
.await?;
Ok(())
)
.bind(crate::BackfillStatus::Complete.as_str())
.bind(last_watermark)
.bind(now)
.bind(now)
.execute(&pool)
.await?;
Ok(())
})
.await
}
}
@@ -286,7 +289,7 @@ WHERE id = 1
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(stale_updated_at)
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("force stale backfill lease");

View File

@@ -0,0 +1,158 @@
use std::future::Future;
use std::sync::Arc;
use std::time::Instant;
use sqlx::SqlitePool;
use crate::telemetry::DbAccess;
use crate::telemetry::DbKind;
use crate::telemetry::DbMetricsRecorder;
use crate::telemetry::DbMetricsRecorderHandle;
/// SQLite pool plus Codex-level operation telemetry context.
#[derive(Clone)]
pub(super) struct InstrumentedDb {
pool: Arc<SqlitePool>,
kind: DbKind,
metrics: Option<DbMetricsRecorderHandle>,
}
impl InstrumentedDb {
pub(super) fn new(
pool: Arc<SqlitePool>,
kind: DbKind,
metrics: Option<DbMetricsRecorderHandle>,
) -> Self {
Self {
pool,
kind,
metrics,
}
}
pub(super) fn pool(&self) -> &SqlitePool {
self.pool.as_ref()
}
pub(super) fn metrics(&self) -> Option<&dyn DbMetricsRecorder> {
self.metrics.as_deref()
}
pub(super) fn metrics_handle(&self) -> Option<DbMetricsRecorderHandle> {
self.metrics.clone()
}
pub(super) async fn read<T, F, Fut>(&self, operation: DbOperation, f: F) -> anyhow::Result<T>
where
F: FnOnce(SqlitePool) -> Fut,
Fut: Future<Output = anyhow::Result<T>>,
{
self.record_operation(operation, DbAccess::Read, f).await
}
pub(super) async fn write<T, F, Fut>(&self, operation: DbOperation, f: F) -> anyhow::Result<T>
where
F: FnOnce(SqlitePool) -> Fut,
Fut: Future<Output = anyhow::Result<T>>,
{
self.record_operation(operation, DbAccess::Write, f).await
}
pub(super) async fn transaction<T, F, Fut>(
&self,
operation: DbOperation,
f: F,
) -> anyhow::Result<T>
where
F: FnOnce(SqlitePool) -> Fut,
Fut: Future<Output = anyhow::Result<T>>,
{
self.record_operation(operation, DbAccess::Transaction, f)
.await
}
pub(super) async fn maintenance<T, F, Fut>(
&self,
operation: DbOperation,
f: F,
) -> anyhow::Result<T>
where
F: FnOnce(SqlitePool) -> Fut,
Fut: Future<Output = anyhow::Result<T>>,
{
self.record_operation(operation, DbAccess::Maintenance, f)
.await
}
pub(super) fn record_result<T>(
&self,
operation: DbOperation,
access: DbAccess,
started: Instant,
result: &anyhow::Result<T>,
) {
crate::telemetry::record_operation_result(
self.metrics(),
self.kind,
operation.as_str(),
access,
started.elapsed(),
result,
);
}
async fn record_operation<T, F, Fut>(
&self,
operation: DbOperation,
access: DbAccess,
f: F,
) -> anyhow::Result<T>
where
F: FnOnce(SqlitePool) -> Fut,
Fut: Future<Output = anyhow::Result<T>>,
{
let started = Instant::now();
let result = f(self.pool().clone()).await;
self.record_result(operation, access, started, &result);
result
}
}
#[derive(Clone, Copy)]
pub(super) enum DbOperation {
CheckpointBackfill,
FindRolloutPathById,
GetBackfillState,
GetDynamicTools,
GetThread,
InsertLogs,
ListThreads,
LogsStartupMaintenance,
MarkBackfillComplete,
MarkBackfillRunning,
PersistDynamicTools,
TouchThreadUpdatedAt,
TryClaimBackfill,
UpsertThread,
}
impl DbOperation {
fn as_str(self) -> &'static str {
match self {
Self::CheckpointBackfill => "checkpoint_backfill",
Self::FindRolloutPathById => "find_rollout_path_by_id",
Self::GetBackfillState => "get_backfill_state",
Self::GetDynamicTools => "get_dynamic_tools",
Self::GetThread => "get_thread",
Self::InsertLogs => "insert_logs",
Self::ListThreads => "list_threads",
Self::LogsStartupMaintenance => "logs_startup_maintenance",
Self::MarkBackfillComplete => "mark_backfill_complete",
Self::MarkBackfillRunning => "mark_backfill_running",
Self::PersistDynamicTools => "persist_dynamic_tools",
Self::TouchThreadUpdatedAt => "touch_thread_updated_at",
Self::TryClaimBackfill => "try_claim_backfill",
Self::UpsertThread => "upsert_thread",
}
}
}

View File

@@ -42,7 +42,7 @@ WHERE thread_id = ?
"#,
)
.bind(thread_id.to_string())
.fetch_optional(self.pool.as_ref())
.fetch_optional(self.state_db.pool())
.await?;
row.map(|row| thread_goal_from_row(&row)).transpose()
@@ -99,7 +99,7 @@ RETURNING
.bind(token_budget)
.bind(now_ms)
.bind(now_ms)
.fetch_one(self.pool.as_ref())
.fetch_one(self.state_db.pool())
.await?;
thread_goal_from_row(&row)
@@ -148,7 +148,7 @@ RETURNING
.bind(token_budget)
.bind(now_ms)
.bind(now_ms)
.fetch_optional(self.pool.as_ref())
.fetch_optional(self.state_db.pool())
.await?;
row.map(|row| thread_goal_from_row(&row)).transpose()
@@ -196,7 +196,7 @@ WHERE thread_id = ?
.bind(thread_id.to_string())
.bind(expected_goal_id)
.bind(expected_goal_id)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
}
(Some(status), None) => {
@@ -224,7 +224,7 @@ WHERE thread_id = ?
.bind(thread_id.to_string())
.bind(expected_goal_id)
.bind(expected_goal_id)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
}
(None, Some(token_budget)) => {
@@ -250,7 +250,7 @@ WHERE thread_id = ?
.bind(thread_id.to_string())
.bind(expected_goal_id)
.bind(expected_goal_id)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
}
(None, None) => {
@@ -289,7 +289,7 @@ WHERE thread_id = ?
.bind(crate::ThreadGoalStatus::Paused.as_str())
.bind(now_ms)
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
if result.rows_affected() == 0 {
@@ -307,7 +307,7 @@ WHERE thread_id = ?
"#,
)
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
@@ -392,7 +392,7 @@ RETURNING
query = query.bind(expected_goal_id);
}
let row = query.fetch_optional(self.pool.as_ref()).await?;
let row = query.fetch_optional(self.state_db.pool()).await?;
let Some(row) = row else {
return Ok(ThreadGoalAccountingOutcome::Unchanged(

View File

@@ -9,41 +9,52 @@ impl StateRuntime {
/// Insert a batch of log entries into the logs table.
pub async fn insert_logs(&self, entries: &[LogEntry]) -> anyhow::Result<()> {
if entries.is_empty() {
return Ok(());
}
let started = Instant::now();
let result: anyhow::Result<()> = async {
if entries.is_empty() {
return Ok(());
}
let mut tx = self.logs_pool.begin().await?;
let mut builder = QueryBuilder::<Sqlite>::new(
"INSERT INTO logs (ts, ts_nanos, level, target, feedback_log_body, thread_id, process_uuid, module_path, file, line, estimated_bytes) ",
let mut tx = self.logs_db.pool().begin().await?;
let mut builder = QueryBuilder::<Sqlite>::new(
"INSERT INTO logs (ts, ts_nanos, level, target, feedback_log_body, thread_id, process_uuid, module_path, file, line, estimated_bytes) ",
);
builder.push_values(entries, |mut row, entry| {
let feedback_log_body = entry.feedback_log_body.as_ref().or(entry.message.as_ref());
// Keep about 10 MiB of reader-visible log content per partition.
// Both `query_logs` and `/feedback` read the persisted
// `feedback_log_body`, while `LogEntry.message` is only a write-time
// fallback for callers that still populate the old field.
let estimated_bytes = feedback_log_body.map_or(0, String::len) as i64
+ entry.level.len() as i64
+ entry.target.len() as i64
+ entry.module_path.as_ref().map_or(0, String::len) as i64
+ entry.file.as_ref().map_or(0, String::len) as i64;
row.push_bind(entry.ts)
.push_bind(entry.ts_nanos)
.push_bind(&entry.level)
.push_bind(&entry.target)
.push_bind(feedback_log_body)
.push_bind(&entry.thread_id)
.push_bind(&entry.process_uuid)
.push_bind(&entry.module_path)
.push_bind(&entry.file)
.push_bind(entry.line)
.push_bind(estimated_bytes);
});
builder.build().execute(&mut *tx).await?;
self.prune_logs_after_insert(entries, &mut tx).await?;
tx.commit().await?;
Ok(())
}
.await;
self.logs_db.record_result(
DbOperation::InsertLogs,
DbAccess::Transaction,
started,
&result,
);
builder.push_values(entries, |mut row, entry| {
let feedback_log_body = entry.feedback_log_body.as_ref().or(entry.message.as_ref());
// Keep about 10 MiB of reader-visible log content per partition.
// Both `query_logs` and `/feedback` read the persisted
// `feedback_log_body`, while `LogEntry.message` is only a write-time
// fallback for callers that still populate the old field.
let estimated_bytes = feedback_log_body.map_or(0, String::len) as i64
+ entry.level.len() as i64
+ entry.target.len() as i64
+ entry.module_path.as_ref().map_or(0, String::len) as i64
+ entry.file.as_ref().map_or(0, String::len) as i64;
row.push_bind(entry.ts)
.push_bind(entry.ts_nanos)
.push_bind(&entry.level)
.push_bind(&entry.target)
.push_bind(feedback_log_body)
.push_bind(&entry.thread_id)
.push_bind(&entry.process_uuid)
.push_bind(&entry.module_path)
.push_bind(&entry.file)
.push_bind(entry.line)
.push_bind(estimated_bytes);
});
builder.build().execute(&mut *tx).await?;
self.prune_logs_after_insert(entries, &mut tx).await?;
tx.commit().await?;
Ok(())
result
}
/// Enforce per-partition retained-log-content caps after a successful batch insert.
@@ -285,28 +296,27 @@ WHERE id IN (
Ok(())
}
pub(crate) async fn delete_logs_before(&self, cutoff_ts: i64) -> anyhow::Result<u64> {
let result = sqlx::query("DELETE FROM logs WHERE ts < ?")
.bind(cutoff_ts)
.execute(self.logs_pool.as_ref())
.await?;
Ok(result.rows_affected())
}
pub(crate) async fn run_logs_startup_maintenance(&self) -> anyhow::Result<()> {
let Some(cutoff) =
Utc::now().checked_sub_signed(chrono::Duration::days(LOG_RETENTION_DAYS))
else {
return Ok(());
};
self.delete_logs_before(cutoff.timestamp()).await?;
// Startup cleanup should not wait behind or block foreground work.
// PASSIVE checkpoints copy whatever is immediately available and skip
// frames that would require waiting on active readers or writers.
sqlx::query("PRAGMA wal_checkpoint(PASSIVE)")
.execute(self.logs_pool.as_ref())
.await?;
Ok(())
self.logs_db
.maintenance(DbOperation::LogsStartupMaintenance, |pool| async move {
let Some(cutoff) =
Utc::now().checked_sub_signed(chrono::Duration::days(LOG_RETENTION_DAYS))
else {
return Ok(());
};
sqlx::query("DELETE FROM logs WHERE ts < ?")
.bind(cutoff.timestamp())
.execute(&pool)
.await?;
// Startup cleanup should not wait behind or block foreground work.
// PASSIVE checkpoints copy whatever is immediately available and skip
// frames that would require waiting on active readers or writers.
sqlx::query("PRAGMA wal_checkpoint(PASSIVE)")
.execute(&pool)
.await?;
Ok(())
})
.await
}
/// Query logs with optional filters.
@@ -326,7 +336,7 @@ WHERE id IN (
let rows = builder
.build_query_as::<LogRow>()
.fetch_all(self.logs_pool.as_ref())
.fetch_all(self.logs_db.pool())
.await?;
Ok(rows)
}
@@ -398,7 +408,7 @@ ORDER BY ts DESC, ts_nanos DESC, id DESC
}
let rows = sql
.bind(LOG_PARTITION_SIZE_LIMIT_BYTES)
.fetch_all(self.logs_pool.as_ref())
.fetch_all(self.logs_db.pool())
.await?;
let mut lines = Vec::new();
@@ -431,7 +441,7 @@ ORDER BY ts DESC, ts_nanos DESC, id DESC
let mut builder =
QueryBuilder::<Sqlite>::new("SELECT MAX(id) AS max_id FROM logs WHERE 1 = 1");
push_log_filters(&mut builder, query);
let row = builder.build().fetch_one(self.logs_pool.as_ref()).await?;
let row = builder.build().fetch_one(self.logs_db.pool()).await?;
let max_id: Option<i64> = row.try_get("max_id")?;
Ok(max_id.unwrap_or(0))
}

View File

@@ -30,7 +30,7 @@ impl StateRuntime {
/// stage-1 (`memory_stage1`) and phase-2 (`memory_consolidate_global`)
/// memory pipelines.
pub async fn clear_memory_data(&self) -> anyhow::Result<()> {
let mut tx = self.pool.begin().await?;
let mut tx = self.state_db.pool().begin().await?;
sqlx::query(
r#"
@@ -68,7 +68,7 @@ WHERE kind = ? OR kind = ?
}
let now = Utc::now().timestamp();
let mut tx = self.pool.begin().await?;
let mut tx = self.state_db.pool().begin().await?;
let mut updated_rows = 0;
for thread_id in thread_ids {
@@ -209,7 +209,7 @@ LEFT JOIN jobs
let items = builder
.build()
.fetch_all(self.pool.as_ref())
.fetch_all(self.state_db.pool())
.await?
.into_iter()
.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
@@ -279,7 +279,7 @@ LIMIT ?
"#,
)
.bind(n as i64)
.fetch_all(self.pool.as_ref())
.fetch_all(self.state_db.pool())
.await?;
rows.into_iter()
@@ -323,7 +323,7 @@ WHERE thread_id IN (
)
.bind(cutoff)
.bind(limit as i64)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
.rows_affected();
@@ -400,7 +400,7 @@ ORDER BY selected.thread_id ASC
.bind(cutoff)
.bind(cutoff)
.bind(n as i64)
.fetch_all(self.pool.as_ref())
.fetch_all(self.state_db.pool())
.await?;
let mut selected = Vec::with_capacity(current_rows.len());
@@ -421,7 +421,7 @@ ORDER BY selected.thread_id ASC
) -> anyhow::Result<bool> {
let now = Utc::now().timestamp();
let thread_id = thread_id.to_string();
let mut tx = self.pool.begin().await?;
let mut tx = self.state_db.pool().begin().await?;
let rows_affected = sqlx::query(
r#"
UPDATE threads
@@ -489,7 +489,7 @@ WHERE thread_id = ?
let thread_id = thread_id.to_string();
let worker_id = worker_id.to_string();
let mut tx = self.pool.begin_with("BEGIN IMMEDIATE").await?;
let mut tx = self.state_db.pool().begin_with("BEGIN IMMEDIATE").await?;
let existing_output = sqlx::query(
r#"
@@ -673,7 +673,7 @@ WHERE kind = ? AND job_key = ?
let now = Utc::now().timestamp();
let thread_id = thread_id.to_string();
let mut tx = self.pool.begin().await?;
let mut tx = self.state_db.pool().begin().await?;
let rows_affected = sqlx::query(
r#"
UPDATE jobs
@@ -750,7 +750,7 @@ WHERE excluded.source_updated_at >= stage1_outputs.source_updated_at
let now = Utc::now().timestamp();
let thread_id = thread_id.to_string();
let mut tx = self.pool.begin().await?;
let mut tx = self.state_db.pool().begin().await?;
let rows_affected = sqlx::query(
r#"
UPDATE jobs
@@ -848,7 +848,7 @@ WHERE kind = ? AND job_key = ?
.bind(JOB_KIND_MEMORY_STAGE1)
.bind(thread_id.as_str())
.bind(ownership_token)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
.rows_affected();
@@ -863,7 +863,7 @@ WHERE kind = ? AND job_key = ?
/// Phase 2 does not use this watermark as a dirty check; git workspace diffing
/// decides whether consolidation work exists after the lock is claimed.
pub async fn enqueue_global_consolidation(&self, input_watermark: i64) -> anyhow::Result<()> {
enqueue_global_consolidation_with_executor(self.pool.as_ref(), input_watermark).await
enqueue_global_consolidation_with_executor(self.state_db.pool(), input_watermark).await
}
/// Attempts to claim the global phase-2 consolidation lock.
@@ -890,7 +890,7 @@ WHERE kind = ? AND job_key = ?
let ownership_token = Uuid::new_v4().to_string();
let worker_id = worker_id.to_string();
let mut tx = self.pool.begin_with("BEGIN IMMEDIATE").await?;
let mut tx = self.state_db.pool().begin_with("BEGIN IMMEDIATE").await?;
let existing_job = sqlx::query(
r#"
@@ -1035,7 +1035,7 @@ WHERE kind = ? AND job_key = ?
.bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL)
.bind(MEMORY_CONSOLIDATION_JOB_KEY)
.bind(ownership_token)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
.rows_affected();
@@ -1058,7 +1058,7 @@ WHERE kind = ? AND job_key = ?
completed_watermark: i64,
selected_outputs: &[Stage1Output],
) -> anyhow::Result<bool> {
let mut tx = self.pool.begin().await?;
let mut tx = self.state_db.pool().begin().await?;
let rows_affected =
mark_global_phase2_job_succeeded_row(&mut *tx, ownership_token, completed_watermark)
.await?;
@@ -1136,7 +1136,7 @@ WHERE kind = ? AND job_key = ?
.bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL)
.bind(MEMORY_CONSOLIDATION_JOB_KEY)
.bind(ownership_token)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
.rows_affected();
@@ -1178,7 +1178,7 @@ WHERE kind = ? AND job_key = ?
.bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL)
.bind(MEMORY_CONSOLIDATION_JOB_KEY)
.bind(ownership_token)
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?
.rows_affected();
@@ -1300,7 +1300,7 @@ mod tests {
.bind(Utc::now().timestamp() - PHASE2_SUCCESS_COOLDOWN_SECONDS - 1)
.bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL)
.bind(MEMORY_CONSOLIDATION_JOB_KEY)
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("age phase2 success beyond cooldown");
}
@@ -1409,7 +1409,7 @@ mod tests {
sqlx::query("UPDATE jobs SET lease_until = 0 WHERE kind = 'memory_stage1' AND job_key = ?")
.bind(thread_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("force stale lease");
@@ -1789,7 +1789,7 @@ mod tests {
.expect("upsert disabled thread");
sqlx::query("UPDATE threads SET memory_mode = 'disabled' WHERE id = ?")
.bind(disabled_thread_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("disable thread memory mode");
@@ -1890,7 +1890,7 @@ mod tests {
.expect("upsert disabled thread");
sqlx::query("UPDATE threads SET memory_mode = 'disabled' WHERE id = ?")
.bind(disabled_thread_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("disable existing thread");
@@ -1900,7 +1900,7 @@ mod tests {
.expect("clear memory data");
let stage1_outputs_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM stage1_outputs")
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count stage1 outputs");
assert_eq!(stage1_outputs_count, 0);
@@ -1909,7 +1909,7 @@ mod tests {
sqlx::query_scalar("SELECT COUNT(*) FROM jobs WHERE kind = ? OR kind = ?")
.bind(JOB_KIND_MEMORY_STAGE1)
.bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL)
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count memory jobs");
assert_eq!(memory_jobs_count, 0);
@@ -1917,7 +1917,7 @@ mod tests {
let enabled_memory_mode: String =
sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?")
.bind(enabled_thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("read enabled thread memory mode");
assert_eq!(enabled_memory_mode, "enabled");
@@ -1925,7 +1925,7 @@ mod tests {
let disabled_memory_mode: String =
sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?")
.bind(disabled_thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("read disabled thread memory mode");
assert_eq!(disabled_memory_mode, "disabled");
@@ -2000,7 +2000,7 @@ INSERT INTO jobs (
.bind(lease_until)
.bind(3)
.bind(metadata.updated_at.timestamp())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("seed running stage1 job");
}
@@ -2034,7 +2034,7 @@ WHERE kind = 'memory_stage1'
"#,
)
.bind(Utc::now().timestamp())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count running stage1 jobs")
.try_get::<i64, _>("count")
@@ -2191,7 +2191,7 @@ WHERE kind = 'memory_stage1'
let count_before =
sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count before delete")
.try_get::<i64, _>("count")
@@ -2200,14 +2200,14 @@ WHERE kind = 'memory_stage1'
sqlx::query("DELETE FROM threads WHERE id = ?")
.bind(thread_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("delete thread");
let count_after =
sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count after delete")
.try_get::<i64, _>("count")
@@ -2258,7 +2258,7 @@ WHERE kind = 'memory_stage1'
let output_row_count =
sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load stage1 output count")
.try_get::<i64, _>("count")
@@ -2279,7 +2279,7 @@ WHERE kind = 'memory_stage1'
let global_job_row_count = sqlx::query("SELECT COUNT(*) AS count FROM jobs WHERE kind = ?")
.bind("memory_consolidate_global")
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load phase2 job row count")
.try_get::<i64, _>("count")
@@ -2383,7 +2383,7 @@ WHERE kind = 'memory_stage1'
let output_row_count =
sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load stage1 output count after delete")
.try_get::<i64, _>("count")
@@ -2494,7 +2494,7 @@ WHERE kind = 'memory_stage1'
)
.bind("memory_stage1")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load stage1 job row after newer-source claim");
assert_eq!(
@@ -2620,7 +2620,7 @@ WHERE kind = 'memory_stage1'
sqlx::query("SELECT retry_remaining FROM jobs WHERE kind = ? AND job_key = ?")
.bind("memory_consolidate_global")
.bind("global")
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load phase2 job row after retry exhaustion");
assert_eq!(
@@ -2787,7 +2787,7 @@ VALUES (?, ?, ?, ?, ?)
.bind("raw memory")
.bind("summary")
.bind(100_i64)
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("insert non-empty stage1 output");
sqlx::query(
@@ -2801,7 +2801,7 @@ VALUES (?, ?, ?, ?, ?)
.bind("")
.bind("")
.bind(101_i64)
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("insert empty stage1 output");
@@ -3292,7 +3292,7 @@ VALUES (?, ?, ?, ?, ?)
"SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?",
)
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load selected_for_phase2");
assert_eq!(selected_for_phase2, 1);
@@ -3585,7 +3585,7 @@ VALUES (?, ?, ?, ?, ?)
"SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?",
)
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load selected snapshot after phase2");
assert_eq!(selected_for_phase2, 1);
@@ -3698,7 +3698,7 @@ VALUES (?, ?, ?, ?, ?)
"SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?",
)
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load selected_for_phase2");
assert_eq!(selected_for_phase2, 0);
@@ -3802,13 +3802,13 @@ VALUES (?, ?, ?, ?, ?)
let row_a =
sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?")
.bind(thread_a.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load stage1 usage row a");
let row_b =
sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?")
.bind(thread_b.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("load stage1 usage row b");
@@ -3908,7 +3908,7 @@ VALUES (?, ?, ?, ?, ?)
.bind(usage_count)
.bind(last_usage.timestamp())
.bind(thread_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("update usage metadata");
}
@@ -4004,7 +4004,7 @@ VALUES (?, ?, ?, ?, ?)
.bind(usage_count)
.bind(last_usage.map(|value| value.timestamp()))
.bind(thread_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("update usage metadata");
}
@@ -4089,13 +4089,13 @@ VALUES (?, ?, ?, ?, ?)
sqlx::query("UPDATE stage1_outputs SET generated_at = ? WHERE thread_id = ?")
.bind(300_i64)
.bind(older_thread.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("update older generated_at");
sqlx::query("UPDATE stage1_outputs SET generated_at = ? WHERE thread_id = ?")
.bind(150_i64)
.bind(newer_thread.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("update newer generated_at");
@@ -4201,14 +4201,14 @@ VALUES (?, ?, ?, ?, ?)
.bind(3_i64)
.bind(now - Duration::days(40).num_seconds())
.bind(stale_used.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("set stale used metadata");
sqlx::query(
"UPDATE stage1_outputs SET selected_for_phase2 = 1, selected_for_phase2_source_updated_at = source_updated_at WHERE thread_id = ?",
)
.bind(stale_selected.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("mark selected for phase2");
sqlx::query(
@@ -4217,13 +4217,13 @@ VALUES (?, ?, ?, ?, ?)
.bind(8_i64)
.bind(now - Duration::days(2).num_seconds())
.bind(fresh_used.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("set fresh used metadata");
let before_jobs_count =
sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM jobs WHERE kind = 'memory_stage1'")
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count stage1 jobs before prune");
@@ -4236,7 +4236,7 @@ VALUES (?, ?, ?, ?, ?)
let remaining = sqlx::query_scalar::<_, String>(
"SELECT thread_id FROM stage1_outputs ORDER BY thread_id",
)
.fetch_all(runtime.pool.as_ref())
.fetch_all(runtime.state_db.pool())
.await
.expect("load remaining stage1 outputs");
let mut expected_remaining = vec![fresh_used.to_string(), stale_selected.to_string()];
@@ -4245,7 +4245,7 @@ VALUES (?, ?, ?, ?, ?)
let after_jobs_count =
sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM jobs WHERE kind = 'memory_stage1'")
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count stage1 jobs after prune");
assert_eq!(after_jobs_count, before_jobs_count);
@@ -4323,7 +4323,7 @@ VALUES (?, ?, ?, ?, ?)
assert_eq!(pruned, 2);
let remaining_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM stage1_outputs")
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("count remaining stage1 outputs");
assert_eq!(remaining_count, 1);
@@ -4539,7 +4539,7 @@ VALUES (?, ?, ?, ?, ?)
.bind(Utc::now().timestamp() - 1)
.bind("memory_consolidate_global")
.bind("global")
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("expire global consolidation lease");
@@ -4675,7 +4675,7 @@ VALUES (?, ?, ?, ?, ?)
sqlx::query("UPDATE jobs SET ownership_token = NULL WHERE kind = ? AND job_key = ?")
.bind("memory_consolidate_global")
.bind("global")
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("clear ownership token");

View File

@@ -44,7 +44,7 @@ WHERE websocket_url = ? AND account_id = ? AND app_server_client_name = ?
.bind(remote_control_app_server_client_name_key(
app_server_client_name,
))
.fetch_optional(self.pool.as_ref())
.fetch_optional(self.state_db.pool())
.await?;
row.map(|row| {
@@ -92,7 +92,7 @@ ON CONFLICT(websocket_url, account_id, app_server_client_name) DO UPDATE SET
.bind(&enrollment.environment_id)
.bind(&enrollment.server_name)
.bind(Utc::now().timestamp())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(())
}
@@ -114,7 +114,7 @@ WHERE websocket_url = ? AND account_id = ? AND app_server_client_name = ?
.bind(remote_control_app_server_client_name_key(
app_server_client_name,
))
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected())
}

View File

@@ -5,8 +5,10 @@ use std::sync::atomic::Ordering;
impl StateRuntime {
pub async fn get_thread(&self, id: ThreadId) -> anyhow::Result<Option<crate::ThreadMetadata>> {
let row = sqlx::query(
r#"
self.state_db
.read(DbOperation::GetThread, |pool| async move {
let row = sqlx::query(
r#"
SELECT
threads.id,
threads.rollout_path,
@@ -34,18 +36,20 @@ SELECT
FROM threads
WHERE threads.id = ?
"#,
)
.bind(id.to_string())
.fetch_optional(self.pool.as_ref())
.await?;
row.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.transpose()
)
.bind(id.to_string())
.fetch_optional(&pool)
.await?;
row.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.transpose()
})
.await
}
pub async fn get_thread_memory_mode(&self, id: ThreadId) -> anyhow::Result<Option<String>> {
let row = sqlx::query("SELECT memory_mode FROM threads WHERE id = ?")
.bind(id.to_string())
.fetch_optional(self.pool.as_ref())
.fetch_optional(self.state_db.pool())
.await?;
Ok(row.and_then(|row| row.try_get("memory_mode").ok()))
}
@@ -55,33 +59,37 @@ WHERE threads.id = ?
&self,
thread_id: ThreadId,
) -> anyhow::Result<Option<Vec<DynamicToolSpec>>> {
let rows = sqlx::query(
r#"
self.state_db
.read(DbOperation::GetDynamicTools, |pool| async move {
let rows = sqlx::query(
r#"
SELECT namespace, name, description, input_schema, defer_loading
FROM thread_dynamic_tools
WHERE thread_id = ?
ORDER BY position ASC
"#,
)
.bind(thread_id.to_string())
.fetch_all(self.pool.as_ref())
.await?;
if rows.is_empty() {
return Ok(None);
}
let mut tools = Vec::with_capacity(rows.len());
for row in rows {
let input_schema: String = row.try_get("input_schema")?;
let input_schema = serde_json::from_str::<Value>(input_schema.as_str())?;
tools.push(DynamicToolSpec {
namespace: row.try_get("namespace")?,
name: row.try_get("name")?,
description: row.try_get("description")?,
input_schema,
defer_loading: row.try_get("defer_loading")?,
});
}
Ok(Some(tools))
)
.bind(thread_id.to_string())
.fetch_all(&pool)
.await?;
if rows.is_empty() {
return Ok(None);
}
let mut tools = Vec::with_capacity(rows.len());
for row in rows {
let input_schema: String = row.try_get("input_schema")?;
let input_schema = serde_json::from_str::<Value>(input_schema.as_str())?;
tools.push(DynamicToolSpec {
namespace: row.try_get("namespace")?,
name: row.try_get("name")?,
description: row.try_get("description")?,
input_schema,
defer_loading: row.try_get("defer_loading")?,
});
}
Ok(Some(tools))
})
.await
}
/// Persist or replace the directional parent-child edge for a spawned thread.
@@ -106,7 +114,7 @@ ON CONFLICT(child_thread_id) DO UPDATE SET
.bind(parent_thread_id.to_string())
.bind(child_thread_id.to_string())
.bind(status.as_ref())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(())
}
@@ -120,7 +128,7 @@ ON CONFLICT(child_thread_id) DO UPDATE SET
sqlx::query("UPDATE thread_spawn_edges SET status = ? WHERE child_thread_id = ?")
.bind(status.as_ref())
.bind(child_thread_id.to_string())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(())
}
@@ -186,7 +194,7 @@ LIMIT 2
)
.bind(parent_thread_id.to_string())
.bind(agent_path)
.fetch_all(self.pool.as_ref())
.fetch_all(self.state_db.pool())
.await?;
one_thread_id_from_rows(rows, agent_path)
}
@@ -218,7 +226,7 @@ LIMIT 2
)
.bind(root_thread_id.to_string())
.bind(agent_path)
.fetch_all(self.pool.as_ref())
.fetch_all(self.state_db.pool())
.await?;
one_thread_id_from_rows(rows, agent_path)
}
@@ -241,7 +249,7 @@ LIMIT 2
sql = sql.bind(status.to_string());
}
let rows = sql.fetch_all(self.pool.as_ref()).await?;
let rows = sql.fetch_all(self.state_db.pool()).await?;
rows.into_iter()
.map(|row| {
ThreadId::try_from(row.try_get::<String, _>("child_thread_id")?).map_err(Into::into)
@@ -283,7 +291,7 @@ ORDER BY depth ASC, child_thread_id ASC
sql = sql.bind(status.clone()).bind(status);
}
let rows = sql.fetch_all(self.pool.as_ref()).await?;
let rows = sql.fetch_all(self.state_db.pool()).await?;
rows.into_iter()
.map(|row| {
ThreadId::try_from(row.try_get::<String, _>("child_thread_id")?).map_err(Into::into)
@@ -309,7 +317,7 @@ ON CONFLICT(child_thread_id) DO NOTHING
.bind(parent_thread_id.to_string())
.bind(child_thread_id.to_string())
.bind(crate::DirectionalThreadSpawnEdgeStatus::Open.as_ref())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(())
}
@@ -332,22 +340,26 @@ ON CONFLICT(child_thread_id) DO NOTHING
id: ThreadId,
archived_only: Option<bool>,
) -> anyhow::Result<Option<PathBuf>> {
let mut builder =
QueryBuilder::<Sqlite>::new("SELECT rollout_path FROM threads WHERE id = ");
builder.push_bind(id.to_string());
match archived_only {
Some(true) => {
builder.push(" AND archived = 1");
}
Some(false) => {
builder.push(" AND archived = 0");
}
None => {}
}
let row = builder.build().fetch_optional(self.pool.as_ref()).await?;
Ok(row
.and_then(|r| r.try_get::<String, _>("rollout_path").ok())
.map(PathBuf::from))
self.state_db
.read(DbOperation::FindRolloutPathById, |pool| async move {
let mut builder =
QueryBuilder::<Sqlite>::new("SELECT rollout_path FROM threads WHERE id = ");
builder.push_bind(id.to_string());
match archived_only {
Some(true) => {
builder.push(" AND archived = 1");
}
Some(false) => {
builder.push(" AND archived = 0");
}
None => {}
}
let row = builder.build().fetch_optional(&pool).await?;
Ok(row
.and_then(|r| r.try_get::<String, _>("rollout_path").ok())
.map(PathBuf::from))
})
.await
}
/// Find the newest thread whose user-facing title exactly matches `title`.
@@ -389,7 +401,7 @@ ON CONFLICT(child_thread_id) DO NOTHING
/*limit*/ 1,
);
let row = builder.build().fetch_optional(self.pool.as_ref()).await?;
let row = builder.build().fetch_optional(self.state_db.pool()).await?;
row.map(|row| ThreadRow::try_from_row(&row).and_then(crate::ThreadMetadata::try_from))
.transpose()
}
@@ -400,35 +412,39 @@ ON CONFLICT(child_thread_id) DO NOTHING
page_size: usize,
filters: ThreadFilterOptions<'_>,
) -> anyhow::Result<crate::ThreadsPage> {
let limit = page_size.saturating_add(1);
let sort_key = filters.sort_key;
let sort_direction = filters.sort_direction;
self.state_db
.read(DbOperation::ListThreads, |pool| async move {
let limit = page_size.saturating_add(1);
let sort_key = filters.sort_key;
let sort_direction = filters.sort_direction;
let mut builder = QueryBuilder::<Sqlite>::new("");
push_thread_select_columns(&mut builder);
builder.push(" FROM threads");
push_thread_filters(&mut builder, filters);
push_thread_order_and_limit(&mut builder, sort_key, sort_direction, limit);
let mut builder = QueryBuilder::<Sqlite>::new("");
push_thread_select_columns(&mut builder);
builder.push(" FROM threads");
push_thread_filters(&mut builder, filters);
push_thread_order_and_limit(&mut builder, sort_key, sort_direction, limit);
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
let mut items = rows
.into_iter()
.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.collect::<Result<Vec<_>, _>>()?;
let num_scanned_rows = items.len();
let next_anchor = if items.len() > page_size {
items.pop();
items
.last()
.and_then(|item| anchor_from_item(item, sort_key))
} else {
None
};
Ok(ThreadsPage {
items,
next_anchor,
num_scanned_rows,
})
let rows = builder.build().fetch_all(&pool).await?;
let mut items = rows
.into_iter()
.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.collect::<Result<Vec<_>, _>>()?;
let num_scanned_rows = items.len();
let next_anchor = if items.len() > page_size {
items.pop();
items
.last()
.and_then(|item| anchor_from_item(item, sort_key))
} else {
None
};
Ok(ThreadsPage {
items,
next_anchor,
num_scanned_rows,
})
})
.await
}
/// List thread ids using the underlying database (no rollout scanning).
@@ -457,7 +473,7 @@ ON CONFLICT(child_thread_id) DO NOTHING
);
push_thread_order_and_limit(&mut builder, sort_key, SortDirection::Desc, limit);
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
let rows = builder.build().fetch_all(self.state_db.pool()).await?;
rows.into_iter()
.map(|row| {
let id: String = row.try_get("id")?;
@@ -547,7 +563,7 @@ ON CONFLICT(id) DO NOTHING
.bind(metadata.git_branch.as_deref())
.bind(metadata.git_origin_url.as_deref())
.bind("enabled")
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
self.insert_thread_spawn_edge_from_source_if_absent(metadata.id, metadata.source.as_str())
.await?;
@@ -562,7 +578,7 @@ ON CONFLICT(id) DO NOTHING
let result = sqlx::query("UPDATE threads SET memory_mode = ? WHERE id = ?")
.bind(memory_mode)
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -575,7 +591,7 @@ ON CONFLICT(id) DO NOTHING
let result = sqlx::query("UPDATE threads SET title = ? WHERE id = ?")
.bind(title)
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -586,14 +602,19 @@ ON CONFLICT(id) DO NOTHING
updated_at: DateTime<Utc>,
) -> anyhow::Result<bool> {
let updated_at = self.allocate_thread_updated_at(updated_at)?;
let result =
sqlx::query("UPDATE threads SET updated_at = ?, updated_at_ms = ? WHERE id = ?")
self.state_db
.write(DbOperation::TouchThreadUpdatedAt, |pool| async move {
let result = sqlx::query(
"UPDATE threads SET updated_at = ?, updated_at_ms = ? WHERE id = ?",
)
.bind(datetime_to_epoch_seconds(updated_at))
.bind(datetime_to_epoch_millis(updated_at))
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.execute(&pool)
.await?;
Ok(result.rows_affected() > 0)
Ok(result.rows_affected() > 0)
})
.await
}
/// Allocate a persisted `updated_at` value for thread-list cursor ordering.
@@ -666,7 +687,7 @@ WHERE id = ?
.bind(git_origin_url.is_some())
.bind(git_origin_url.flatten())
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -676,12 +697,14 @@ WHERE id = ?
metadata: &crate::ThreadMetadata,
creation_memory_mode: Option<&str>,
) -> anyhow::Result<()> {
let updated_at = self.allocate_thread_updated_at(metadata.updated_at)?;
// Backfill/reconcile callers merge existing git info before upserting, but that
// read/modify/write is not atomic. Preserve non-null SQLite git fields here so
// an explicit metadata update cannot be lost if a stale rollout upsert lands later.
sqlx::query(
r#"
let started = Instant::now();
let result: anyhow::Result<()> = async {
let updated_at = self.allocate_thread_updated_at(metadata.updated_at)?;
// Backfill/reconcile callers merge existing git info before upserting, but that
// read/modify/write is not atomic. Preserve non-null SQLite git fields here so
// an explicit metadata update cannot be lost if a stale rollout upsert lands later.
sqlx::query(
r#"
INSERT INTO threads (
id,
rollout_path,
@@ -738,48 +761,56 @@ ON CONFLICT(id) DO UPDATE SET
git_branch = COALESCE(threads.git_branch, excluded.git_branch),
git_origin_url = COALESCE(threads.git_origin_url, excluded.git_origin_url)
"#,
)
.bind(metadata.id.to_string())
.bind(metadata.rollout_path.display().to_string())
.bind(datetime_to_epoch_seconds(metadata.created_at))
.bind(datetime_to_epoch_seconds(updated_at))
.bind(datetime_to_epoch_millis(metadata.created_at))
.bind(datetime_to_epoch_millis(updated_at))
.bind(metadata.source.as_str())
.bind(
metadata
.thread_source
.map(codex_protocol::protocol::ThreadSource::as_str),
)
.bind(metadata.agent_nickname.as_deref())
.bind(metadata.agent_role.as_deref())
.bind(metadata.agent_path.as_deref())
.bind(metadata.model_provider.as_str())
.bind(metadata.model.as_deref())
.bind(
metadata
.reasoning_effort
.as_ref()
.map(crate::extract::enum_to_string),
)
.bind(metadata.cwd.display().to_string())
.bind(metadata.cli_version.as_str())
.bind(metadata.title.as_str())
.bind(metadata.sandbox_policy.as_str())
.bind(metadata.approval_mode.as_str())
.bind(metadata.tokens_used)
.bind(metadata.first_user_message.as_deref().unwrap_or_default())
.bind(metadata.archived_at.is_some())
.bind(metadata.archived_at.map(datetime_to_epoch_seconds))
.bind(metadata.git_sha.as_deref())
.bind(metadata.git_branch.as_deref())
.bind(metadata.git_origin_url.as_deref())
.bind(creation_memory_mode.unwrap_or("enabled"))
.execute(self.pool.as_ref())
.await?;
self.insert_thread_spawn_edge_from_source_if_absent(metadata.id, metadata.source.as_str())
)
.bind(metadata.id.to_string())
.bind(metadata.rollout_path.display().to_string())
.bind(datetime_to_epoch_seconds(metadata.created_at))
.bind(datetime_to_epoch_seconds(updated_at))
.bind(datetime_to_epoch_millis(metadata.created_at))
.bind(datetime_to_epoch_millis(updated_at))
.bind(metadata.source.as_str())
.bind(
metadata
.thread_source
.map(codex_protocol::protocol::ThreadSource::as_str),
)
.bind(metadata.agent_nickname.as_deref())
.bind(metadata.agent_role.as_deref())
.bind(metadata.agent_path.as_deref())
.bind(metadata.model_provider.as_str())
.bind(metadata.model.as_deref())
.bind(
metadata
.reasoning_effort
.as_ref()
.map(crate::extract::enum_to_string),
)
.bind(metadata.cwd.display().to_string())
.bind(metadata.cli_version.as_str())
.bind(metadata.title.as_str())
.bind(metadata.sandbox_policy.as_str())
.bind(metadata.approval_mode.as_str())
.bind(metadata.tokens_used)
.bind(metadata.first_user_message.as_deref().unwrap_or_default())
.bind(metadata.archived_at.is_some())
.bind(metadata.archived_at.map(datetime_to_epoch_seconds))
.bind(metadata.git_sha.as_deref())
.bind(metadata.git_branch.as_deref())
.bind(metadata.git_origin_url.as_deref())
.bind(creation_memory_mode.unwrap_or("enabled"))
.execute(self.state_db.pool())
.await?;
Ok(())
self.insert_thread_spawn_edge_from_source_if_absent(
metadata.id,
metadata.source.as_str(),
)
.await?;
Ok(())
}
.await;
self.state_db
.record_result(DbOperation::UpsertThread, DbAccess::Write, started, &result);
result
}
/// Persist dynamic tools for a thread if none have been stored yet.
@@ -791,19 +822,21 @@ ON CONFLICT(id) DO UPDATE SET
thread_id: ThreadId,
tools: Option<&[DynamicToolSpec]>,
) -> anyhow::Result<()> {
let Some(tools) = tools else {
return Ok(());
};
if tools.is_empty() {
return Ok(());
}
let thread_id = thread_id.to_string();
let mut tx = self.pool.begin().await?;
for (idx, tool) in tools.iter().enumerate() {
let position = i64::try_from(idx).unwrap_or(i64::MAX);
let input_schema = serde_json::to_string(&tool.input_schema)?;
sqlx::query(
r#"
self.state_db
.transaction(DbOperation::PersistDynamicTools, |pool| async move {
let Some(tools) = tools else {
return Ok(());
};
if tools.is_empty() {
return Ok(());
}
let thread_id = thread_id.to_string();
let mut tx = pool.begin().await?;
for (idx, tool) in tools.iter().enumerate() {
let position = i64::try_from(idx).unwrap_or(i64::MAX);
let input_schema = serde_json::to_string(&tool.input_schema)?;
sqlx::query(
r#"
INSERT INTO thread_dynamic_tools (
thread_id,
position,
@@ -815,19 +848,21 @@ INSERT INTO thread_dynamic_tools (
) VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(thread_id, position) DO NOTHING
"#,
)
.bind(thread_id.as_str())
.bind(position)
.bind(tool.namespace.as_deref())
.bind(tool.name.as_str())
.bind(tool.description.as_str())
.bind(input_schema)
.bind(tool.defer_loading)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(())
)
.bind(thread_id.as_str())
.bind(position)
.bind(tool.namespace.as_deref())
.bind(tool.name.as_str())
.bind(tool.description.as_str())
.bind(input_schema)
.bind(tool.defer_loading)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(())
})
.await
}
/// Apply rollout items incrementally using the underlying database.
@@ -937,7 +972,7 @@ ON CONFLICT(thread_id, position) DO NOTHING
pub async fn delete_thread(&self, thread_id: ThreadId) -> anyhow::Result<u64> {
let result = sqlx::query("DELETE FROM threads WHERE id = ?")
.bind(thread_id.to_string())
.execute(self.pool.as_ref())
.execute(self.state_db.pool())
.await?;
Ok(result.rows_affected())
}
@@ -1171,7 +1206,7 @@ mod tests {
let memory_mode: String =
sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("memory mode should be readable");
assert_eq!(memory_mode, "disabled");
@@ -1185,7 +1220,7 @@ mod tests {
let memory_mode: String =
sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("memory mode should remain readable");
assert_eq!(memory_mode, "disabled");
@@ -1539,7 +1574,7 @@ mod tests {
.bind(123_i64)
.bind("newer preview")
.bind(thread_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("concurrent metadata write should succeed");
@@ -1739,7 +1774,7 @@ mod tests {
"SELECT created_at, updated_at, created_at_ms, updated_at_ms FROM threads WHERE id = ?",
)
.bind(second_id.to_string())
.fetch_one(runtime.pool.as_ref())
.fetch_one(runtime.state_db.pool())
.await
.expect("thread timestamp row should load");
assert_eq!(
@@ -1773,7 +1808,7 @@ mod tests {
sqlx::query("UPDATE threads SET updated_at = ? WHERE id = ?")
.bind(1_700_001_112_i64)
.bind(first_id.to_string())
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("legacy timestamp write should succeed");
let legacy = runtime
@@ -1985,7 +2020,7 @@ INSERT INTO thread_spawn_edges (
.bind(parent_thread_id.to_string())
.bind(future_child_thread_id.to_string())
.bind("future")
.execute(runtime.pool.as_ref())
.execute(runtime.state_db.pool())
.await
.expect("future-status child edge insert should succeed");

View File

@@ -0,0 +1,342 @@
use std::borrow::Cow;
use std::sync::Arc;
use std::time::Duration;
use crate::DB_FALLBACK_METRIC;
use crate::DB_INIT_DURATION_METRIC;
use crate::DB_INIT_METRIC;
use crate::DB_LOG_QUEUE_METRIC;
use crate::DB_OPERATION_DURATION_METRIC;
use crate::DB_OPERATION_METRIC;
/// Low-cardinality metrics sink used by the SQLite state runtime.
///
/// Implementations should ignore recording errors locally. Database operations
/// must never fail because telemetry delivery failed.
pub trait DbMetricsRecorder: Send + Sync + 'static {
/// Increment a counter metric by `inc` with low-cardinality tags.
fn counter(&self, name: &str, inc: i64, tags: &[(&str, &str)]);
/// Record an elapsed duration metric with low-cardinality tags.
fn record_duration(&self, name: &str, duration: Duration, tags: &[(&str, &str)]);
}
/// Shared recorder handle stored by `StateRuntime` and cloned by log layers.
pub type DbMetricsRecorderHandle = Arc<dyn DbMetricsRecorder>;
#[derive(Clone, Copy)]
pub(crate) enum DbKind {
State,
Logs,
}
impl DbKind {
fn as_str(self) -> &'static str {
match self {
Self::State => "state",
Self::Logs => "logs",
}
}
}
#[derive(Clone, Copy)]
pub(crate) enum DbAccess {
Read,
Write,
Transaction,
Maintenance,
}
impl DbAccess {
fn as_str(self) -> &'static str {
match self {
Self::Read => "read",
Self::Write => "write",
Self::Transaction => "transaction",
Self::Maintenance => "maintenance",
}
}
}
pub(crate) fn record_init_result<T>(
metrics: Option<&dyn DbMetricsRecorder>,
db: DbKind,
phase: &'static str,
duration: Duration,
result: &anyhow::Result<T>,
) {
let outcome = DbOutcomeTags::from_result(result);
let tags = [
("status", outcome.status),
("phase", phase),
("db", db.as_str()),
("error", outcome.error),
];
record_counter(metrics, DB_INIT_METRIC, &tags);
record_duration(metrics, DB_INIT_DURATION_METRIC, duration, &tags);
}
pub fn record_fallback(
metrics: Option<&dyn DbMetricsRecorder>,
caller: &'static str,
reason: &'static str,
) {
let tags = [("caller", caller), ("reason", reason)];
record_counter(metrics, DB_FALLBACK_METRIC, &tags);
}
pub fn record_init_backfill_gate(
metrics: Option<&dyn DbMetricsRecorder>,
duration: Duration,
result: &anyhow::Result<()>,
) {
record_init_result(metrics, DbKind::State, "backfill_gate", duration, result);
}
pub(crate) fn record_log_queue(
metrics: Option<&dyn DbMetricsRecorder>,
event: &'static str,
reason: &'static str,
) {
let tags = [("event", event), ("reason", reason)];
record_counter(metrics, DB_LOG_QUEUE_METRIC, &tags);
}
pub(crate) fn classify_error(err: &anyhow::Error) -> &'static str {
for cause in err.chain() {
if let Some(sqlx_err) = cause.downcast_ref::<sqlx::Error>() {
return classify_sqlx_error(sqlx_err);
}
if cause
.downcast_ref::<sqlx::migrate::MigrateError>()
.is_some()
{
return "migration";
}
if cause.downcast_ref::<serde_json::Error>().is_some() {
return "serde";
}
if cause.downcast_ref::<std::io::Error>().is_some() {
return "io";
}
}
"unknown"
}
pub(crate) fn classify_sqlite_code(code: &str) -> &'static str {
let primary_code = code.parse::<i32>().ok().map(|code| code & 0xff);
match primary_code {
Some(5) => "busy",
Some(6) => "locked",
Some(8) => "readonly",
Some(10) => "io",
Some(11) => "corrupt",
Some(13) => "full",
Some(14) => "cantopen",
Some(19) => "constraint",
Some(17) => "schema",
_ => "unknown",
}
}
pub(crate) fn record_operation_result<T>(
metrics: Option<&dyn DbMetricsRecorder>,
db: DbKind,
operation: &'static str,
access: DbAccess,
duration: Duration,
result: &anyhow::Result<T>,
) {
let outcome = DbOutcomeTags::from_result(result);
let tags = [
("status", outcome.status),
("db", db.as_str()),
("operation", operation),
("access", access.as_str()),
("error", outcome.error),
];
record_counter(metrics, DB_OPERATION_METRIC, &tags);
record_duration(metrics, DB_OPERATION_DURATION_METRIC, duration, &tags);
}
struct DbOutcomeTags {
status: &'static str,
error: &'static str,
}
impl DbOutcomeTags {
fn from_result<T>(result: &anyhow::Result<T>) -> Self {
match result {
Ok(_) => Self {
status: "success",
error: "none",
},
Err(err) => Self {
status: "failed",
error: classify_error(err),
},
}
}
}
fn classify_sqlx_error(err: &sqlx::Error) -> &'static str {
match err {
sqlx::Error::Database(database_error) => {
let code = database_error
.code()
.unwrap_or(Cow::Borrowed("none"))
.to_string();
classify_sqlite_code(code.as_str())
}
sqlx::Error::PoolTimedOut => "pool_timeout",
sqlx::Error::Io(_) => "io",
sqlx::Error::ColumnDecode { source, .. } if source.is::<serde_json::Error>() => "serde",
sqlx::Error::Decode(source) if source.is::<serde_json::Error>() => "serde",
_ => "unknown",
}
}
fn record_counter(metrics: Option<&dyn DbMetricsRecorder>, name: &str, tags: &[(&str, &str)]) {
if let Some(metrics) = metrics {
metrics.counter(name, /*inc*/ 1, tags);
}
}
fn record_duration(
metrics: Option<&dyn DbMetricsRecorder>,
name: &str,
duration: Duration,
tags: &[(&str, &str)],
) {
if let Some(metrics) = metrics {
metrics.record_duration(name, duration, tags);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DB_FALLBACK_METRIC;
use crate::DB_OPERATION_METRIC;
use pretty_assertions::assert_eq;
use std::collections::BTreeMap;
use std::sync::Mutex;
#[derive(Default)]
struct TestMetrics {
events: Mutex<Vec<MetricEvent>>,
}
#[derive(Debug, Eq, PartialEq)]
struct MetricEvent {
name: String,
tags: BTreeMap<String, String>,
}
impl TestMetrics {
fn events(&self) -> Vec<MetricEvent> {
self.events
.lock()
.expect("metrics lock")
.iter()
.map(|event| MetricEvent {
name: event.name.clone(),
tags: event.tags.clone(),
})
.collect()
}
}
impl DbMetricsRecorder for TestMetrics {
fn counter(&self, name: &str, _inc: i64, tags: &[(&str, &str)]) {
self.events.lock().expect("metrics lock").push(MetricEvent {
name: name.to_string(),
tags: tags_to_map(tags),
});
}
fn record_duration(&self, _name: &str, _duration: Duration, _tags: &[(&str, &str)]) {}
}
fn tags_to_map(tags: &[(&str, &str)]) -> BTreeMap<String, String> {
tags.iter()
.map(|(key, value)| ((*key).to_string(), (*value).to_string()))
.collect()
}
#[test]
fn classifies_sqlite_primary_codes() {
assert_eq!(classify_sqlite_code("5"), "busy");
assert_eq!(classify_sqlite_code("6"), "locked");
assert_eq!(classify_sqlite_code("14"), "cantopen");
assert_eq!(classify_sqlite_code("2067"), "constraint");
}
#[test]
fn classifies_non_sqlite_errors() {
let io_error =
anyhow::Error::new(std::io::Error::new(std::io::ErrorKind::NotFound, "missing"));
assert_eq!(classify_error(&io_error), "io");
let serde_error =
anyhow::Error::new(serde_json::from_str::<serde_json::Value>("not-json").unwrap_err());
assert_eq!(classify_error(&serde_error), "serde");
let unknown_error = anyhow::anyhow!("plain failure");
assert_eq!(classify_error(&unknown_error), "unknown");
}
#[test]
fn classifies_sqlx_pool_timeout() {
let err = anyhow::Error::new(sqlx::Error::PoolTimedOut);
assert_eq!(classify_error(&err), "pool_timeout");
}
#[test]
fn records_operation_metric_with_stable_tags() {
let metrics = TestMetrics::default();
let result: anyhow::Result<()> = Ok(());
record_operation_result(
Some(&metrics),
DbKind::State,
"list_threads",
DbAccess::Read,
Duration::from_millis(3),
&result,
);
assert_eq!(
metrics.events(),
vec![MetricEvent {
name: DB_OPERATION_METRIC.to_string(),
tags: BTreeMap::from([
("access".to_string(), "read".to_string()),
("db".to_string(), "state".to_string()),
("error".to_string(), "none".to_string()),
("operation".to_string(), "list_threads".to_string()),
("status".to_string(), "success".to_string()),
]),
}]
);
}
#[test]
fn records_fallback_metric_with_reason() {
let metrics = TestMetrics::default();
record_fallback(Some(&metrics), "list_threads", "db_error");
assert_eq!(
metrics.events(),
vec![MetricEvent {
name: DB_FALLBACK_METRIC.to_string(),
tags: BTreeMap::from([
("caller".to_string(), "list_threads".to_string()),
("reason".to_string(), "db_error".to_string()),
]),
}]
);
}
}