diff --git a/codex-rs/app-server/src/request_processors/thread_processor.rs b/codex-rs/app-server/src/request_processors/thread_processor.rs index 1766c971b1..c9a42d8359 100644 --- a/codex-rs/app-server/src/request_processors/thread_processor.rs +++ b/codex-rs/app-server/src/request_processors/thread_processor.rs @@ -1513,9 +1513,13 @@ impl ThreadRequestProcessor { .clone() .ok_or_else(|| internal_error("sqlite state db unavailable for memory reset"))?; - state_db.clear_memory_data().await.map_err(|err| { - internal_error(format!("failed to clear memory rows in state db: {err}")) - })?; + state_db + .memories() + .clear_memory_data() + .await + .map_err(|err| { + internal_error(format!("failed to clear memory rows in memories db: {err}")) + })?; clear_memory_roots_contents(&self.config.codex_home) .await diff --git a/codex-rs/app-server/tests/suite/v2/memory_reset.rs b/codex-rs/app-server/tests/suite/v2/memory_reset.rs index 3c7ae38671..7b2eddeb53 100644 --- a/codex-rs/app-server/tests/suite/v2/memory_reset.rs +++ b/codex-rs/app-server/tests/suite/v2/memory_reset.rs @@ -49,7 +49,10 @@ async fn memory_reset_clears_memory_files_and_rows_preserves_threads() -> Result .await??; let _: MemoryResetResponse = to_response::(response)?; - let stage1_outputs = state_db.list_stage1_outputs_for_global(/*n*/ 10).await?; + let stage1_outputs = state_db + .memories() + .list_stage1_outputs_for_global(/*n*/ 10) + .await?; assert_eq!(stage1_outputs, Vec::new()); assert_eq!( state_db.get_thread_memory_mode(thread_id).await?.as_deref(), @@ -81,6 +84,7 @@ async fn seed_stage1_output(state_db: &Arc, codex_home: &Path) -> state_db.upsert_thread(&metadata).await?; let claim = state_db + .memories() .try_claim_stage1_job( thread_id, worker_id, @@ -94,6 +98,7 @@ async fn seed_stage1_output(state_db: &Arc, codex_home: &Path) -> }; assert!( state_db + .memories() .mark_stage1_job_succeeded( thread_id, ownership_token.as_str(), @@ -106,6 +111,7 @@ async fn seed_stage1_output(state_db: &Arc, codex_home: &Path) -> "stage1 success should be recorded" ); state_db + .memories() .enqueue_global_consolidation(now.timestamp()) .await?; diff --git a/codex-rs/cli/src/doctor/output.rs b/codex-rs/cli/src/doctor/output.rs index ec09a4d69f..3f6e2f05d2 100644 --- a/codex-rs/cli/src/doctor/output.rs +++ b/codex-rs/cli/src/doctor/output.rs @@ -702,6 +702,7 @@ fn state_summary(check: &DoctorCheck) -> String { "state DB integrity", "log DB integrity", "goals DB integrity", + "memories DB integrity", ] .into_iter() .all(|label| detail::detail_value(check, label).is_some_and(|value| value == "ok")); @@ -1363,6 +1364,37 @@ Run codex doctor without --summary for detailed diagnostics. ); } + #[test] + fn render_human_report_includes_memories_db_in_state_health_summary() { + let report = DoctorReport { + schema_version: 1, + generated_at: "0s since unix epoch".to_string(), + overall_status: CheckStatus::Ok, + codex_version: "0.0.0".to_string(), + checks: vec![ + DoctorCheck::new( + "state.paths", + "state", + CheckStatus::Ok, + "state paths inspectable", + ) + .detail("state DB: /tmp/state.sqlite") + .detail("state DB integrity: ok") + .detail("log DB: /tmp/logs.sqlite") + .detail("log DB integrity: ok") + .detail("goals DB: /tmp/goals.sqlite") + .detail("goals DB integrity: ok") + .detail("memories DB: /tmp/memories.sqlite") + .detail("memories DB integrity: ok"), + ], + }; + + let rendered = render_human_report(&report, detailed_no_color_unicode_options()); + + assert!(rendered.contains("✓ state databases healthy")); + assert!(rendered.contains("memories DB /tmp/memories.sqlite · integrity ok")); + } + #[test] fn render_human_report_supports_ascii_output() { let rendered = render_human_report( diff --git a/codex-rs/cli/src/doctor/output/detail.rs b/codex-rs/cli/src/doctor/output/detail.rs index 86cbf2a83f..6c6cf8633c 100644 --- a/codex-rs/cli/src/doctor/output/detail.rs +++ b/codex-rs/cli/src/doctor/output/detail.rs @@ -413,6 +413,7 @@ fn state_details(parsed: &[ParsedDetail]) -> Vec { push_database_row(&mut out, parsed, "state DB"); push_database_row(&mut out, parsed, "log DB"); push_database_row(&mut out, parsed, "goals DB"); + push_database_row(&mut out, parsed, "memories DB"); for (source, label) in [ ("active rollout files", "active rollouts"), @@ -440,6 +441,8 @@ fn state_details(parsed: &[ParsedDetail]) -> Vec { "state DB integrity", "log DB integrity", "goals DB integrity", + "memories DB", + "memories DB integrity", "active rollout files", "archived rollout files", ], diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index 4bb7ef74a0..856273354c 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -27,7 +27,7 @@ use codex_responses_api_proxy::Args as ResponsesApiProxyArgs; use codex_rollout_trace::REDUCED_STATE_FILE_NAME; use codex_rollout_trace::replay_bundle; use codex_state::StateRuntime; -use codex_state::state_db_path; +use codex_state::memories_db_path; use codex_tui::AppExitInfo; use codex_tui::Cli as TuiCli; use codex_tui::ExitReason; @@ -1751,22 +1751,16 @@ async fn run_debug_clear_memories_command( .build() .await?; - let state_path = state_db_path(config.sqlite_home.as_path()); - let mut cleared_state_db = false; - if tokio::fs::try_exists(&state_path).await? { - let state_db = - StateRuntime::init(config.sqlite_home.clone(), config.model_provider_id.clone()) - .await?; - state_db.clear_memory_data().await?; - cleared_state_db = true; - } + let memories_path = memories_db_path(config.sqlite_home.as_path()); + let cleared_memories_db = + StateRuntime::clear_memory_data_in_sqlite_home(config.sqlite_home.as_path()).await?; clear_memory_roots_contents(&config.codex_home).await?; - let mut message = if cleared_state_db { - format!("Cleared memory state from {}.", state_path.display()) + let mut message = if cleared_memories_db { + format!("Cleared memory state from {}.", memories_path.display()) } else { - format!("No state db found at {}.", state_path.display()) + format!("No memories db found at {}.", memories_path.display()) }; message.push_str(&format!( " Cleared memory directories under {}.", diff --git a/codex-rs/cli/tests/debug_clear_memories.rs b/codex-rs/cli/tests/debug_clear_memories.rs index 9d5e114dbf..e39d736ae0 100644 --- a/codex-rs/cli/tests/debug_clear_memories.rs +++ b/codex-rs/cli/tests/debug_clear_memories.rs @@ -2,6 +2,7 @@ use std::path::Path; use anyhow::Result; use codex_state::StateRuntime; +use codex_state::memories_db_path; use codex_state::state_db_path; use predicates::str::contains; use sqlx::SqlitePool; @@ -23,6 +24,9 @@ async fn debug_clear_memories_resets_state_and_removes_memory_dir() -> Result<() let thread_id = "00000000-0000-0000-0000-000000000123"; let db_path = state_db_path(codex_home.path()); let pool = SqlitePool::connect(&format!("sqlite://{}", db_path.display())).await?; + let memories_db_path = memories_db_path(codex_home.path()); + let memories_pool = + SqlitePool::connect(&format!("sqlite://{}", memories_db_path.display())).await?; sqlx::query( r#" @@ -74,7 +78,7 @@ INSERT INTO stage1_outputs ( "#, ) .bind(thread_id) - .execute(&pool) + .execute(&memories_pool) .await?; sqlx::query( @@ -99,13 +103,14 @@ INSERT INTO jobs ( "#, ) .bind(thread_id) - .execute(&pool) + .execute(&memories_pool) .await?; let memory_root = codex_home.path().join("memories"); std::fs::create_dir_all(&memory_root)?; std::fs::write(memory_root.join("memory_summary.md"), "stale memory")?; pool.close().await; + memories_pool.close().await; let mut cmd = codex_command(codex_home.path())?; cmd.args(["debug", "clear-memories"]) @@ -113,7 +118,7 @@ INSERT INTO jobs ( .success() .stdout(contains("Cleared memory state")); - let pool = SqlitePool::connect(&format!("sqlite://{}", db_path.display())).await?; + let pool = SqlitePool::connect(&format!("sqlite://{}", memories_db_path.display())).await?; let stage1_outputs_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM stage1_outputs") .fetch_one(&pool) .await?; @@ -131,3 +136,54 @@ INSERT INTO jobs ( Ok(()) } + +#[tokio::test] +async fn debug_clear_memories_resets_memories_db_without_state_db() -> Result<()> { + let codex_home = TempDir::new()?; + let runtime = + StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string()).await?; + drop(runtime); + + let db_path = state_db_path(codex_home.path()); + let memories_db_path = memories_db_path(codex_home.path()); + let memories_pool = + SqlitePool::connect(&format!("sqlite://{}", memories_db_path.display())).await?; + + sqlx::query( + r#" +INSERT INTO stage1_outputs ( + thread_id, + source_updated_at, + raw_memory, + rollout_summary, + generated_at, + rollout_slug, + usage_count, + last_usage, + selected_for_phase2, + selected_for_phase2_source_updated_at +) VALUES ('00000000-0000-0000-0000-000000000123', 1, 'raw', 'summary', 1, NULL, 0, NULL, 0, NULL) + "#, + ) + .execute(&memories_pool) + .await?; + + memories_pool.close().await; + std::fs::remove_file(&db_path)?; + + let mut cmd = codex_command(codex_home.path())?; + cmd.args(["debug", "clear-memories"]) + .assert() + .success() + .stdout(contains("Cleared memory state")); + + let pool = SqlitePool::connect(&format!("sqlite://{}", memories_db_path.display())).await?; + let stage1_outputs_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM stage1_outputs") + .fetch_one(&pool) + .await?; + assert_eq!(stage1_outputs_count, 0); + pool.close().await; + assert!(!db_path.exists()); + + Ok(()) +} diff --git a/codex-rs/core/src/stream_events_utils.rs b/codex-rs/core/src/stream_events_utils.rs index 0e32ccb9fb..1ee4b4b58f 100644 --- a/codex-rs/core/src/stream_events_utils.rs +++ b/codex-rs/core/src/stream_events_utils.rs @@ -233,7 +233,7 @@ async fn record_stage1_output_usage_for_memory_citation( } if let Some(db) = state_db_ctx { - let _ = db.record_stage1_output_usage(&thread_ids).await; + let _ = db.memories().record_stage1_output_usage(&thread_ids).await; } true } diff --git a/codex-rs/memories/write/src/phase1.rs b/codex-rs/memories/write/src/phase1.rs index 9c6c2561de..1844ab3b52 100644 --- a/codex-rs/memories/write/src/phase1.rs +++ b/codex-rs/memories/write/src/phase1.rs @@ -112,6 +112,7 @@ pub async fn prune(context: &MemoryStartupContext, config: &Config) { if let Some(db) = context.state_db() { let max_unused_days = config.memories.max_unused_days; match db + .memories() .prune_stage1_outputs_for_retention(max_unused_days, crate::stage_one::PRUNE_BATCH_SIZE) .await { @@ -124,7 +125,7 @@ pub async fn prune(context: &MemoryStartupContext, config: &Config) { } Err(err) => { warn!( - "state db prune_stage1_outputs_for_retention failed during memories startup: {err}" + "memories db prune_stage1_outputs_for_retention failed during memories startup: {err}" ); } } @@ -161,6 +162,7 @@ async fn claim_startup_jobs( .collect::>(); match state_db + .memories() .claim_stage1_jobs_for_startup( context.thread_id(), codex_state::Stage1StartupClaimParams { @@ -176,7 +178,9 @@ async fn claim_startup_jobs( { Ok(claims) => Some(claims), Err(err) => { - warn!("state db claim_stage1_jobs_for_startup failed during memories startup: {err}"); + warn!( + "memories db claim_stage1_jobs_for_startup failed during memories startup: {err}" + ); None } } @@ -329,6 +333,7 @@ mod job { tracing::warn!("Phase 1 job failed for thread {thread_id}: {reason}"); if let Some(state_db) = context.state_db() { let _ = state_db + .memories() .mark_stage1_job_failed( thread_id, ownership_token, @@ -349,6 +354,7 @@ mod job { }; if state_db + .memories() .mark_stage1_job_succeeded_no_output(thread_id, ownership_token) .await .unwrap_or(false) @@ -373,6 +379,7 @@ mod job { }; if state_db + .memories() .mark_stage1_job_succeeded( thread_id, ownership_token, diff --git a/codex-rs/memories/write/src/phase2.rs b/codex-rs/memories/write/src/phase2.rs index 7a3841f8c3..c78032d9c2 100644 --- a/codex-rs/memories/write/src/phase2.rs +++ b/codex-rs/memories/write/src/phase2.rs @@ -91,6 +91,7 @@ pub async fn run(context: Arc, config: Arc) { // 4. Load current DB-backed Phase 2 inputs. let raw_memories = match db + .memories() .get_phase2_input_selection(max_raw_memories, max_unused_days) .await { @@ -217,6 +218,7 @@ mod job { db: &StateRuntime, ) -> Result { let claim = db + .memories() .try_claim_global_phase2_job(context.thread_id(), crate::stage_two::JOB_LEASE_SECONDS) .await .map_err(|e| { @@ -255,15 +257,17 @@ mod job { ) { context.counter(MEMORY_PHASE_TWO_JOBS, /*inc*/ 1, &[("status", reason)]); if matches!( - db.mark_global_phase2_job_failed( - &claim.token, - reason, - crate::stage_two::JOB_RETRY_DELAY_SECONDS, - ) - .await, + db.memories() + .mark_global_phase2_job_failed( + &claim.token, + reason, + crate::stage_two::JOB_RETRY_DELAY_SECONDS, + ) + .await, Ok(false) ) { let _ = db + .memories() .mark_global_phase2_job_failed_if_unowned( &claim.token, reason, @@ -282,7 +286,8 @@ mod job { reason: &'static str, ) -> bool { context.counter(MEMORY_PHASE_TWO_JOBS, /*inc*/ 1, &[("status", reason)]); - db.mark_global_phase2_job_succeeded(&claim.token, completion_watermark, selected_outputs) + db.memories() + .mark_global_phase2_job_succeeded(&claim.token, completion_watermark, selected_outputs) .await .unwrap_or(false) } @@ -382,6 +387,7 @@ mod agent { } // Do not reset the workspace baseline if we lost the lock. let still_owns_lock = match db + .memories() .heartbeat_global_phase2_job( &claim.token, crate::stage_two::JOB_LEASE_SECONDS, @@ -479,6 +485,7 @@ mod agent { } _ = heartbeat_interval.tick() => { match db + .memories() .heartbeat_global_phase2_job( &token, crate::stage_two::JOB_LEASE_SECONDS, diff --git a/codex-rs/memories/write/src/startup_tests.rs b/codex-rs/memories/write/src/startup_tests.rs index 30240f04ca..4af8560e67 100644 --- a/codex-rs/memories/write/src/startup_tests.rs +++ b/codex-rs/memories/write/src/startup_tests.rs @@ -190,7 +190,8 @@ async fn memories_startup_phase2_prunes_old_extension_resources_without_stage1_i let server = start_mock_server().await; let home = Arc::new(TempDir::new()?); let db = init_state_db(&home).await?; - db.enqueue_global_consolidation(/*input_watermark*/ 1) + db.memories() + .enqueue_global_consolidation(/*input_watermark*/ 1) .await?; let now = chrono::Utc::now(); @@ -445,6 +446,7 @@ async fn seed_stage1_output_for_existing_thread( ) -> anyhow::Result<()> { let owner = ThreadId::new(); let claim = db + .memories() .try_claim_stage1_job( thread_id, owner, updated_at, /*lease_seconds*/ 3_600, /*max_running_jobs*/ 64, @@ -456,15 +458,16 @@ async fn seed_stage1_output_for_existing_thread( }; assert!( - db.mark_stage1_job_succeeded( - thread_id, - &ownership_token, - updated_at, - raw_memory, - rollout_summary, - rollout_slug, - ) - .await?, + db.memories() + .mark_stage1_job_succeeded( + thread_id, + &ownership_token, + updated_at, + raw_memory, + rollout_summary, + rollout_slug, + ) + .await?, "stage-1 success should enqueue global consolidation" ); diff --git a/codex-rs/rollout/src/state_db.rs b/codex-rs/rollout/src/state_db.rs index 3bd8dd8e01..ea087e5c6d 100644 --- a/codex-rs/rollout/src/state_db.rs +++ b/codex-rs/rollout/src/state_db.rs @@ -493,8 +493,12 @@ pub async fn mark_thread_memory_mode_polluted( let Some(ctx) = context else { return; }; - if let Err(err) = ctx.mark_thread_memory_mode_polluted(thread_id).await { - warn!("state db mark_thread_memory_mode_polluted failed during {stage}: {err}"); + if let Err(err) = ctx + .memories() + .mark_thread_memory_mode_polluted(thread_id) + .await + { + warn!("memories db mark_thread_memory_mode_polluted failed during {stage}: {err}"); } } diff --git a/codex-rs/state/BUILD.bazel b/codex-rs/state/BUILD.bazel index b3f0fecab7..9225b6c126 100644 --- a/codex-rs/state/BUILD.bazel +++ b/codex-rs/state/BUILD.bazel @@ -3,5 +3,10 @@ load("//:defs.bzl", "codex_rust_crate") codex_rust_crate( name = "state", crate_name = "codex_state", - compile_data = glob(["goals_migrations/**", "logs_migrations/**", "migrations/**"]), + compile_data = glob([ + "goals_migrations/**", + "logs_migrations/**", + "memory_migrations/**", + "migrations/**", + ]), ) diff --git a/codex-rs/state/memory_migrations/0001_memories.sql b/codex-rs/state/memory_migrations/0001_memories.sql new file mode 100644 index 0000000000..4e5b3a1e9f --- /dev/null +++ b/codex-rs/state/memory_migrations/0001_memories.sql @@ -0,0 +1,35 @@ +CREATE TABLE stage1_outputs ( + thread_id TEXT PRIMARY KEY, + source_updated_at INTEGER NOT NULL, + raw_memory TEXT NOT NULL, + rollout_summary TEXT NOT NULL, + rollout_slug TEXT, + generated_at INTEGER NOT NULL, + usage_count INTEGER, + last_usage INTEGER, + selected_for_phase2 INTEGER NOT NULL DEFAULT 0, + selected_for_phase2_source_updated_at INTEGER +); + +CREATE INDEX idx_stage1_outputs_source_updated_at + ON stage1_outputs(source_updated_at DESC, thread_id DESC); + +CREATE TABLE jobs ( + kind TEXT NOT NULL, + job_key TEXT NOT NULL, + status TEXT NOT NULL, + worker_id TEXT, + ownership_token TEXT, + started_at INTEGER, + finished_at INTEGER, + lease_until INTEGER, + retry_at INTEGER, + retry_remaining INTEGER NOT NULL, + last_error TEXT, + input_watermark INTEGER, + last_success_watermark INTEGER, + PRIMARY KEY (kind, job_key) +); + +CREATE INDEX idx_jobs_kind_status_retry_lease + ON jobs(kind, status, retry_at, lease_until); diff --git a/codex-rs/state/migrations/0035_drop_memory_tables.sql b/codex-rs/state/migrations/0035_drop_memory_tables.sql new file mode 100644 index 0000000000..49721e3e54 --- /dev/null +++ b/codex-rs/state/migrations/0035_drop_memory_tables.sql @@ -0,0 +1,2 @@ +DROP TABLE IF EXISTS jobs; +DROP TABLE IF EXISTS stage1_outputs; diff --git a/codex-rs/state/src/lib.rs b/codex-rs/state/src/lib.rs index 678ee11eee..cb8c711e65 100644 --- a/codex-rs/state/src/lib.rs +++ b/codex-rs/state/src/lib.rs @@ -55,6 +55,7 @@ pub use runtime::GoalAccountingMode; pub use runtime::GoalAccountingOutcome; pub use runtime::GoalStore; pub use runtime::GoalUpdate; +pub use runtime::MemoryStore; pub use runtime::RemoteControlEnrollmentRecord; pub use runtime::RuntimeDbPath; pub use runtime::ThreadFilterOptions; @@ -62,6 +63,8 @@ pub use runtime::goals_db_filename; pub use runtime::goals_db_path; pub use runtime::logs_db_filename; pub use runtime::logs_db_path; +pub use runtime::memories_db_filename; +pub use runtime::memories_db_path; pub use runtime::runtime_db_paths; pub use runtime::sqlite_integrity_check; pub use runtime::state_db_filename; @@ -77,6 +80,7 @@ pub const SQLITE_HOME_ENV: &str = "CODEX_SQLITE_HOME"; pub const LOGS_DB_FILENAME: &str = "logs_2.sqlite"; pub const GOALS_DB_FILENAME: &str = "goals_1.sqlite"; +pub const MEMORIES_DB_FILENAME: &str = "memories_1.sqlite"; pub const STATE_DB_FILENAME: &str = "state_5.sqlite"; /// Errors encountered during DB operations. Tags: [stage] diff --git a/codex-rs/state/src/migrations.rs b/codex-rs/state/src/migrations.rs index 526e958b5a..641d584530 100644 --- a/codex-rs/state/src/migrations.rs +++ b/codex-rs/state/src/migrations.rs @@ -5,6 +5,7 @@ use sqlx::migrate::Migrator; pub(crate) static STATE_MIGRATOR: Migrator = sqlx::migrate!("./migrations"); pub(crate) static LOGS_MIGRATOR: Migrator = sqlx::migrate!("./logs_migrations"); pub(crate) static GOALS_MIGRATOR: Migrator = sqlx::migrate!("./goals_migrations"); +pub(crate) static MEMORIES_MIGRATOR: Migrator = sqlx::migrate!("./memory_migrations"); /// Allow an older Codex binary to open a database that has already been /// migrated by a newer binary running in parallel. @@ -32,3 +33,7 @@ pub(crate) fn runtime_logs_migrator() -> Migrator { pub(crate) fn runtime_goals_migrator() -> Migrator { runtime_migrator(&GOALS_MIGRATOR) } + +pub(crate) fn runtime_memories_migrator() -> Migrator { + runtime_migrator(&MEMORIES_MIGRATOR) +} diff --git a/codex-rs/state/src/model/memories.rs b/codex-rs/state/src/model/memories.rs index 9bb34405ae..650ede08da 100644 --- a/codex-rs/state/src/model/memories.rs +++ b/codex-rs/state/src/model/memories.rs @@ -1,9 +1,6 @@ -use anyhow::Result; use chrono::DateTime; use chrono::Utc; use codex_protocol::ThreadId; -use sqlx::Row; -use sqlx::sqlite::SqliteRow; use std::path::PathBuf; use super::ThreadMetadata; @@ -22,58 +19,6 @@ pub struct Stage1Output { pub generated_at: DateTime, } -#[derive(Debug)] -pub(crate) struct Stage1OutputRow { - thread_id: String, - rollout_path: String, - source_updated_at: i64, - raw_memory: String, - rollout_summary: String, - rollout_slug: Option, - cwd: String, - git_branch: Option, - generated_at: i64, -} - -impl Stage1OutputRow { - pub(crate) fn try_from_row(row: &SqliteRow) -> Result { - Ok(Self { - thread_id: row.try_get("thread_id")?, - rollout_path: row.try_get("rollout_path")?, - source_updated_at: row.try_get("source_updated_at")?, - raw_memory: row.try_get("raw_memory")?, - rollout_summary: row.try_get("rollout_summary")?, - rollout_slug: row.try_get("rollout_slug")?, - cwd: row.try_get("cwd")?, - git_branch: row.try_get("git_branch")?, - generated_at: row.try_get("generated_at")?, - }) - } -} - -impl TryFrom for Stage1Output { - type Error = anyhow::Error; - - fn try_from(row: Stage1OutputRow) -> std::result::Result { - Ok(Self { - thread_id: ThreadId::try_from(row.thread_id)?, - rollout_path: PathBuf::from(row.rollout_path), - source_updated_at: epoch_seconds_to_datetime(row.source_updated_at)?, - raw_memory: row.raw_memory, - rollout_summary: row.rollout_summary, - rollout_slug: row.rollout_slug, - cwd: PathBuf::from(row.cwd), - git_branch: row.git_branch, - generated_at: epoch_seconds_to_datetime(row.generated_at)?, - }) - } -} - -fn epoch_seconds_to_datetime(secs: i64) -> Result> { - DateTime::::from_timestamp(secs, 0) - .ok_or_else(|| anyhow::anyhow!("invalid unix timestamp: {secs}")) -} - /// Result of trying to claim a stage-1 memory extraction job. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Stage1JobClaimOutcome { diff --git a/codex-rs/state/src/model/mod.rs b/codex-rs/state/src/model/mod.rs index a431bc64c0..2c725e50ab 100644 --- a/codex-rs/state/src/model/mod.rs +++ b/codex-rs/state/src/model/mod.rs @@ -37,7 +37,6 @@ pub use thread_metadata::ThreadsPage; pub(crate) use agent_job::AgentJobItemRow; pub(crate) use agent_job::AgentJobRow; -pub(crate) use memories::Stage1OutputRow; pub(crate) use thread_goal::ThreadGoalRow; pub(crate) use thread_metadata::ThreadRow; pub(crate) use thread_metadata::anchor_from_item; diff --git a/codex-rs/state/src/runtime.rs b/codex-rs/state/src/runtime.rs index 8c9c6d0267..ba41021290 100644 --- a/codex-rs/state/src/runtime.rs +++ b/codex-rs/state/src/runtime.rs @@ -10,6 +10,7 @@ use crate::LOGS_DB_FILENAME; use crate::LogEntry; use crate::LogQuery; use crate::LogRow; +use crate::MEMORIES_DB_FILENAME; use crate::STATE_DB_FILENAME; use crate::SortKey; use crate::ThreadMetadata; @@ -18,6 +19,7 @@ use crate::ThreadsPage; use crate::apply_rollout_item; use crate::migrations::runtime_goals_migrator; use crate::migrations::runtime_logs_migrator; +use crate::migrations::runtime_memories_migrator; use crate::migrations::runtime_state_migrator; use crate::model::AgentJobRow; use crate::model::ThreadRow; @@ -70,6 +72,7 @@ pub use goals::GoalAccountingMode; pub use goals::GoalAccountingOutcome; pub use goals::GoalStore; pub use goals::GoalUpdate; +pub use memories::MemoryStore; pub use remote_control::RemoteControlEnrollmentRecord; pub use threads::ThreadFilterOptions; @@ -121,7 +124,15 @@ const GOALS_DB: RuntimeDbSpec = RuntimeDbSpec { migrate_phase: "migrate_goals", }; -const RUNTIME_DBS: [RuntimeDbSpec; 3] = [STATE_DB, LOGS_DB, GOALS_DB]; +const MEMORIES_DB: RuntimeDbSpec = RuntimeDbSpec { + label: "memories DB", + filename: MEMORIES_DB_FILENAME, + kind: DbKind::Memories, + open_phase: "open_memories", + migrate_phase: "migrate_memories", +}; + +const RUNTIME_DBS: [RuntimeDbSpec; 4] = [STATE_DB, LOGS_DB, GOALS_DB, MEMORIES_DB]; #[derive(Clone, Debug, Eq, PartialEq)] pub struct RuntimeDbPath { @@ -136,6 +147,7 @@ pub struct StateRuntime { pool: Arc, logs_pool: Arc, thread_goals: GoalStore, + memories: MemoryStore, thread_updated_at_millis: Arc, } @@ -172,9 +184,11 @@ impl StateRuntime { let state_migrator = runtime_state_migrator(); let logs_migrator = runtime_logs_migrator(); let goals_migrator = runtime_goals_migrator(); + let memories_migrator = runtime_memories_migrator(); let state_path = STATE_DB.path(codex_home.as_path()); let logs_path = LOGS_DB.path(codex_home.as_path()); let goals_path = GOALS_DB.path(codex_home.as_path()); + let memories_path = MEMORIES_DB.path(codex_home.as_path()); let pool = match open_state_sqlite(&state_path, &state_migrator, telemetry_override).await { Ok(db) => Arc::new(db), Err(err) => { @@ -198,6 +212,22 @@ impl StateRuntime { return Err(err); } }; + let memories_pool = match open_memories_sqlite( + &memories_path, + &memories_migrator, + telemetry_override, + ) + .await + { + Ok(db) => Arc::new(db), + Err(err) => { + warn!( + "failed to open memories db at {}: {err}", + memories_path.display() + ); + return Err(err); + } + }; let started = Instant::now(); let backfill_state_result = ensure_backfill_state_row_in_pool(pool.as_ref()).await; crate::telemetry::record_init_result( @@ -225,6 +255,7 @@ impl StateRuntime { let thread_updated_at_millis = thread_updated_at_millis.unwrap_or(0); let runtime = Arc::new(Self { thread_goals: GoalStore::new(Arc::clone(&goals_pool)), + memories: MemoryStore::new(Arc::clone(&memories_pool), Arc::clone(&pool)), pool, logs_pool, codex_home, @@ -248,6 +279,28 @@ impl StateRuntime { pub fn thread_goals(&self) -> &GoalStore { &self.thread_goals } + + pub fn memories(&self) -> &MemoryStore { + &self.memories + } + + pub async fn clear_memory_data_in_sqlite_home(sqlite_home: &Path) -> anyhow::Result { + let memories_path = MEMORIES_DB.path(sqlite_home); + if !tokio::fs::try_exists(&memories_path).await? { + return Ok(false); + } + + let memories_migrator = runtime_memories_migrator(); + let pool = open_memories_sqlite( + &memories_path, + &memories_migrator, + /*telemetry_override*/ None, + ) + .await?; + memories::clear_memory_data_in_pool(&pool).await?; + pool.close().await; + Ok(true) + } } fn base_sqlite_options(path: &Path) -> SqliteConnectOptions { @@ -287,6 +340,14 @@ async fn open_goals_sqlite( open_sqlite(path, migrator, GOALS_DB, telemetry_override).await } +async fn open_memories_sqlite( + path: &Path, + migrator: &Migrator, + telemetry_override: Option<&dyn DbTelemetry>, +) -> anyhow::Result { + open_sqlite(path, migrator, MEMORIES_DB, telemetry_override).await +} + async fn open_sqlite( path: &Path, migrator: &Migrator, @@ -363,6 +424,14 @@ pub fn goals_db_path(codex_home: &Path) -> PathBuf { GOALS_DB.path(codex_home) } +pub fn memories_db_filename() -> String { + MEMORIES_DB.filename.to_string() +} + +pub fn memories_db_path(codex_home: &Path) -> PathBuf { + MEMORIES_DB.path(codex_home) +} + pub fn runtime_db_paths(codex_home: &Path) -> Vec { RUNTIME_DBS .iter() @@ -579,6 +648,8 @@ mod tests { "migrate_logs", "open_goals", "migrate_goals", + "open_memories", + "migrate_memories", "ensure_backfill_state", "post_init_query", ] diff --git a/codex-rs/state/src/runtime/memories.rs b/codex-rs/state/src/runtime/memories.rs index aeafe62a77..350d22cc6d 100644 --- a/codex-rs/state/src/runtime/memories.rs +++ b/codex-rs/state/src/runtime/memories.rs @@ -1,15 +1,14 @@ use super::threads::ThreadFilterOptions; use super::threads::push_thread_filters; -use super::threads::push_thread_order_and_limit; use super::*; use crate::SortDirection; use crate::model::Phase2JobClaimOutcome; use crate::model::Stage1JobClaim; use crate::model::Stage1JobClaimOutcome; use crate::model::Stage1Output; -use crate::model::Stage1OutputRow; use crate::model::Stage1StartupClaimParams; use crate::model::ThreadRow; +use chrono::DateTime; use chrono::Duration; use sqlx::Executor; use sqlx::QueryBuilder; @@ -20,39 +19,29 @@ const JOB_KIND_MEMORY_STAGE1: &str = "memory_stage1"; const JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL: &str = "memory_consolidate_global"; const MEMORY_CONSOLIDATION_JOB_KEY: &str = "global"; const PHASE2_SUCCESS_COOLDOWN_SECONDS: i64 = 6 * 60 * 60; +const PHASE2_INPUT_SELECTION_PAGE_SIZE: usize = 512; const DEFAULT_RETRY_REMAINING: i64 = 3; -impl StateRuntime { +/// Store for generated memory state and memory extraction/consolidation jobs. +#[derive(Clone)] +pub struct MemoryStore { + pool: Arc, + state_pool: Arc, +} + +impl MemoryStore { + pub(crate) fn new(pool: Arc, state_pool: Arc) -> Self { + Self { pool, state_pool } + } + /// Deletes all persisted memory state in one transaction. /// /// This removes every `stage1_outputs` row and all `jobs` rows for the /// stage-1 (`memory_stage1`) and phase-2 (`memory_consolidate_global`) /// memory pipelines. pub async fn clear_memory_data(&self) -> anyhow::Result<()> { - let mut tx = self.pool.begin().await?; - - sqlx::query( - r#" -DELETE FROM stage1_outputs - "#, - ) - .execute(&mut *tx) - .await?; - - sqlx::query( - r#" -DELETE FROM jobs -WHERE kind = ? OR kind = ? - "#, - ) - .bind(JOB_KIND_MEMORY_STAGE1) - .bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL) - .execute(&mut *tx) - .await?; - - tx.commit().await?; - Ok(()) + clear_memory_data_in_pool(self.pool.as_ref()).await } /// Record usage for cited stage-1 outputs. @@ -92,6 +81,51 @@ WHERE thread_id = ? Ok(updated_rows) } + async fn stage1_source_needs_update( + &self, + thread_id: ThreadId, + source_updated_at: i64, + ) -> anyhow::Result { + let thread_id = thread_id.to_string(); + let existing_output = sqlx::query( + r#" +SELECT source_updated_at +FROM stage1_outputs +WHERE thread_id = ? + "#, + ) + .bind(thread_id.as_str()) + .fetch_optional(self.pool.as_ref()) + .await?; + if let Some(existing_output) = existing_output { + let existing_source_updated_at: i64 = existing_output.try_get("source_updated_at")?; + if existing_source_updated_at >= source_updated_at { + return Ok(false); + } + } + + let existing_job = sqlx::query( + r#" +SELECT last_success_watermark +FROM jobs +WHERE kind = ? AND job_key = ? + "#, + ) + .bind(JOB_KIND_MEMORY_STAGE1) + .bind(thread_id.as_str()) + .fetch_optional(self.pool.as_ref()) + .await?; + if let Some(existing_job) = existing_job { + let last_success_watermark = + existing_job.try_get::, _>("last_success_watermark")?; + if last_success_watermark.is_some_and(|watermark| watermark >= source_updated_at) { + return Ok(false); + } + } + + Ok(true) + } + /// Selects and claims stage-1 startup jobs for stale threads. /// /// Query behavior: @@ -100,8 +134,9 @@ WHERE thread_id = ? /// - excludes threads with `memory_mode != 'enabled'` /// - excludes the current thread id /// - keeps only threads whose millisecond `updated_at` is in the age window - /// - keeps only threads whose memory is stale compared to millisecond `updated_at` - /// - orders by `updated_at_ms DESC` and applies `scan_limit` + /// - checks memory staleness against the memories DB + /// - orders by `updated_at_ms DESC` and applies `scan_limit` to bound + /// state-DB work before probing the memories DB /// /// For each selected thread, this function calls [`Self::try_claim_stage1_job`] /// with `source_updated_at = thread.updated_at.timestamp()` and returns up to @@ -157,16 +192,6 @@ SELECT threads.git_branch, threads.git_origin_url FROM threads -LEFT JOIN stage1_outputs - ON stage1_outputs.thread_id = threads.id -LEFT JOIN jobs - ON jobs.kind = - "#, - ); - builder.push_bind(JOB_KIND_MEMORY_STAGE1); - builder.push( - r#" - AND jobs.job_key = threads.id "#, ); push_thread_filters( @@ -196,32 +221,29 @@ LEFT JOIN jobs .push("threads.updated_at_ms") .push(" <= ") .push_bind(idle_cutoff); - let updated_at = "threads.updated_at_ms"; - builder.push(" AND ((COALESCE(stage1_outputs.source_updated_at, -1) + 1) * 1000) <= "); - builder.push(updated_at); - builder.push(" AND ((COALESCE(jobs.last_success_watermark, -1) + 1) * 1000) <= "); - builder.push(updated_at); - push_thread_order_and_limit( - &mut builder, - SortKey::UpdatedAt, - SortDirection::Desc, - scan_limit, - ); + let scan_limit_i64 = i64::try_from(scan_limit).unwrap_or(i64::MAX); + builder.push(" ORDER BY threads.updated_at_ms DESC LIMIT "); + builder.push_bind(scan_limit_i64); let items = builder .build() - .fetch_all(self.pool.as_ref()) + .fetch_all(self.state_pool.as_ref()) .await? .into_iter() .map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from)) .collect::, _>>()?; let mut claimed = Vec::new(); - for item in items { if claimed.len() >= max_claimed { break; } + if !self + .stage1_source_needs_update(item.id, item.updated_at.timestamp()) + .await? + { + continue; + } if let Stage1JobClaimOutcome::Claimed { ownership_token } = self .try_claim_stage1_job( @@ -243,13 +265,64 @@ LEFT JOIN jobs Ok(claimed) } + pub(super) async fn delete_thread_memory(&self, thread_id: ThreadId) -> anyhow::Result<()> { + let now = Utc::now().timestamp(); + let thread_id = thread_id.to_string(); + let mut tx = self.pool.begin().await?; + + let existing_output = sqlx::query( + r#" +SELECT selected_for_phase2 +FROM stage1_outputs +WHERE thread_id = ? + "#, + ) + .bind(thread_id.as_str()) + .fetch_optional(&mut *tx) + .await?; + let was_selected_for_phase2 = existing_output + .map(|row| row.try_get::("selected_for_phase2")) + .transpose()? + .is_some_and(|selected| selected != 0); + + let deleted_rows = sqlx::query( + r#" +DELETE FROM stage1_outputs +WHERE thread_id = ? + "#, + ) + .bind(thread_id.as_str()) + .execute(&mut *tx) + .await? + .rows_affected(); + + sqlx::query( + r#" +DELETE FROM jobs +WHERE kind = ? AND job_key = ? + "#, + ) + .bind(JOB_KIND_MEMORY_STAGE1) + .bind(thread_id.as_str()) + .execute(&mut *tx) + .await?; + + if deleted_rows > 0 && was_selected_for_phase2 { + enqueue_global_consolidation_with_executor(&mut *tx, now).await?; + } + + tx.commit().await?; + Ok(()) + } + /// Lists the most recent non-empty stage-1 outputs for global consolidation. /// /// Query behavior: /// - filters out rows where both `raw_memory` and `rollout_summary` are blank - /// - joins `threads` to include thread `cwd`, `rollout_path`, and `git_branch` + /// - hydrates thread `cwd`, `rollout_path`, and `git_branch` from the state DB + /// - filters out missing or non-enabled threads /// - orders by `source_updated_at DESC, thread_id DESC` - /// - applies `LIMIT n` + /// - returns the first `n` visible outputs pub async fn list_stage1_outputs_for_global( &self, n: usize, @@ -262,30 +335,30 @@ LEFT JOIN jobs r#" SELECT so.thread_id, - COALESCE(t.rollout_path, '') AS rollout_path, so.source_updated_at, so.raw_memory, so.rollout_summary, so.rollout_slug, - so.generated_at, - COALESCE(t.cwd, '') AS cwd, - t.git_branch AS git_branch + so.generated_at FROM stage1_outputs AS so -LEFT JOIN threads AS t - ON t.id = so.thread_id -WHERE t.memory_mode = 'enabled' - AND (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0) +WHERE length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0 ORDER BY so.source_updated_at DESC, so.thread_id DESC -LIMIT ? "#, ) - .bind(n as i64) .fetch_all(self.pool.as_ref()) .await?; - rows.into_iter() - .map(|row| Stage1OutputRow::try_from_row(&row).and_then(Stage1Output::try_from)) - .collect::, _>>() + let mut outputs = Vec::new(); + for row in rows { + if let Some(output) = self.stage1_output_from_row_if_thread_enabled(&row).await? { + outputs.push(output); + if outputs.len() >= n { + break; + } + } + } + + Ok(outputs) } /// Prunes stale stage-1 outputs while preserving the latest phase-2 @@ -356,64 +429,150 @@ WHERE thread_id IN ( } let cutoff = (Utc::now() - Duration::days(max_unused_days.max(0))).timestamp(); - let current_rows = sqlx::query( - r#" -SELECT - selected.thread_id, - selected.rollout_path, - selected.source_updated_at, - selected.raw_memory, - selected.rollout_summary, - selected.rollout_slug, - selected.generated_at, - selected.cwd, - selected.git_branch -FROM ( - SELECT - so.thread_id, - COALESCE(t.rollout_path, '') AS rollout_path, - so.source_updated_at, - so.raw_memory, - so.rollout_summary, - so.rollout_slug, - so.generated_at, - COALESCE(t.cwd, '') AS cwd, - t.git_branch AS git_branch - FROM stage1_outputs AS so - LEFT JOIN threads AS t - ON t.id = so.thread_id - WHERE t.memory_mode = 'enabled' - AND (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0) - AND ( - (so.last_usage IS NOT NULL AND so.last_usage >= ?) - OR (so.last_usage IS NULL AND so.source_updated_at >= ?) - ) - ORDER BY - COALESCE(so.usage_count, 0) DESC, - COALESCE(so.last_usage, so.source_updated_at) DESC, - so.source_updated_at DESC, - so.thread_id DESC - LIMIT ? -) AS selected -ORDER BY selected.thread_id ASC - "#, - ) - .bind(cutoff) - .bind(cutoff) - .bind(n as i64) - .fetch_all(self.pool.as_ref()) - .await?; + let page_size = n.clamp(1, PHASE2_INPUT_SELECTION_PAGE_SIZE); + let page_size_i64 = i64::try_from(page_size).unwrap_or(i64::MAX); + let mut offset = 0_i64; + let mut selected_keys = Vec::with_capacity(n); - let mut selected = Vec::with_capacity(current_rows.len()); - for row in current_rows { - selected.push(Stage1Output::try_from(Stage1OutputRow::try_from_row( - &row, - )?)?); + while selected_keys.len() < n { + let candidate_rows = sqlx::query( + r#" +SELECT + so.thread_id, + so.source_updated_at +FROM stage1_outputs AS so +WHERE (length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0) + AND ( + (so.last_usage IS NOT NULL AND so.last_usage >= ?) + OR (so.last_usage IS NULL AND so.source_updated_at >= ?) + ) +ORDER BY + COALESCE(so.usage_count, 0) DESC, + COALESCE(so.last_usage, so.source_updated_at) DESC, + so.source_updated_at DESC, + so.thread_id DESC +LIMIT ? OFFSET ? + "#, + ) + .bind(cutoff) + .bind(cutoff) + .bind(page_size_i64) + .bind(offset) + .fetch_all(self.pool.as_ref()) + .await?; + + if candidate_rows.is_empty() { + break; + } + + let candidate_count = i64::try_from(candidate_rows.len()).unwrap_or(i64::MAX); + for row in candidate_rows { + let thread_id: String = row.try_get("thread_id")?; + let source_updated_at: i64 = row.try_get("source_updated_at")?; + if self + .enabled_thread_metadata(ThreadId::try_from(thread_id.as_str())?) + .await? + .is_some() + { + selected_keys.push((thread_id, source_updated_at)); + if selected_keys.len() >= n { + break; + } + } + } + + offset = offset.saturating_add(candidate_count); } + let mut selected = Vec::with_capacity(selected_keys.len()); + for (thread_id, source_updated_at) in selected_keys { + let Some(row) = sqlx::query( + r#" +SELECT + so.thread_id, + so.source_updated_at, + so.raw_memory, + so.rollout_summary, + so.rollout_slug, + so.generated_at +FROM stage1_outputs AS so +WHERE so.thread_id = ? AND so.source_updated_at = ? + "#, + ) + .bind(thread_id.as_str()) + .bind(source_updated_at) + .fetch_optional(self.pool.as_ref()) + .await? + else { + continue; + }; + if let Some(output) = self.stage1_output_from_row_if_thread_enabled(&row).await? { + selected.push(output); + } + } + + selected.sort_by(|a, b| a.thread_id.to_string().cmp(&b.thread_id.to_string())); + Ok(selected) } + async fn stage1_output_from_row_if_thread_enabled( + &self, + row: &sqlx::sqlite::SqliteRow, + ) -> anyhow::Result> { + let thread_id: String = row.try_get("thread_id")?; + let Some(thread) = self + .enabled_thread_metadata(ThreadId::try_from(thread_id.as_str())?) + .await? + else { + return Ok(None); + }; + Ok(Some(stage1_output_from_row_and_thread(row, thread)?)) + } + + async fn enabled_thread_metadata( + &self, + thread_id: ThreadId, + ) -> anyhow::Result> { + let row = sqlx::query( + r#" +SELECT + threads.id, + threads.rollout_path, + threads.created_at_ms AS created_at, + threads.updated_at_ms AS updated_at, + threads.source, + threads.thread_source, + threads.agent_nickname, + threads.agent_role, + threads.agent_path, + threads.model_provider, + threads.model, + threads.reasoning_effort, + threads.cwd, + threads.cli_version, + threads.title, + threads.preview, + threads.sandbox_policy, + threads.approval_mode, + threads.tokens_used, + threads.first_user_message, + threads.archived_at, + threads.git_sha, + threads.git_branch, + threads.git_origin_url +FROM threads +WHERE threads.id = ? AND threads.memory_mode = 'enabled' + "#, + ) + .bind(thread_id.to_string()) + .fetch_optional(self.state_pool.as_ref()) + .await?; + + row.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from)) + .transpose() + } + /// Marks a thread as polluted and enqueues phase-2 forgetting when the /// thread participated in the last successful phase-2 baseline. pub async fn mark_thread_memory_mode_polluted( @@ -422,24 +581,6 @@ ORDER BY selected.thread_id ASC ) -> anyhow::Result { let now = Utc::now().timestamp(); let thread_id = thread_id.to_string(); - let mut tx = self.pool.begin().await?; - let rows_affected = sqlx::query( - r#" -UPDATE threads -SET memory_mode = 'polluted' -WHERE id = ? AND memory_mode != 'polluted' - "#, - ) - .bind(thread_id.as_str()) - .execute(&mut *tx) - .await? - .rows_affected(); - - if rows_affected == 0 { - tx.commit().await?; - return Ok(false); - } - let selected_for_phase2 = sqlx::query_scalar::<_, i64>( r#" SELECT selected_for_phase2 @@ -448,15 +589,26 @@ WHERE thread_id = ? "#, ) .bind(thread_id.as_str()) - .fetch_optional(&mut *tx) + .fetch_optional(self.pool.as_ref()) .await? .unwrap_or(0); + let rows_affected = sqlx::query( + r#" +UPDATE threads +SET memory_mode = 'polluted' +WHERE id = ? AND memory_mode != 'polluted' + "#, + ) + .bind(thread_id.as_str()) + .execute(self.state_pool.as_ref()) + .await? + .rows_affected(); + if selected_for_phase2 != 0 { - enqueue_global_consolidation_with_executor(&mut *tx, now).await?; + self.enqueue_global_consolidation(now).await?; } - tx.commit().await?; - Ok(true) + Ok(rows_affected > 0) } /// Attempts to claim a stage-1 job for a thread at `source_updated_at`. @@ -1221,6 +1373,58 @@ WHERE kind = ? AND job_key = ? Ok(rows_affected) } +pub(super) async fn clear_memory_data_in_pool(pool: &SqlitePool) -> anyhow::Result<()> { + let mut tx = pool.begin().await?; + + sqlx::query( + r#" +DELETE FROM stage1_outputs + "#, + ) + .execute(&mut *tx) + .await?; + + sqlx::query( + r#" +DELETE FROM jobs +WHERE kind = ? OR kind = ? + "#, + ) + .bind(JOB_KIND_MEMORY_STAGE1) + .bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + Ok(()) +} + +fn stage1_output_from_row_and_thread( + row: &sqlx::sqlite::SqliteRow, + thread: ThreadMetadata, +) -> anyhow::Result { + let source_updated_at: i64 = row.try_get("source_updated_at")?; + let generated_at: i64 = row.try_get("generated_at")?; + let source_updated_at = datetime_from_epoch_seconds(source_updated_at)?; + let generated_at = datetime_from_epoch_seconds(generated_at)?; + Ok(Stage1Output { + thread_id: thread.id, + rollout_path: thread.rollout_path, + source_updated_at, + raw_memory: row.try_get("raw_memory")?, + rollout_summary: row.try_get("rollout_summary")?, + rollout_slug: row.try_get("rollout_slug")?, + cwd: thread.cwd, + git_branch: thread.git_branch, + generated_at, + }) +} + +fn datetime_from_epoch_seconds(secs: i64) -> anyhow::Result> { + DateTime::::from_timestamp(secs, 0) + .ok_or_else(|| anyhow::anyhow!("invalid unix timestamp: {secs}")) +} + async fn enqueue_global_consolidation_with_executor<'e, E>( executor: E, input_watermark: i64, @@ -1272,6 +1476,181 @@ ON CONFLICT(kind, job_key) DO UPDATE SET Ok(()) } +#[cfg(test)] +impl StateRuntime { + async fn clear_memory_data(&self) -> anyhow::Result<()> { + self.memories.clear_memory_data().await + } + + async fn record_stage1_output_usage(&self, thread_ids: &[ThreadId]) -> anyhow::Result { + self.memories.record_stage1_output_usage(thread_ids).await + } + + async fn claim_stage1_jobs_for_startup( + &self, + current_thread_id: ThreadId, + params: Stage1StartupClaimParams<'_>, + ) -> anyhow::Result> { + self.memories + .claim_stage1_jobs_for_startup(current_thread_id, params) + .await + } + + async fn list_stage1_outputs_for_global(&self, n: usize) -> anyhow::Result> { + self.memories.list_stage1_outputs_for_global(n).await + } + + async fn prune_stage1_outputs_for_retention( + &self, + max_unused_days: i64, + limit: usize, + ) -> anyhow::Result { + self.memories + .prune_stage1_outputs_for_retention(max_unused_days, limit) + .await + } + + async fn get_phase2_input_selection( + &self, + n: usize, + max_unused_days: i64, + ) -> anyhow::Result> { + self.memories + .get_phase2_input_selection(n, max_unused_days) + .await + } + + async fn mark_thread_memory_mode_polluted(&self, thread_id: ThreadId) -> anyhow::Result { + self.memories + .mark_thread_memory_mode_polluted(thread_id) + .await + } + + async fn try_claim_stage1_job( + &self, + thread_id: ThreadId, + worker_id: ThreadId, + source_updated_at: i64, + lease_seconds: i64, + max_running_jobs: usize, + ) -> anyhow::Result { + self.memories + .try_claim_stage1_job( + thread_id, + worker_id, + source_updated_at, + lease_seconds, + max_running_jobs, + ) + .await + } + + async fn mark_stage1_job_succeeded( + &self, + thread_id: ThreadId, + ownership_token: &str, + source_updated_at: i64, + raw_memory: &str, + rollout_summary: &str, + rollout_slug: Option<&str>, + ) -> anyhow::Result { + self.memories + .mark_stage1_job_succeeded( + thread_id, + ownership_token, + source_updated_at, + raw_memory, + rollout_summary, + rollout_slug, + ) + .await + } + + async fn mark_stage1_job_succeeded_no_output( + &self, + thread_id: ThreadId, + ownership_token: &str, + ) -> anyhow::Result { + self.memories + .mark_stage1_job_succeeded_no_output(thread_id, ownership_token) + .await + } + + async fn mark_stage1_job_failed( + &self, + thread_id: ThreadId, + ownership_token: &str, + failure_reason: &str, + retry_delay_seconds: i64, + ) -> anyhow::Result { + self.memories + .mark_stage1_job_failed( + thread_id, + ownership_token, + failure_reason, + retry_delay_seconds, + ) + .await + } + + async fn enqueue_global_consolidation(&self, input_watermark: i64) -> anyhow::Result<()> { + self.memories + .enqueue_global_consolidation(input_watermark) + .await + } + + async fn try_claim_global_phase2_job( + &self, + worker_id: ThreadId, + lease_seconds: i64, + ) -> anyhow::Result { + self.memories + .try_claim_global_phase2_job(worker_id, lease_seconds) + .await + } + + async fn mark_global_phase2_job_succeeded( + &self, + ownership_token: &str, + completed_watermark: i64, + selected_outputs: &[Stage1Output], + ) -> anyhow::Result { + self.memories + .mark_global_phase2_job_succeeded( + ownership_token, + completed_watermark, + selected_outputs, + ) + .await + } + + async fn mark_global_phase2_job_failed( + &self, + ownership_token: &str, + failure_reason: &str, + retry_delay_seconds: i64, + ) -> anyhow::Result { + self.memories + .mark_global_phase2_job_failed(ownership_token, failure_reason, retry_delay_seconds) + .await + } + + async fn mark_global_phase2_job_failed_if_unowned( + &self, + ownership_token: &str, + failure_reason: &str, + retry_delay_seconds: i64, + ) -> anyhow::Result { + self.memories + .mark_global_phase2_job_failed_if_unowned( + ownership_token, + failure_reason, + retry_delay_seconds, + ) + .await + } +} + #[cfg(test)] mod tests { use super::JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL; @@ -1296,12 +1675,16 @@ mod tests { ThreadId::from_string(value).expect("thread id") } + fn memory_pool(runtime: &StateRuntime) -> &sqlx::SqlitePool { + runtime.memories().pool.as_ref() + } + async fn age_phase2_success_beyond_cooldown(runtime: &StateRuntime) { sqlx::query("UPDATE jobs SET finished_at = ? WHERE kind = ? AND job_key = ?") .bind(Utc::now().timestamp() - PHASE2_SUCCESS_COOLDOWN_SECONDS - 1) .bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL) .bind(MEMORY_CONSOLIDATION_JOB_KEY) - .execute(runtime.pool.as_ref()) + .execute(memory_pool(runtime)) .await .expect("age phase2 success beyond cooldown"); } @@ -1410,7 +1793,7 @@ mod tests { sqlx::query("UPDATE jobs SET lease_until = 0 WHERE kind = 'memory_stage1' AND job_key = ?") .bind(thread_id.to_string()) - .execute(runtime.pool.as_ref()) + .execute(memory_pool(&runtime)) .await .expect("force stale lease"); @@ -1656,7 +2039,7 @@ mod tests { } #[tokio::test] - async fn claim_stage1_jobs_prefilters_threads_with_up_to_date_memory() { + async fn claim_stage1_jobs_bounds_state_scan_before_memory_probes() { let codex_home = unique_temp_dir(); let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string()) .await @@ -1734,7 +2117,7 @@ mod tests { .expect("upsert stale thread"); let allowed_sources = vec!["cli".to_string()]; - let claims = runtime + let claims_with_one_scanned_thread = runtime .claim_stage1_jobs_for_startup( current_thread_id, Stage1StartupClaimParams { @@ -1748,6 +2131,22 @@ mod tests { ) .await .expect("claim stage1 startup jobs"); + assert_eq!(claims_with_one_scanned_thread.len(), 0); + + let claims = runtime + .claim_stage1_jobs_for_startup( + current_thread_id, + Stage1StartupClaimParams { + scan_limit: 2, + max_claimed: 1, + max_age_days: 30, + min_rollout_idle_hours: 12, + allowed_sources: allowed_sources.as_slice(), + lease_seconds: 3600, + }, + ) + .await + .expect("claim stage1 startup jobs with wider scan"); assert_eq!(claims.len(), 1); assert_eq!(claims[0].thread.id, stale_thread_id); @@ -1901,7 +2300,7 @@ mod tests { .expect("clear memory data"); let stage1_outputs_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM stage1_outputs") - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("count stage1 outputs"); assert_eq!(stage1_outputs_count, 0); @@ -1910,7 +2309,7 @@ mod tests { sqlx::query_scalar("SELECT COUNT(*) FROM jobs WHERE kind = ? OR kind = ?") .bind(JOB_KIND_MEMORY_STAGE1) .bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL) - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("count memory jobs"); assert_eq!(memory_jobs_count, 0); @@ -2001,7 +2400,7 @@ INSERT INTO jobs ( .bind(lease_until) .bind(3) .bind(metadata.updated_at.timestamp()) - .execute(runtime.pool.as_ref()) + .execute(memory_pool(&runtime)) .await .expect("seed running stage1 job"); } @@ -2035,7 +2434,7 @@ WHERE kind = 'memory_stage1' "#, ) .bind(Utc::now().timestamp()) - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("count running stage1 jobs") .try_get::("count") @@ -2149,7 +2548,7 @@ WHERE kind = 'memory_stage1' } #[tokio::test] - async fn stage1_output_cascades_on_thread_delete() { + async fn delete_thread_removes_stage1_output_and_enqueues_phase2_when_selected() { let codex_home = unique_temp_dir(); let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string()) .await @@ -2192,29 +2591,84 @@ WHERE kind = 'memory_stage1' let count_before = sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?") .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("count before delete") .try_get::("count") .expect("count value"); assert_eq!(count_before, 1); - sqlx::query("DELETE FROM threads WHERE id = ?") - .bind(thread_id.to_string()) - .execute(runtime.pool.as_ref()) + let phase2_claim = runtime + .try_claim_global_phase2_job(owner, /*lease_seconds*/ 3600) .await - .expect("delete thread"); + .expect("claim phase2"); + let (phase2_token, input_watermark) = match phase2_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim outcome: {other:?}"), + }; + let selected_outputs = runtime + .list_stage1_outputs_for_global(/*n*/ 10) + .await + .expect("list stage1 outputs"); + assert!( + runtime + .mark_global_phase2_job_succeeded( + phase2_token.as_str(), + input_watermark, + &selected_outputs, + ) + .await + .expect("mark phase2 succeeded"), + "phase2 success should mark selected stage1 output" + ); + + let before_delete = Utc::now().timestamp(); + assert_eq!( + runtime + .delete_thread(thread_id) + .await + .expect("delete thread"), + 1 + ); let count_after = sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?") .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("count after delete") .try_get::("count") .expect("count value"); assert_eq!(count_after, 0); + let phase2_job = sqlx::query( + r#" +SELECT status, input_watermark +FROM jobs +WHERE kind = ? AND job_key = ? + "#, + ) + .bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL) + .bind(MEMORY_CONSOLIDATION_JOB_KEY) + .fetch_one(memory_pool(&runtime)) + .await + .expect("load phase2 job after delete"); + let status: String = phase2_job.try_get("status").expect("status"); + let input_watermark: i64 = phase2_job + .try_get("input_watermark") + .expect("input watermark"); + assert_eq!(status, "pending"); + assert!(input_watermark >= before_delete); + + let visible_outputs = runtime + .list_stage1_outputs_for_global(/*n*/ 10) + .await + .expect("list stage1 outputs after thread delete"); + assert_eq!(visible_outputs.len(), 0); + let _ = tokio::fs::remove_dir_all(codex_home).await; } @@ -2259,7 +2713,7 @@ WHERE kind = 'memory_stage1' let output_row_count = sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?") .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("load stage1 output count") .try_get::("count") @@ -2280,7 +2734,7 @@ WHERE kind = 'memory_stage1' let global_job_row_count = sqlx::query("SELECT COUNT(*) AS count FROM jobs WHERE kind = ?") .bind("memory_consolidate_global") - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("load phase2 job row count") .try_get::("count") @@ -2384,7 +2838,7 @@ WHERE kind = 'memory_stage1' let output_row_count = sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?") .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("load stage1 output count after delete") .try_get::("count") @@ -2495,7 +2949,7 @@ WHERE kind = 'memory_stage1' ) .bind("memory_stage1") .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("load stage1 job row after newer-source claim"); assert_eq!( @@ -2621,7 +3075,7 @@ WHERE kind = 'memory_stage1' sqlx::query("SELECT retry_remaining FROM jobs WHERE kind = ? AND job_key = ?") .bind("memory_consolidate_global") .bind("global") - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("load phase2 job row after retry exhaustion"); assert_eq!( @@ -2788,7 +3242,7 @@ VALUES (?, ?, ?, ?, ?) .bind("raw memory") .bind("summary") .bind(100_i64) - .execute(runtime.pool.as_ref()) + .execute(memory_pool(&runtime)) .await .expect("insert non-empty stage1 output"); sqlx::query( @@ -2802,7 +3256,7 @@ VALUES (?, ?, ?, ?, ?) .bind("") .bind("") .bind(101_i64) - .execute(runtime.pool.as_ref()) + .execute(memory_pool(&runtime)) .await .expect("insert empty stage1 output"); @@ -3183,6 +3637,101 @@ VALUES (?, ?, ?, ?, ?) let _ = tokio::fs::remove_dir_all(codex_home).await; } + #[tokio::test] + async fn mark_thread_memory_mode_polluted_enqueues_phase2_when_already_polluted() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string()) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join("workspace"), + )) + .await + .expect("upsert thread"); + + let claim = runtime + .try_claim_stage1_job( + thread_id, owner, /*source_updated_at*/ 100, /*lease_seconds*/ 3600, + /*max_running_jobs*/ 64, + ) + .await + .expect("claim stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + /*source_updated_at*/ 100, + "raw", + "summary", + /*rollout_slug*/ None, + ) + .await + .expect("mark stage1 succeeded"), + "stage1 success should persist output" + ); + + let phase2_claim = runtime + .try_claim_global_phase2_job(owner, /*lease_seconds*/ 3600) + .await + .expect("claim phase2"); + let (phase2_token, input_watermark) = match phase2_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim outcome: {other:?}"), + }; + let selected_outputs = runtime + .list_stage1_outputs_for_global(/*n*/ 10) + .await + .expect("list stage1 outputs"); + assert!( + runtime + .mark_global_phase2_job_succeeded( + phase2_token.as_str(), + input_watermark, + &selected_outputs, + ) + .await + .expect("mark phase2 success"), + "phase2 success should persist selected rows" + ); + + sqlx::query("UPDATE threads SET memory_mode = 'polluted' WHERE id = ?") + .bind(thread_id.to_string()) + .execute(runtime.pool.as_ref()) + .await + .expect("mark thread polluted before memory enqueue"); + + assert!( + !runtime + .mark_thread_memory_mode_polluted(thread_id) + .await + .expect("mark already polluted thread"), + "already polluted thread should not report a state transition" + ); + + age_phase2_success_beyond_cooldown(&runtime).await; + let next_claim = runtime + .try_claim_global_phase2_job(owner, /*lease_seconds*/ 3600) + .await + .expect("claim phase2 after already-polluted enqueue"); + assert!(matches!(next_claim, Phase2JobClaimOutcome::Claimed { .. })); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + #[tokio::test] async fn get_phase2_input_selection_returns_regenerated_selected_rows() { let codex_home = unique_temp_dir(); @@ -3293,7 +3842,7 @@ VALUES (?, ?, ?, ?, ?) "SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?", ) .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("load selected_for_phase2"); assert_eq!(selected_for_phase2, 1); @@ -3586,7 +4135,7 @@ VALUES (?, ?, ?, ?, ?) "SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?", ) .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("load selected snapshot after phase2"); assert_eq!(selected_for_phase2, 1); @@ -3699,7 +4248,7 @@ VALUES (?, ?, ?, ?, ?) "SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?", ) .bind(thread_id.to_string()) - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("load selected_for_phase2"); assert_eq!(selected_for_phase2, 0); @@ -3803,13 +4352,13 @@ VALUES (?, ?, ?, ?, ?) let row_a = sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?") .bind(thread_a.to_string()) - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("load stage1 usage row a"); let row_b = sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?") .bind(thread_b.to_string()) - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("load stage1 usage row b"); @@ -3909,7 +4458,7 @@ VALUES (?, ?, ?, ?, ?) .bind(usage_count) .bind(last_usage.timestamp()) .bind(thread_id.to_string()) - .execute(runtime.pool.as_ref()) + .execute(memory_pool(&runtime)) .await .expect("update usage metadata"); } @@ -4005,7 +4554,7 @@ VALUES (?, ?, ?, ?, ?) .bind(usage_count) .bind(last_usage.map(|value| value.timestamp())) .bind(thread_id.to_string()) - .execute(runtime.pool.as_ref()) + .execute(memory_pool(&runtime)) .await .expect("update usage metadata"); } @@ -4090,13 +4639,13 @@ VALUES (?, ?, ?, ?, ?) sqlx::query("UPDATE stage1_outputs SET generated_at = ? WHERE thread_id = ?") .bind(300_i64) .bind(older_thread.to_string()) - .execute(runtime.pool.as_ref()) + .execute(memory_pool(&runtime)) .await .expect("update older generated_at"); sqlx::query("UPDATE stage1_outputs SET generated_at = ? WHERE thread_id = ?") .bind(150_i64) .bind(newer_thread.to_string()) - .execute(runtime.pool.as_ref()) + .execute(memory_pool(&runtime)) .await .expect("update newer generated_at"); @@ -4202,14 +4751,14 @@ VALUES (?, ?, ?, ?, ?) .bind(3_i64) .bind(now - Duration::days(40).num_seconds()) .bind(stale_used.to_string()) - .execute(runtime.pool.as_ref()) + .execute(memory_pool(&runtime)) .await .expect("set stale used metadata"); sqlx::query( "UPDATE stage1_outputs SET selected_for_phase2 = 1, selected_for_phase2_source_updated_at = source_updated_at WHERE thread_id = ?", ) .bind(stale_selected.to_string()) - .execute(runtime.pool.as_ref()) + .execute(memory_pool(&runtime)) .await .expect("mark selected for phase2"); sqlx::query( @@ -4218,13 +4767,13 @@ VALUES (?, ?, ?, ?, ?) .bind(8_i64) .bind(now - Duration::days(2).num_seconds()) .bind(fresh_used.to_string()) - .execute(runtime.pool.as_ref()) + .execute(memory_pool(&runtime)) .await .expect("set fresh used metadata"); let before_jobs_count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM jobs WHERE kind = 'memory_stage1'") - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("count stage1 jobs before prune"); @@ -4237,7 +4786,7 @@ VALUES (?, ?, ?, ?, ?) let remaining = sqlx::query_scalar::<_, String>( "SELECT thread_id FROM stage1_outputs ORDER BY thread_id", ) - .fetch_all(runtime.pool.as_ref()) + .fetch_all(memory_pool(&runtime)) .await .expect("load remaining stage1 outputs"); let mut expected_remaining = vec![fresh_used.to_string(), stale_selected.to_string()]; @@ -4246,7 +4795,7 @@ VALUES (?, ?, ?, ?, ?) let after_jobs_count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM jobs WHERE kind = 'memory_stage1'") - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("count stage1 jobs after prune"); assert_eq!(after_jobs_count, before_jobs_count); @@ -4324,7 +4873,7 @@ VALUES (?, ?, ?, ?, ?) assert_eq!(pruned, 2); let remaining_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM stage1_outputs") - .fetch_one(runtime.pool.as_ref()) + .fetch_one(memory_pool(&runtime)) .await .expect("count remaining stage1 outputs"); assert_eq!(remaining_count, 1); @@ -4540,7 +5089,7 @@ VALUES (?, ?, ?, ?, ?) .bind(Utc::now().timestamp() - 1) .bind("memory_consolidate_global") .bind("global") - .execute(runtime.pool.as_ref()) + .execute(memory_pool(&runtime)) .await .expect("expire global consolidation lease"); @@ -4676,7 +5225,7 @@ VALUES (?, ?, ?, ?, ?) sqlx::query("UPDATE jobs SET ownership_token = NULL WHERE kind = ? AND job_key = ?") .bind("memory_consolidate_global") .bind("global") - .execute(runtime.pool.as_ref()) + .execute(memory_pool(&runtime)) .await .expect("clear ownership token"); diff --git a/codex-rs/state/src/runtime/threads.rs b/codex-rs/state/src/runtime/threads.rs index c7030d7e44..78180e7c7d 100644 --- a/codex-rs/state/src/runtime/threads.rs +++ b/codex-rs/state/src/runtime/threads.rs @@ -971,6 +971,7 @@ ON CONFLICT(thread_id, position) DO NOTHING .execute(self.pool.as_ref()) .await?; let rows_affected = result.rows_affected(); + self.memories.delete_thread_memory(thread_id).await?; if rows_affected > 0 { self.thread_goals.delete_thread_goal(thread_id).await?; } diff --git a/codex-rs/state/src/telemetry.rs b/codex-rs/state/src/telemetry.rs index da2b7de7a8..1a9e4f9951 100644 --- a/codex-rs/state/src/telemetry.rs +++ b/codex-rs/state/src/telemetry.rs @@ -40,6 +40,7 @@ pub(crate) enum DbKind { State, Logs, Goals, + Memories, } impl DbKind { @@ -48,6 +49,7 @@ impl DbKind { Self::State => "state", Self::Logs => "logs", Self::Goals => "goals", + Self::Memories => "memories", } } }