diff --git a/codex-rs/cli/src/state_db_recovery.rs b/codex-rs/cli/src/state_db_recovery.rs index a869e675b2..fbc92564e9 100644 --- a/codex-rs/cli/src/state_db_recovery.rs +++ b/codex-rs/cli/src/state_db_recovery.rs @@ -18,9 +18,7 @@ pub(crate) fn is_locked(detail: &str) -> bool { pub(crate) fn confirm_repair(startup_error: &LocalStateDbStartupError) -> std::io::Result { eprintln!("Codex couldn't start because its local database appears to be damaged."); - eprintln!( - "Codex can try a last-resort startup repair by backing up those files and rebuilding empty local databases." - ); + eprintln!("Codex can try to repair by backing up and rebuilding those files."); print_technical_details(startup_error); crate::confirm("Repair Codex local data now? [y/N]: ") } diff --git a/codex-rs/state/src/lib.rs b/codex-rs/state/src/lib.rs index cb8c711e65..a3ba576fd7 100644 --- a/codex-rs/state/src/lib.rs +++ b/codex-rs/state/src/lib.rs @@ -95,3 +95,7 @@ pub const DB_INIT_METRIC: &str = "codex.sqlite.init.count"; pub const DB_INIT_DURATION_METRIC: &str = "codex.sqlite.init.duration_ms"; /// Rollout fallback attempts. Tags: [caller, reason] pub const DB_FALLBACK_METRIC: &str = "codex.sqlite.fallback.count"; +/// SQLite automatic recovery attempts. Tags: [status, db, error, trigger_error] +pub const DB_RECOVERY_METRIC: &str = "codex.sqlite.recovery.count"; +/// SQLite automatic recovery latency. Tags: [status, db, error, trigger_error] +pub const DB_RECOVERY_DURATION_METRIC: &str = "codex.sqlite.recovery.duration_ms"; diff --git a/codex-rs/state/src/runtime.rs b/codex-rs/state/src/runtime.rs index c7ebd3d99d..1a2f7b6292 100644 --- a/codex-rs/state/src/runtime.rs +++ b/codex-rs/state/src/runtime.rs @@ -361,7 +361,7 @@ async fn open_sqlite( if !recovery::is_malformed_sqlite_error(&err) { return Err(err); } - recovery::recover_database(path, spec, migrator, &err).await?; + recovery::recover_database(path, spec, migrator, &err, telemetry_override).await?; connect_sqlite(options.clone(), spec, telemetry_override).await? } }; @@ -373,7 +373,7 @@ async fn open_sqlite( return Err(err); } pool.close().await; - recovery::recover_database(path, spec, migrator, &err).await?; + recovery::recover_database(path, spec, migrator, &err, telemetry_override).await?; let pool = connect_sqlite(options, spec, telemetry_override).await?; migrate_sqlite(&pool, migrator, spec, telemetry_override).await?; Ok(pool) @@ -507,6 +507,7 @@ mod tests { use super::state_db_path; use super::test_support::unique_temp_dir; use crate::DB_INIT_METRIC; + use crate::DB_RECOVERY_METRIC; use crate::DbTelemetry; use crate::migrations::STATE_MIGRATOR; use pretty_assertions::assert_eq; @@ -515,6 +516,10 @@ mod tests { use sqlx::sqlite::SqliteConnectOptions; use std::collections::BTreeMap; use std::collections::BTreeSet; + use std::fs::OpenOptions; + use std::io::Seek; + use std::io::SeekFrom; + use std::io::Write; use std::path::Path; use std::sync::Mutex; @@ -659,6 +664,102 @@ mod tests { let _ = tokio::fs::remove_dir_all(codex_home).await; } + #[tokio::test] + async fn open_state_sqlite_recovers_malformed_database_on_startup() { + let codex_home = unique_temp_dir(); + tokio::fs::create_dir_all(&codex_home) + .await + .expect("create codex home"); + let state_path = state_db_path(codex_home.as_path()); + let pool = SqlitePool::connect_with( + SqliteConnectOptions::new() + .filename(&state_path) + .create_if_missing(true), + ) + .await + .expect("open state db"); + STATE_MIGRATOR + .run(&pool) + .await + .expect("apply current state schema"); + let thread_id = "00000000-0000-0000-0000-000000000123"; + sqlx::query( + r#" +INSERT INTO threads ( + id, + rollout_path, + created_at, + updated_at, + source, + model_provider, + cwd, + title, + sandbox_policy, + approval_mode +) VALUES (?, ?, 1, 1, 'cli', 'test-provider', ?, 'startup recovery', 'read-only', 'on-request') + "#, + ) + .bind(thread_id) + .bind(codex_home.join("session.jsonl").display().to_string()) + .bind(codex_home.as_path().display().to_string()) + .execute(&pool) + .await + .expect("insert thread"); + let page_size: i64 = sqlx::query_scalar("PRAGMA page_size") + .fetch_one(&pool) + .await + .expect("read page size"); + let migration_root_page: i64 = sqlx::query_scalar( + "SELECT rootpage FROM sqlite_schema WHERE name = '_sqlx_migrations'", + ) + .fetch_one(&pool) + .await + .expect("read migration root page"); + pool.close().await; + corrupt_page( + state_path.as_path(), + page_size.try_into().expect("page size should fit u64"), + migration_root_page + .try_into() + .expect("root page should fit u64"), + ) + .expect("corrupt migration table page"); + + let telemetry = TestTelemetry::default(); + let tolerant_migrator = runtime_state_migrator(); + let recovered_pool = + open_state_sqlite(state_path.as_path(), &tolerant_migrator, Some(&telemetry)) + .await + .expect("startup should recover malformed state db"); + let title: String = sqlx::query_scalar("SELECT title FROM threads WHERE id = ?") + .bind(thread_id) + .fetch_one(&recovered_pool) + .await + .expect("recovered thread should exist"); + recovered_pool.close().await; + + assert_eq!(title, "startup recovery"); + let integrity = sqlite_integrity_check(state_path.as_path()) + .await + .expect("integrity check should run"); + assert_eq!(integrity, vec!["ok".to_string()]); + let recovery_event = telemetry + .counters() + .into_iter() + .find(|event| event.name == DB_RECOVERY_METRIC) + .expect("recovery metric should be recorded"); + assert_eq!( + recovery_event.tags, + BTreeMap::from([ + ("db".to_string(), "state".to_string()), + ("error".to_string(), "none".to_string()), + ("status".to_string(), "success".to_string()), + ("trigger_error".to_string(), "corrupt".to_string()), + ]) + ); + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + #[tokio::test] async fn init_records_successful_sqlite_init_phases_to_explicit_telemetry() { let codex_home = unique_temp_dir(); @@ -700,4 +801,14 @@ mod tests { runtime.logs_pool.close().await; let _ = tokio::fs::remove_dir_all(codex_home).await; } + + fn corrupt_page(path: &Path, page_size: u64, page_number: u64) -> std::io::Result<()> { + let offset = page_size + .checked_mul(page_number.saturating_sub(1)) + .ok_or_else(|| std::io::Error::other("corrupt page offset overflowed"))?; + let mut file = OpenOptions::new().write(true).open(path)?; + file.seek(SeekFrom::Start(offset))?; + file.write_all(&[0; 16])?; + Ok(()) + } } diff --git a/codex-rs/state/src/runtime/recovery.rs b/codex-rs/state/src/runtime/recovery.rs index 1964b5bf8a..a3ed00e383 100644 --- a/codex-rs/state/src/runtime/recovery.rs +++ b/codex-rs/state/src/runtime/recovery.rs @@ -1,7 +1,9 @@ use super::RuntimeDbSpec; +use crate::telemetry::DbTelemetry; use anyhow::Context; use anyhow::Result; use log::LevelFilter; +use sqlx::AssertSqlSafe; use sqlx::ConnectOptions; use sqlx::Row; use sqlx::SqlitePool; @@ -11,9 +13,14 @@ use sqlx::sqlite::SqlitePoolOptions; use std::borrow::Cow; use std::collections::BTreeSet; use std::ffi::OsString; +use std::fs::File; +use std::fs::OpenOptions; +use std::fs::TryLockError; use std::path::Path; use std::path::PathBuf; +use std::thread; use std::time::Duration; +use std::time::Instant; use std::time::SystemTime; use std::time::UNIX_EPOCH; use tracing::warn; @@ -23,6 +30,8 @@ mod recover_api; const SQLITE_CORRUPT: i32 = 11; const SQLITE_NOTADB: i32 = 26; +const RECOVERY_LOCK_POLL: Duration = Duration::from_millis(100); +const RECOVERY_LOCK_TIMEOUT: Duration = Duration::from_secs(60); #[derive(Debug)] struct RecoveryPaths { @@ -30,6 +39,11 @@ struct RecoveryPaths { backup_paths: Vec, } +struct RecoveryLock { + _file: File, + waited: bool, +} + pub(super) fn is_malformed_sqlite_error(err: &anyhow::Error) -> bool { // Prefer SQLite result codes, but keep a message fallback for migration // wrappers that stringify the underlying database error. @@ -51,7 +65,42 @@ pub(super) async fn recover_database( spec: RuntimeDbSpec, migrator: &Migrator, original_error: &anyhow::Error, + telemetry_override: Option<&dyn DbTelemetry>, ) -> Result<()> { + let started = Instant::now(); + let result = recover_database_inner(path, spec, migrator, original_error).await; + crate::telemetry::record_recovery_result( + telemetry_override, + spec.kind, + started.elapsed(), + original_error, + &result, + ); + result +} + +async fn recover_database_inner( + path: &Path, + spec: RuntimeDbSpec, + migrator: &Migrator, + original_error: &anyhow::Error, +) -> Result<()> { + let recovery_lock = acquire_recovery_lock(path).await.with_context(|| { + format!( + "failed to lock automatic recovery for {} at {}", + spec.label, + path.display() + ) + })?; + if recovery_lock.waited && database_is_healthy(path, migrator).await { + warn!( + "{} at {} was usable after waiting for the recovery lock; skipping duplicate recovery", + spec.label, + path.display() + ); + return Ok(()); + } + let recovery = prepare_recovery_paths(path).await.with_context(|| { format!( "failed to prepare automatic recovery for {} at {}", @@ -123,6 +172,71 @@ fn print_status(message: String) { eprintln!("{message}"); } +async fn acquire_recovery_lock(path: &Path) -> Result { + let lock_path = recovery_lock_path(path); + tokio::task::spawn_blocking(move || { + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(false) + .open(lock_path.as_path()) + .with_context(|| format!("failed to open {}", lock_path.display()))?; + let started = Instant::now(); + let mut waited = false; + loop { + match file.try_lock() { + Ok(()) => { + return Ok(RecoveryLock { + _file: file, + waited, + }); + } + Err(TryLockError::WouldBlock) if started.elapsed() < RECOVERY_LOCK_TIMEOUT => { + waited = true; + thread::sleep(RECOVERY_LOCK_POLL); + } + Err(TryLockError::WouldBlock) => { + anyhow::bail!( + "timed out waiting for another Codex process to finish recovering {}", + lock_path.display() + ); + } + Err(err) => { + return Err(std::io::Error::from(err)).with_context(|| { + format!("failed to lock recovery file {}", lock_path.display()) + }); + } + } + } + }) + .await + .context("recovery lock task panicked")? +} + +fn recovery_lock_path(path: &Path) -> PathBuf { + let mut lock_path = OsString::from(path.as_os_str()); + lock_path.push(".codex-recovery.lock"); + PathBuf::from(lock_path) +} + +async fn database_is_healthy(path: &Path, migrator: &Migrator) -> bool { + let Ok(pool) = open_recovered_pool(path).await else { + return false; + }; + let result = async { + // Check integrity before migrations so a still-corrupt database is not + // modified just because another process held the recovery lock first. + assert_integrity_ok(&pool).await?; + migrator.run(&pool).await?; + assert_integrity_ok(&pool).await?; + Ok::<(), anyhow::Error>(()) + } + .await; + pool.close().await; + result.is_ok() +} + fn sqlx_error_is_malformed(err: &sqlx::Error) -> bool { match err { sqlx::Error::Database(database_error) => { @@ -212,6 +326,19 @@ async fn run_recovery(path: &Path, recovered_path: &Path, migrator: &Migrator) - assert_integrity_ok(&pool).await?; match migrator.run(&pool).await { Ok(()) => { + if let Err(err) = + assert_expected_schema(&pool, recovered_path.as_path(), migrator).await + { + pool.close().await; + rebuild_recovered_database(recovered_path.as_path(), migrator) + .await + .with_context(|| { + format!( + "failed to normalize recovered database after schema validation failure: {err}" + ) + })?; + return Ok(()); + } assert_integrity_ok(&pool).await?; pool.close().await; } @@ -338,7 +465,8 @@ ORDER BY name "#, quote_identifier(schema) ); - let rows = sqlx::query_scalar::<_, String>(sql.as_str()) + // Dynamic identifiers are quoted with quote_identifier before interpolation. + let rows = sqlx::query_scalar::<_, String>(AssertSqlSafe(sql)) .fetch_all(pool) .await?; Ok(rows.into_iter().collect()) @@ -365,14 +493,16 @@ async fn copy_current_schema_table(pool: &SqlitePool, table: &str) -> Result Result<()> { let table_name = quote_identifier(table); let sql = format!("CREATE TABLE main.{table_name} AS SELECT * FROM recovered.{table_name}"); - sqlx::query(sql.as_str()).execute(pool).await?; + // Dynamic identifiers are quoted with quote_identifier before interpolation. + sqlx::query(AssertSqlSafe(sql)).execute(pool).await?; Ok(()) } @@ -381,7 +511,8 @@ async fn table_columns(pool: &SqlitePool, schema: &str, table: &str) -> Result