From 4d201e340e9ff079ba0d8adf1ee7ece6f829a30b Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Mon, 4 May 2026 11:46:03 -0700 Subject: [PATCH] state: pass state db handles through consumers (#20561) ## Why SQLite state was still being opened from consumer paths, including lazy `OnceCell`-backed thread-store call sites. That let one process construct multiple state DB connections for the same Codex home, which makes SQLite lock contention and `database is locked` failures much easier to hit. State DB lifetime should be chosen by main-like entrypoints and tests, then passed through explicitly. Consumers should use the supplied `Option` or `StateDbHandle` and keep their existing filesystem fallback or error behavior when no handle is available. The startup path also needs to keep the rollout crate in charge of SQLite state initialization. Opening `codex_state::StateRuntime` directly bypasses rollout metadata backfill, so entrypoints should initialize through `codex_rollout::state_db` and receive a handle only after required rollout backfills have completed. ## What Changed - Initialize the state DB in main-like entrypoints for CLI, TUI, app-server, exec, MCP server, and the thread-manager sample. - Pass `Option` through `ThreadManager`, `LocalThreadStore`, app-server processors, TUI app wiring, rollout listing/recording, personality migration, shell snapshot cleanup, session-name lookup, and memory/device-key consumers. - Remove the lazy local state DB wrapper from the thread store so non-test consumers use only the supplied handle or their existing fallback path. - Make `codex_rollout::state_db::init` the local state startup path: it opens/migrates SQLite, runs rollout metadata backfill when needed, waits for concurrent backfill workers up to a bounded timeout, verifies completion, and then returns the initialized handle. - Keep optional/non-owning SQLite helpers, such as remote TUI local reads, as open-only paths that do not run startup backfill. - Switch app-server startup from direct `codex_state::StateRuntime::init` to the rollout state initializer so app-server cannot skip rollout backfill. - Collapse split rollout lookup/list APIs so callers use the normal methods with an optional state handle instead of `_with_state_db` variants. - Restore `getConversationSummary(ThreadId)` to delegate through `ThreadStore::read_thread` instead of a LocalThreadStore-specific rollout path special case. - Keep DB-backed rollout path lookup keyed on the DB row and file existence, without imposing the filesystem filename convention on existing DB rows. - Verify readable DB-backed rollout paths against `session_meta.id` before returning them, so a stale SQLite row that points at another thread's JSONL falls back to filesystem search and read-repairs the DB row. - Keep `debug prompt-input` filesystem-only so a one-off debug command does not initialize or backfill SQLite state just to print prompt input. - Keep goal-session test Codex homes alive only in the goal-specific helper, rather than leaking tempdirs from the shared session test helper. - Update tests and call sites to pass explicit state handles where DB behavior is expected and explicit `None` where filesystem-only behavior is intended. ## Validation - `CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo check -p codex-rollout -p codex-thread-store -p codex-app-server -p codex-core -p codex-tui -p codex-exec -p codex-cli --tests` - `CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo test -p codex-rollout state_db_` - `CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo test -p codex-rollout find_thread_path` - `CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo test -p codex-rollout find_thread_path -- --nocapture` - `CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo test -p codex-rollout try_init_ -- --nocapture` - `CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo test -p codex-rollout` - `CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo clippy -p codex-rollout --lib -- -D warnings` - `CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo test -p codex-thread-store read_thread_falls_back_when_sqlite_path_points_to_another_thread -- --nocapture` - `CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo test -p codex-thread-store` - `CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo test -p codex-core shell_snapshot` - `CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo test -p codex-core --test all personality_migration` - `CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo test -p codex-core --test all rollout_list_find` - `RUST_MIN_STACK=8388608 CODEX_SKIP_VENDORED_BWRAP=1 CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo test -p codex-core --test all rollout_list_find::find_prefers_sqlite_path_by_id -- --nocapture` - `RUST_MIN_STACK=8388608 CODEX_SKIP_VENDORED_BWRAP=1 CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo test -p codex-core --test all rollout_list_find -- --nocapture` - `CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo test -p codex-core interrupt_accounts_active_goal_before_pausing` - `CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo test -p codex-app-server get_auth_status -- --test-threads=1` - `CODEX_SKIP_VENDORED_BWRAP=1 CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo test -p codex-app-server --lib` - `CODEX_SKIP_VENDORED_BWRAP=1 CARGO_TARGET_DIR=/tmp/codex-target-state-db cargo check -p codex-rollout -p codex-app-server --tests` - `CARGO_TARGET_DIR=/tmp/codex-target-state-db just fix -p codex-rollout -p codex-thread-store -p codex-core -p codex-app-server -p codex-tui -p codex-exec -p codex-cli` - `CODEX_SKIP_VENDORED_BWRAP=1 CARGO_TARGET_DIR=/tmp/codex-target-state-db just fix -p codex-rollout -p codex-app-server` - `CARGO_TARGET_DIR=/tmp/codex-target-state-db just fix -p codex-rollout` - `CODEX_SKIP_VENDORED_BWRAP=1 CARGO_TARGET_DIR=/tmp/codex-target-state-db just fix -p codex-core` - `just argument-comment-lint -p codex-core` - `just argument-comment-lint -p codex-rollout` Focused coverage added in `codex-rollout`: - `recorder::tests::state_db_init_backfills_before_returning` verifies the rollout metadata row exists before startup init returns. - `state_db::tests::try_init_waits_for_concurrent_startup_backfill` verifies startup waits for another worker to finish backfill instead of disabling the handle for the process. - `state_db::tests::try_init_times_out_waiting_for_stuck_startup_backfill` verifies startup does not hang indefinitely on a stuck backfill lease. - `tests::find_thread_path_accepts_existing_state_db_path_without_canonical_filename` verifies DB-backed lookup accepts valid existing rollout paths even when the filename does not include the thread UUID. - `tests::find_thread_path_falls_back_when_db_path_points_to_another_thread` verifies DB-backed lookup ignores a stale row whose existing path belongs to another thread and read-repairs the row after filesystem fallback. Focused coverage updated in `codex-core`: - `rollout_list_find::find_prefers_sqlite_path_by_id` now uses a DB-preferred rollout file with matching `session_meta.id`, so it still verifies that valid SQLite paths win without depending on stale/empty rollout contents. `cargo test -p codex-app-server thread_list_respects_search_term_filter -- --test-threads=1 --nocapture` was attempted locally but timed out waiting for the app-server test harness `initialize` response before reaching the changed thread-list code path. `bazel test //codex-rs/thread-store:thread-store-unit-tests --test_output=errors` was attempted locally after the thread-store fix, but this container failed before target analysis while fetching `v8+` through BuildBuddy/direct GitHub. The equivalent local crate coverage, including `cargo test -p codex-thread-store`, passes. A plain local `cargo check -p codex-rollout -p codex-app-server --tests` also requires system `libcap.pc` for `codex-linux-sandbox`; the follow-up app-server check above used `CODEX_SKIP_VENDORED_BWRAP=1` in this container. --- codex-rs/app-server-client/src/lib.rs | 7 + codex-rs/app-server/src/in_process.rs | 39 +++- codex-rs/app-server/src/lib.rs | 89 +++++--- codex-rs/app-server/src/message_processor.rs | 16 +- .../src/message_processor_tracing_tests.rs | 1 + codex-rs/app-server/src/request_processors.rs | 1 - .../device_key_processor.rs | 36 +--- .../request_processors/feedback_processor.rs | 5 +- .../thread_goal_processor.rs | 57 ++++-- .../request_processors/thread_processor.rs | 79 +++---- .../src/request_processors/thread_summary.rs | 8 - .../src/request_processors/turn_processor.rs | 27 ++- .../app-server/tests/suite/v2/mcp_resource.rs | 1 + .../tests/suite/v2/remote_thread_store.rs | 1 + .../tests/suite/v2/thread_archive.rs | 68 ++++--- .../app-server/tests/suite/v2/thread_list.rs | 1 + .../tests/suite/v2/thread_name_websocket.rs | 7 +- .../app-server/tests/suite/v2/thread_read.rs | 3 + .../tests/suite/v2/thread_unarchive.rs | 17 +- codex-rs/cli/src/main.rs | 2 +- codex-rs/core-api/src/lib.rs | 2 + codex-rs/core/src/agent/control.rs | 4 +- codex-rs/core/src/agent/control_tests.rs | 18 +- codex-rs/core/src/lib.rs | 2 +- codex-rs/core/src/personality_migration.rs | 23 ++- .../core/src/personality_migration_tests.rs | 10 +- codex-rs/core/src/prompt_debug.rs | 6 +- codex-rs/core/src/session/mod.rs | 11 +- codex-rs/core/src/session/session.rs | 1 + codex-rs/core/src/session/tests.rs | 74 +++++-- .../core/src/session/tests/guardian_tests.rs | 1 + codex-rs/core/src/shell_snapshot.rs | 28 ++- codex-rs/core/src/shell_snapshot_tests.rs | 33 ++- codex-rs/core/src/state_db_bridge.rs | 4 +- codex-rs/core/src/stream_events_utils.rs | 11 +- codex-rs/core/src/test_support.rs | 16 ++ codex-rs/core/src/thread_manager.rs | 39 +++- codex-rs/core/src/thread_manager_tests.rs | 29 ++- .../src/tools/handlers/multi_agents_tests.rs | 13 +- codex-rs/core/tests/common/test_codex.rs | 7 +- codex-rs/core/tests/suite/client.rs | 3 +- .../core/tests/suite/personality_migration.rs | 31 +-- .../core/tests/suite/prompt_debug_tests.rs | 1 + .../core/tests/suite/rollout_list_find.rs | 62 ++++-- codex-rs/core/tests/suite/skills.rs | 3 +- codex-rs/exec/src/lib.rs | 16 +- codex-rs/mcp-server/src/lib.rs | 2 + codex-rs/mcp-server/src/message_processor.rs | 5 +- codex-rs/rollout/src/list.rs | 56 +++-- codex-rs/rollout/src/metadata.rs | 17 +- codex-rs/rollout/src/recorder.rs | 18 +- codex-rs/rollout/src/recorder_tests.rs | 89 ++++++++ codex-rs/rollout/src/session_index.rs | 9 +- codex-rs/rollout/src/session_index_tests.rs | 6 +- codex-rs/rollout/src/state_db.rs | 192 +++++++++++++----- codex-rs/rollout/src/state_db_tests.rs | 62 ++++++ codex-rs/rollout/src/tests.rs | 81 +++++++- codex-rs/thread-manager-sample/src/main.rs | 5 +- .../thread-store/src/local/archive_thread.rs | 28 +-- .../thread-store/src/local/list_threads.rs | 17 +- codex-rs/thread-store/src/local/mod.rs | 47 ++--- .../thread-store/src/local/read_thread.rs | 84 ++++---- .../src/local/unarchive_thread.rs | 8 +- .../src/local/update_thread_metadata.rs | 58 +++--- codex-rs/tui/src/app.rs | 4 + codex-rs/tui/src/app/event_dispatch.rs | 1 + codex-rs/tui/src/app/session_lifecycle.rs | 2 +- codex-rs/tui/src/app/test_support.rs | 1 + codex-rs/tui/src/app/tests.rs | 2 + codex-rs/tui/src/app/thread_routing.rs | 2 +- codex-rs/tui/src/app/thread_session_state.rs | 2 +- codex-rs/tui/src/lib.rs | 76 +++++-- codex-rs/tui/src/onboarding/auth.rs | 1 + codex-rs/tui/src/session_resume.rs | 15 +- 74 files changed, 1286 insertions(+), 517 deletions(-) diff --git a/codex-rs/app-server-client/src/lib.rs b/codex-rs/app-server-client/src/lib.rs index bbbb109eff..539a1684c5 100644 --- a/codex-rs/app-server-client/src/lib.rs +++ b/codex-rs/app-server-client/src/lib.rs @@ -29,6 +29,7 @@ pub use codex_app_server::in_process::DEFAULT_IN_PROCESS_CHANNEL_CAPACITY; pub use codex_app_server::in_process::InProcessServerEvent; use codex_app_server::in_process::InProcessStartArgs; use codex_app_server::in_process::LogDbLayer; +pub use codex_app_server::in_process::StateDbHandle; use codex_app_server_protocol::ClientInfo; use codex_app_server_protocol::ClientNotification; use codex_app_server_protocol::ClientRequest; @@ -343,6 +344,8 @@ pub struct InProcessClientStartArgs { pub feedback: CodexFeedback, /// SQLite tracing layer used to flush recently emitted logs before feedback upload. pub log_db: Option, + /// Process-wide SQLite state handle shared with the embedded app-server. + pub state_db: Option, /// Environment manager used by core execution and filesystem operations. pub environment_manager: Arc, /// Startup warnings emitted after initialize succeeds. @@ -404,6 +407,7 @@ impl InProcessClientStartArgs { thread_config_loader, feedback: self.feedback, log_db: self.log_db, + state_db: self.state_db, environment_manager: self.environment_manager, config_warnings: self.config_warnings, session_source: self.session_source, @@ -983,6 +987,7 @@ mod tests { cloud_requirements: CloudRequirementsLoader::default(), feedback: CodexFeedback::new(), log_db: None, + state_db: None, environment_manager: Arc::new(EnvironmentManager::default_for_tests()), config_warnings: Vec::new(), session_source, @@ -2057,6 +2062,7 @@ mod tests { cloud_requirements: CloudRequirementsLoader::default(), feedback: CodexFeedback::new(), log_db: None, + state_db: None, environment_manager: environment_manager.clone(), config_warnings: Vec::new(), session_source: SessionSource::Exec, @@ -2096,6 +2102,7 @@ mod tests { cloud_requirements: CloudRequirementsLoader::default(), feedback: CodexFeedback::new(), log_db: None, + state_db: None, environment_manager: Arc::new(EnvironmentManager::default_for_tests()), config_warnings: Vec::new(), session_source: SessionSource::Exec, diff --git a/codex-rs/app-server/src/in_process.rs b/codex-rs/app-server/src/in_process.rs index 57f06e64f0..09d57a60a3 100644 --- a/codex-rs/app-server/src/in_process.rs +++ b/codex-rs/app-server/src/in_process.rs @@ -86,6 +86,7 @@ use codex_exec_server::EnvironmentManager; use codex_feedback::CodexFeedback; use codex_login::AuthManager; use codex_protocol::protocol::SessionSource; +pub use codex_rollout::StateDbHandle; pub use codex_state::log_db::LogDbLayer; use tokio::sync::mpsc; use tokio::sync::oneshot; @@ -126,6 +127,8 @@ pub struct InProcessStartArgs { pub feedback: CodexFeedback, /// SQLite tracing layer used to flush recently emitted logs before feedback upload. pub log_db: Option, + /// Process-wide SQLite state handle shared with embedded app-server consumers. + pub state_db: Option, /// Environment manager used by core execution and filesystem operations. pub environment_manager: Arc, /// Startup warnings emitted after initialize succeeds. @@ -251,6 +254,8 @@ pub struct InProcessClientHandle { client: InProcessClientSender, event_rx: mpsc::Receiver, runtime_handle: tokio::task::JoinHandle<()>, + #[cfg(test)] + _test_codex_home: Option, } impl InProcessClientHandle { @@ -418,6 +423,7 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle { environment_manager: args.environment_manager, feedback: args.feedback, log_db: args.log_db, + state_db: args.state_db, config_warnings: args.config_warnings, session_source: args.session_source, auth_manager, @@ -717,6 +723,8 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle { client: InProcessClientSender { client_tx }, event_rx, runtime_handle, + #[cfg(test)] + _test_codex_home: None, } } @@ -738,13 +746,22 @@ mod tests { use codex_app_server_protocol::TurnStatus; use codex_core::config::ConfigBuilder; use pretty_assertions::assert_eq; + use std::path::Path; + use tempfile::TempDir; - async fn build_test_config() -> Config { - match ConfigBuilder::default().build().await { + async fn build_test_config(codex_home: &Path) -> Config { + match ConfigBuilder::default() + .codex_home(codex_home.to_path_buf()) + .build() + .await + { Ok(config) => config, - Err(_) => Config::load_default_with_cli_overrides(Vec::new()) - .await - .expect("default config should load"), + Err(_) => Config::load_default_with_cli_overrides_for_codex_home( + codex_home.to_path_buf(), + Vec::new(), + ) + .await + .expect("default config should load"), } } @@ -752,15 +769,21 @@ mod tests { session_source: SessionSource, channel_capacity: usize, ) -> InProcessClientHandle { + let codex_home = TempDir::new().expect("temp dir"); + let config = Arc::new(build_test_config(codex_home.path()).await); + let state_db = codex_rollout::state_db::try_init(config.as_ref()) + .await + .expect("state db should initialize for in-process test"); let args = InProcessStartArgs { arg0_paths: Arg0DispatchPaths::default(), - config: Arc::new(build_test_config().await), + config, cli_overrides: Vec::new(), loader_overrides: LoaderOverrides::default(), cloud_requirements: CloudRequirementsLoader::default(), thread_config_loader: Arc::new(codex_config::NoopThreadConfigLoader), feedback: CodexFeedback::new(), log_db: None, + state_db: Some(state_db), environment_manager: Arc::new(EnvironmentManager::default_for_tests()), config_warnings: Vec::new(), session_source, @@ -775,7 +798,9 @@ mod tests { }, channel_capacity, }; - start(args).await.expect("in-process runtime should start") + let mut client = start(args).await.expect("in-process runtime should start"); + client._test_codex_home = Some(codex_home); + client } async fn start_test_client(session_source: SessionSource) -> InProcessClientHandle { diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 45caa04a58..4b1723d88c 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -54,6 +54,7 @@ use codex_exec_server::EnvironmentManager; use codex_exec_server::ExecServerRuntimePaths; use codex_feedback::CodexFeedback; use codex_protocol::protocol::SessionSource; +use codex_rollout::state_db as rollout_state_db; use codex_state::log_db; use tokio::sync::mpsc; use tokio::sync::oneshot; @@ -453,23 +454,6 @@ pub async fn run_main_with_transport_options( .await { Ok(config) => { - let effective_toml = config.config_layer_stack.effective_config(); - match effective_toml.try_into() { - Ok(config_toml) => { - if let Err(err) = codex_core::personality_migration::maybe_migrate_personality( - &config.codex_home, - &config_toml, - ) - .await - { - warn!(error = %err, "Failed to run personality migration"); - } - } - Err(err) => { - warn!(error = %err, "Failed to deserialize config for personality migration"); - } - } - let discovered_thread_config_loader = configured_thread_config_loader(&config); config_manager .replace_thread_config_loader(Arc::clone(&discovered_thread_config_loader)); @@ -483,23 +467,70 @@ pub async fn run_main_with_transport_options( } }; let mut config_warnings = Vec::new(); - let config = match config_manager + let (mut config, should_run_personality_migration) = match config_manager .load_latest_config(/*fallback_cwd*/ None) .await { - Ok(config) => config, + Ok(config) => (config, true), Err(err) => { let message = config_warning_from_error("Invalid configuration; using defaults.", &err); config_warnings.push(message); - config_manager.load_default_config().await.map_err(|e| { - std::io::Error::new( - ErrorKind::InvalidData, - format!("error loading default config after config error: {e}"), - ) - })? + ( + config_manager.load_default_config().await.map_err(|e| { + std::io::Error::new( + ErrorKind::InvalidData, + format!("error loading default config after config error: {e}"), + ) + })?, + false, + ) } }; + let state_db_result = rollout_state_db::try_init(&config).await; + let state_db_init_error = state_db_result.as_ref().err().map(ToString::to_string); + let state_db = state_db_result.ok(); + + if should_run_personality_migration { + let effective_toml = config.config_layer_stack.effective_config(); + match effective_toml.try_into() { + Ok(config_toml) => { + match codex_core::personality_migration::maybe_migrate_personality( + &config.codex_home, + &config_toml, + state_db.clone(), + ) + .await + { + Ok(codex_core::personality_migration::PersonalityMigrationStatus::Applied) => { + config = config_manager + .load_latest_config(/*fallback_cwd*/ None) + .await + .map_err(|err| { + std::io::Error::new( + ErrorKind::InvalidData, + format!( + "error reloading config after personality migration: {err}" + ), + ) + })?; + } + Ok( + codex_core::personality_migration::PersonalityMigrationStatus::SkippedMarker + | codex_core::personality_migration::PersonalityMigrationStatus::SkippedExplicitPersonality + | codex_core::personality_migration::PersonalityMigrationStatus::SkippedNoSessions, + ) => {} + Err(err) => { + warn!(error = %err, "Failed to run personality migration"); + } + } + } + Err(err) => { + warn!(error = %err, "Failed to deserialize config for personality migration"); + } + } + } + if let Ok(Some(err)) = check_execpolicy_for_warnings(&config.config_layer_stack).await { let (path, range) = exec_policy_warning_location(&err); let message = ConfigWarningNotification { @@ -567,13 +598,6 @@ pub async fn run_main_with_transport_options( let feedback_layer = feedback.logger_layer(); let feedback_metadata_layer = feedback.metadata_layer(); - let state_db_result = codex_state::StateRuntime::init( - config.sqlite_home.clone(), - config.model_provider_id.clone(), - ) - .await; - let state_db_init_error = state_db_result.as_ref().err().map(ToString::to_string); - let state_db = state_db_result.ok(); let log_db = state_db.clone().map(log_db::start); let log_db_layer = log_db .clone() @@ -745,6 +769,7 @@ pub async fn run_main_with_transport_options( environment_manager, feedback: feedback.clone(), log_db, + state_db: state_db.clone(), config_warnings, session_source, auth_manager, diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index d444dee33c..b4bce010de 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -72,6 +72,7 @@ use codex_login::auth::ExternalAuthTokens; use codex_protocol::ThreadId; use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::W3cTraceContext; +use codex_rollout::StateDbHandle; use codex_state::log_db::LogDbLayer; use tokio::sync::Mutex; use tokio::sync::Semaphore; @@ -251,6 +252,7 @@ pub(crate) struct MessageProcessorArgs { pub(crate) environment_manager: Arc, pub(crate) feedback: CodexFeedback, pub(crate) log_db: Option, + pub(crate) state_db: Option, pub(crate) config_warnings: Vec, pub(crate) session_source: SessionSource, pub(crate) auth_manager: Arc, @@ -272,6 +274,7 @@ impl MessageProcessor { environment_manager, feedback, log_db, + state_db, config_warnings, session_source, auth_manager, @@ -285,7 +288,7 @@ impl MessageProcessor { // The thread store is intentionally process-scoped. Config reloads can // affect per-thread behavior, but they must not move newly started, // resumed, or forked threads to a different persistence backend/root. - let thread_store = thread_store_from_config(config.as_ref()); + let thread_store = thread_store_from_config(config.as_ref(), state_db.clone()); let thread_manager = Arc::new(ThreadManager::new( config.as_ref(), auth_manager.clone(), @@ -293,6 +296,7 @@ impl MessageProcessor { environment_manager, Some(analytics_events_client.clone()), Arc::clone(&thread_store), + state_db.clone(), )); thread_manager .plugins_manager() @@ -337,6 +341,7 @@ impl MessageProcessor { Arc::clone(&config), feedback, log_db, + state_db.clone(), ); let git_processor = GitRequestProcessor::new(); let initialize_processor = InitializeRequestProcessor::new( @@ -371,6 +376,7 @@ impl MessageProcessor { outgoing.clone(), Arc::clone(&config), thread_state_manager.clone(), + state_db.clone(), ); let thread_processor = ThreadRequestProcessor::new( auth_manager.clone(), @@ -386,6 +392,7 @@ impl MessageProcessor { thread_watch_manager.clone(), Arc::clone(&thread_list_state_permit), thread_goal_processor.clone(), + state_db.clone(), ); let turn_processor = TurnRequestProcessor::new( auth_manager.clone(), @@ -399,6 +406,7 @@ impl MessageProcessor { thread_state_manager, thread_watch_manager, thread_list_state_permit, + state_db.clone(), ); if matches!(plugin_startup_tasks, crate::PluginStartupTasks::Start) { // Keep plugin startup warmups aligned at app-server startup. @@ -429,11 +437,7 @@ impl MessageProcessor { arg0_paths, config.codex_home.to_path_buf(), ); - let device_key_processor = DeviceKeyRequestProcessor::new( - outgoing.clone(), - config.sqlite_home.clone(), - config.model_provider_id.clone(), - ); + let device_key_processor = DeviceKeyRequestProcessor::new(outgoing.clone(), state_db); let fs_processor = FsRequestProcessor::new( thread_manager .environment_manager() diff --git a/codex-rs/app-server/src/message_processor_tracing_tests.rs b/codex-rs/app-server/src/message_processor_tracing_tests.rs index 8caf1aaa96..45ea709180 100644 --- a/codex-rs/app-server/src/message_processor_tracing_tests.rs +++ b/codex-rs/app-server/src/message_processor_tracing_tests.rs @@ -290,6 +290,7 @@ async fn build_test_processor( environment_manager: Arc::new(EnvironmentManager::default_for_tests()), feedback: CodexFeedback::new(), log_db: None, + state_db: None, config_warnings: Vec::new(), session_source: SessionSource::VSCode, auth_manager, diff --git a/codex-rs/app-server/src/request_processors.rs b/codex-rs/app-server/src/request_processors.rs index b1c6ab9815..f59ea14402 100644 --- a/codex-rs/app-server/src/request_processors.rs +++ b/codex-rs/app-server/src/request_processors.rs @@ -369,7 +369,6 @@ use codex_rmcp_client::perform_oauth_login_return_url; use codex_rollout::EventPersistenceMode; use codex_rollout::is_persisted_rollout_item; use codex_rollout::state_db::StateDbHandle; -use codex_rollout::state_db::get_state_db; use codex_rollout::state_db::reconcile_rollout; use codex_state::StateRuntime; use codex_state::ThreadMetadata; diff --git a/codex-rs/app-server/src/request_processors/device_key_processor.rs b/codex-rs/app-server/src/request_processors/device_key_processor.rs index c469f2544a..ea0a96c2af 100644 --- a/codex-rs/app-server/src/request_processors/device_key_processor.rs +++ b/codex-rs/app-server/src/request_processors/device_key_processor.rs @@ -1,6 +1,5 @@ use std::fmt; use std::future::Future; -use std::path::PathBuf; use std::sync::Arc; use crate::error_code::internal_error; @@ -36,7 +35,6 @@ use codex_device_key::RemoteControlClientEnrollmentAudience; use codex_device_key::RemoteControlClientEnrollmentSignPayload; use codex_state::DeviceKeyBindingRecord; use codex_state::StateRuntime; -use tokio::sync::OnceCell; #[derive(Clone)] pub(crate) struct DeviceKeyRequestProcessor { @@ -47,15 +45,11 @@ pub(crate) struct DeviceKeyRequestProcessor { impl DeviceKeyRequestProcessor { pub(crate) fn new( outgoing: Arc, - sqlite_home: PathBuf, - default_provider: String, + state_db: Option>, ) -> Self { Self { outgoing, - store: DeviceKeyStore::new(Arc::new(StateDeviceKeyBindingStore::new( - sqlite_home, - default_provider, - ))), + store: DeviceKeyStore::new(Arc::new(StateDeviceKeyBindingStore::new(state_db))), } } @@ -176,39 +170,25 @@ async fn sign_device_key( } struct StateDeviceKeyBindingStore { - sqlite_home: PathBuf, - default_provider: String, - state_db: OnceCell>, + state_db: Option>, } impl StateDeviceKeyBindingStore { - fn new(sqlite_home: PathBuf, default_provider: String) -> Self { - Self { - sqlite_home, - default_provider, - state_db: OnceCell::new(), - } + fn new(state_db: Option>) -> Self { + Self { state_db } } async fn state_db(&self) -> Result, DeviceKeyError> { - let sqlite_home = self.sqlite_home.clone(); - let default_provider = self.default_provider.clone(); self.state_db - .get_or_try_init(|| async move { - StateRuntime::init(sqlite_home, default_provider) - .await - .map_err(|err| DeviceKeyError::Platform(err.to_string())) - }) - .await - .cloned() + .clone() + .ok_or_else(|| DeviceKeyError::Platform("sqlite state db unavailable".to_string())) } } impl fmt::Debug for StateDeviceKeyBindingStore { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("StateDeviceKeyBindingStore") - .field("sqlite_home", &self.sqlite_home) - .field("default_provider", &self.default_provider) + .field("has_state_db", &self.state_db.is_some()) .finish_non_exhaustive() } } diff --git a/codex-rs/app-server/src/request_processors/feedback_processor.rs b/codex-rs/app-server/src/request_processors/feedback_processor.rs index 666dcac83a..5e945d7b10 100644 --- a/codex-rs/app-server/src/request_processors/feedback_processor.rs +++ b/codex-rs/app-server/src/request_processors/feedback_processor.rs @@ -7,6 +7,7 @@ pub(crate) struct FeedbackRequestProcessor { config: Arc, feedback: CodexFeedback, log_db: Option, + state_db: Option, } impl FeedbackRequestProcessor { @@ -16,6 +17,7 @@ impl FeedbackRequestProcessor { config: Arc, feedback: CodexFeedback, log_db: Option, + state_db: Option, ) -> Self { Self { auth_manager, @@ -23,6 +25,7 @@ impl FeedbackRequestProcessor { config, feedback, log_db, + state_db, } } @@ -75,7 +78,7 @@ impl FeedbackRequestProcessor { if let Some(log_db) = self.log_db.as_ref() { log_db.flush().await; } - let state_db_ctx = get_state_db(&self.config).await; + let state_db_ctx = self.state_db.clone(); let feedback_thread_ids = match conversation_id { Some(conversation_id) => match self .thread_manager diff --git a/codex-rs/app-server/src/request_processors/thread_goal_processor.rs b/codex-rs/app-server/src/request_processors/thread_goal_processor.rs index ba9115f150..2e7c9909f0 100644 --- a/codex-rs/app-server/src/request_processors/thread_goal_processor.rs +++ b/codex-rs/app-server/src/request_processors/thread_goal_processor.rs @@ -7,6 +7,7 @@ pub(crate) struct ThreadGoalRequestProcessor { outgoing: Arc, config: Arc, thread_state_manager: ThreadStateManager, + state_db: Option, } impl ThreadGoalRequestProcessor { @@ -15,12 +16,14 @@ impl ThreadGoalRequestProcessor { outgoing: Arc, config: Arc, thread_state_manager: ThreadStateManager, + state_db: Option, ) -> Self { Self { thread_manager, outgoing, config, thread_state_manager, + state_db, } } @@ -78,7 +81,7 @@ impl ThreadGoalRequestProcessor { if let Some(state_db) = thread.state_db() { Some(state_db) } else { - open_state_db_for_direct_thread_lookup(&self.config).await + self.state_db.clone() } } else { None @@ -104,12 +107,16 @@ impl ThreadGoalRequestProcessor { "ephemeral thread does not support goals: {thread_id}" )) })?, - None => find_thread_path_by_id_str(&self.config.codex_home, &thread_id.to_string()) - .await - .map_err(|err| { - internal_error(format!("failed to locate thread id {thread_id}: {err}")) - })? - .ok_or_else(|| invalid_request(format!("thread not found: {thread_id}")))?, + None => find_thread_path_by_id_str( + &self.config.codex_home, + &thread_id.to_string(), + self.state_db.as_deref(), + ) + .await + .map_err(|err| { + internal_error(format!("failed to locate thread id {thread_id}: {err}")) + })? + .ok_or_else(|| invalid_request(format!("thread not found: {thread_id}")))?, }; reconcile_rollout( Some(&state_db), @@ -247,12 +254,16 @@ impl ThreadGoalRequestProcessor { "ephemeral thread does not support goals: {thread_id}" )) })?, - None => find_thread_path_by_id_str(&self.config.codex_home, &thread_id.to_string()) - .await - .map_err(|err| { - internal_error(format!("failed to locate thread id {thread_id}: {err}")) - })? - .ok_or_else(|| invalid_request(format!("thread not found: {thread_id}")))?, + None => find_thread_path_by_id_str( + &self.config.codex_home, + &thread_id.to_string(), + self.state_db.as_deref(), + ) + .await + .map_err(|err| { + internal_error(format!("failed to locate thread id {thread_id}: {err}")) + })? + .ok_or_else(|| invalid_request(format!("thread not found: {thread_id}")))?, }; reconcile_rollout( Some(&state_db), @@ -307,16 +318,20 @@ impl ThreadGoalRequestProcessor { return Ok(state_db); } } else { - find_thread_path_by_id_str(&self.config.codex_home, &thread_id.to_string()) - .await - .map_err(|err| { - internal_error(format!("failed to locate thread id {thread_id}: {err}")) - })? - .ok_or_else(|| invalid_request(format!("thread not found: {thread_id}")))?; + find_thread_path_by_id_str( + &self.config.codex_home, + &thread_id.to_string(), + self.state_db.as_deref(), + ) + .await + .map_err(|err| { + internal_error(format!("failed to locate thread id {thread_id}: {err}")) + })? + .ok_or_else(|| invalid_request(format!("thread not found: {thread_id}")))?; } - open_state_db_for_direct_thread_lookup(&self.config) - .await + self.state_db + .clone() .ok_or_else(|| internal_error("sqlite state db unavailable for thread goals")) } diff --git a/codex-rs/app-server/src/request_processors/thread_processor.rs b/codex-rs/app-server/src/request_processors/thread_processor.rs index f10652a029..40be4f34ef 100644 --- a/codex-rs/app-server/src/request_processors/thread_processor.rs +++ b/codex-rs/app-server/src/request_processors/thread_processor.rs @@ -259,6 +259,7 @@ pub(crate) struct ThreadRequestProcessor { pub(super) thread_watch_manager: ThreadWatchManager, pub(super) thread_list_state_permit: Arc, pub(super) thread_goal_processor: ThreadGoalRequestProcessor, + pub(super) state_db: Option, pub(super) background_tasks: TaskTracker, } @@ -278,6 +279,7 @@ impl ThreadRequestProcessor { thread_watch_manager: ThreadWatchManager, thread_list_state_permit: Arc, thread_goal_processor: ThreadGoalRequestProcessor, + state_db: Option, ) -> Self { Self { auth_manager, @@ -293,6 +295,7 @@ impl ThreadRequestProcessor { thread_watch_manager, thread_list_state_permit, thread_goal_processor, + state_db, background_tasks: TaskTracker::new(), } } @@ -1170,7 +1173,7 @@ impl ThreadRequestProcessor { .map_err(|err| invalid_request(format!("invalid thread id: {err}")))?; let mut thread_ids = vec![thread_id]; - if let Some(state_db_ctx) = get_state_db(&self.config).await { + if let Some(state_db_ctx) = self.state_db.as_ref() { let descendants = state_db_ctx .list_thread_spawn_descendants(thread_id) .await @@ -1391,14 +1394,10 @@ impl ThreadRequestProcessor { } async fn memory_reset_response_inner(&self) -> Result { - let state_db = StateRuntime::init( - self.config.sqlite_home.clone(), - self.config.model_provider_id.clone(), - ) - .await - .map_err(|err| { - internal_error(format!("failed to open state db for memory reset: {err}")) - })?; + let state_db = self + .state_db + .clone() + .ok_or_else(|| internal_error("sqlite state db unavailable for memory reset"))?; state_db.clear_memory_data().await.map_err(|err| { internal_error(format!("failed to clear memory rows in state db: {err}")) @@ -1445,7 +1444,7 @@ impl ThreadRequestProcessor { let loaded_thread = self.thread_manager.get_thread(thread_uuid).await.ok(); let mut state_db_ctx = loaded_thread.as_ref().and_then(|thread| thread.state_db()); if state_db_ctx.is_none() { - state_db_ctx = get_state_db(&self.config).await; + state_db_ctx = self.state_db.clone(); } let Some(state_db_ctx) = state_db_ctx else { return Err(internal_error(format!( @@ -1583,33 +1582,37 @@ impl ThreadRequestProcessor { return Ok(()); } - let rollout_path = - match find_thread_path_by_id_str(&self.config.codex_home, &thread_uuid.to_string()) - .await + let rollout_path = match find_thread_path_by_id_str( + &self.config.codex_home, + &thread_uuid.to_string(), + self.state_db.as_deref(), + ) + .await + { + Ok(Some(path)) => path, + Ok(None) => match find_archived_thread_path_by_id_str( + &self.config.codex_home, + &thread_uuid.to_string(), + self.state_db.as_deref(), + ) + .await { Ok(Some(path)) => path, - Ok(None) => match find_archived_thread_path_by_id_str( - &self.config.codex_home, - &thread_uuid.to_string(), - ) - .await - { - Ok(Some(path)) => path, - Ok(None) => { - return Err(invalid_request(format!("thread not found: {thread_uuid}"))); - } - Err(err) => { - return Err(internal_error(format!( - "failed to locate archived thread id {thread_uuid}: {err}" - ))); - } - }, + Ok(None) => { + return Err(invalid_request(format!("thread not found: {thread_uuid}"))); + } Err(err) => { return Err(internal_error(format!( - "failed to locate thread id {thread_uuid}: {err}" + "failed to locate archived thread id {thread_uuid}: {err}" ))); } - }; + }, + Err(err) => { + return Err(internal_error(format!( + "failed to locate thread id {thread_uuid}: {err}" + ))); + } + }; reconcile_rollout( Some(state_db_ctx), @@ -2555,7 +2558,7 @@ impl ThreadRequestProcessor { let InitialHistory::Resumed(resumed_history) = thread_history else { return None; }; - let state_db_ctx = get_state_db(&self.config).await?; + let state_db_ctx = self.state_db.clone()?; let persisted_metadata = state_db_ctx .get_thread(resumed_history.conversation_id) .await @@ -2922,7 +2925,9 @@ impl ThreadRequestProcessor { } async fn attach_thread_name(&self, thread_id: ThreadId, thread: &mut Thread) { - if let Some(title) = title_from_state_db(&self.config, thread_id).await { + if let Some(title) = + title_from_state_db(&self.config, self.state_db.as_ref(), thread_id).await + { set_thread_name_from_title(thread, title); } } @@ -3683,8 +3688,12 @@ async fn read_summary_from_state_db_context_by_thread_id( Some(summary_from_thread_metadata(&metadata)) } -async fn title_from_state_db(config: &Config, thread_id: ThreadId) -> Option { - if let Some(state_db_ctx) = open_state_db_for_direct_thread_lookup(config).await +async fn title_from_state_db( + config: &Config, + state_db_ctx: Option<&StateDbHandle>, + thread_id: ThreadId, +) -> Option { + if let Some(state_db_ctx) = state_db_ctx && let Some(metadata) = state_db_ctx.get_thread(thread_id).await.ok().flatten() && let Some(title) = distinct_title(&metadata) { diff --git a/codex-rs/app-server/src/request_processors/thread_summary.rs b/codex-rs/app-server/src/request_processors/thread_summary.rs index d528604928..ed2360ed13 100644 --- a/codex-rs/app-server/src/request_processors/thread_summary.rs +++ b/codex-rs/app-server/src/request_processors/thread_summary.rs @@ -1,13 +1,5 @@ use super::*; -pub(super) async fn open_state_db_for_direct_thread_lookup( - config: &Config, -) -> Option { - StateRuntime::init(config.sqlite_home.clone(), config.model_provider_id.clone()) - .await - .ok() -} - pub(crate) async fn read_summary_from_rollout( path: &Path, fallback_provider: &str, diff --git a/codex-rs/app-server/src/request_processors/turn_processor.rs b/codex-rs/app-server/src/request_processors/turn_processor.rs index cfeaedbdb0..c11b7166e2 100644 --- a/codex-rs/app-server/src/request_processors/turn_processor.rs +++ b/codex-rs/app-server/src/request_processors/turn_processor.rs @@ -13,6 +13,7 @@ pub(crate) struct TurnRequestProcessor { thread_state_manager: ThreadStateManager, thread_watch_manager: ThreadWatchManager, thread_list_state_permit: Arc, + state_db: Option, } impl TurnRequestProcessor { @@ -29,6 +30,7 @@ impl TurnRequestProcessor { thread_state_manager: ThreadStateManager, thread_watch_manager: ThreadWatchManager, thread_list_state_permit: Arc, + state_db: Option, ) -> Self { Self { auth_manager, @@ -42,6 +44,7 @@ impl TurnRequestProcessor { thread_state_manager, thread_watch_manager, thread_list_state_permit, + state_db, } } @@ -891,16 +894,20 @@ impl TurnRequestProcessor { let rollout_path = if let Some(path) = parent_thread.rollout_path() { path } else { - find_thread_path_by_id_str(&self.config.codex_home, &parent_thread_id.to_string()) - .await - .map_err(|err| { - internal_error(format!( - "failed to locate thread id {parent_thread_id}: {err}" - )) - })? - .ok_or_else(|| { - invalid_request(format!("no rollout found for thread id {parent_thread_id}")) - })? + find_thread_path_by_id_str( + &self.config.codex_home, + &parent_thread_id.to_string(), + self.state_db.as_deref(), + ) + .await + .map_err(|err| { + internal_error(format!( + "failed to locate thread id {parent_thread_id}: {err}" + )) + })? + .ok_or_else(|| { + invalid_request(format!("no rollout found for thread id {parent_thread_id}")) + })? }; let mut config = self.config.as_ref().clone(); diff --git a/codex-rs/app-server/tests/suite/v2/mcp_resource.rs b/codex-rs/app-server/tests/suite/v2/mcp_resource.rs index 3b1a495576..a51f4bbd4e 100644 --- a/codex-rs/app-server/tests/suite/v2/mcp_resource.rs +++ b/codex-rs/app-server/tests/suite/v2/mcp_resource.rs @@ -204,6 +204,7 @@ async fn mcp_resource_read_returns_error_for_unknown_thread() -> Result<()> { thread_config_loader: Arc::new(codex_config::NoopThreadConfigLoader), feedback: CodexFeedback::new(), log_db: None, + state_db: None, environment_manager: Arc::new(EnvironmentManager::default_for_tests()), config_warnings: Vec::new(), session_source: SessionSource::Cli, diff --git a/codex-rs/app-server/tests/suite/v2/remote_thread_store.rs b/codex-rs/app-server/tests/suite/v2/remote_thread_store.rs index a76caefebb..b04eb12a45 100644 --- a/codex-rs/app-server/tests/suite/v2/remote_thread_store.rs +++ b/codex-rs/app-server/tests/suite/v2/remote_thread_store.rs @@ -80,6 +80,7 @@ async fn thread_start_with_non_local_thread_store_does_not_create_local_persiste thread_config_loader: Arc::new(NoopThreadConfigLoader), feedback: CodexFeedback::new(), log_db: None, + state_db: None, environment_manager: Arc::new(EnvironmentManager::default_for_tests()), config_warnings: Vec::new(), session_source: SessionSource::Cli, diff --git a/codex-rs/app-server/tests/suite/v2/thread_archive.rs b/codex-rs/app-server/tests/suite/v2/thread_archive.rs index 7d884c9a7d..b441a23cb6 100644 --- a/codex-rs/app-server/tests/suite/v2/thread_archive.rs +++ b/codex-rs/app-server/tests/suite/v2/thread_archive.rs @@ -63,7 +63,7 @@ async fn thread_archive_requires_materialized_rollout() -> Result<()> { rollout_path.display() ); assert!( - find_thread_path_by_id_str(codex_home.path(), &thread.id) + find_thread_path_by_id_str(codex_home.path(), &thread.id, /*state_db_ctx*/ None) .await? .is_none(), "thread id should not be discoverable before rollout materialization" @@ -118,9 +118,10 @@ async fn thread_archive_requires_materialized_rollout() -> Result<()> { rollout_path.display() ); - let discovered_path = find_thread_path_by_id_str(codex_home.path(), &thread.id) - .await? - .expect("expected rollout path for thread id to exist after materialization"); + let discovered_path = + find_thread_path_by_id_str(codex_home.path(), &thread.id, /*state_db_ctx*/ None) + .await? + .expect("expected rollout path for thread id to exist after materialization"); assert_paths_match_on_disk(&discovered_path, &rollout_path)?; let archive_id = mcp @@ -252,15 +253,23 @@ async fn thread_archive_archives_spawned_descendants() -> Result<()> { for thread_id in [parent_thread_id, child_thread_id, grandchild_thread_id] { assert!( - find_thread_path_by_id_str(codex_home.path(), &thread_id.to_string()) - .await? - .is_none(), + find_thread_path_by_id_str( + codex_home.path(), + &thread_id.to_string(), + /*state_db_ctx*/ None, + ) + .await? + .is_none(), "expected active rollout for {thread_id} to be archived" ); assert!( - find_archived_thread_path_by_id_str(codex_home.path(), &thread_id.to_string()) - .await? - .is_some(), + find_archived_thread_path_by_id_str( + codex_home.path(), + &thread_id.to_string(), + /*state_db_ctx*/ None, + ) + .await? + .is_some(), "expected archived rollout for {thread_id} to exist" ); } @@ -322,9 +331,10 @@ async fn thread_archive_succeeds_when_descendant_archive_fails() -> Result<()> { ) .await?; - let child_rollout_path = find_thread_path_by_id_str(codex_home.path(), &child_id) - .await? - .expect("child rollout path"); + let child_rollout_path = + find_thread_path_by_id_str(codex_home.path(), &child_id, /*state_db_ctx*/ None) + .await? + .expect("child rollout path"); let archived_child_path = codex_home .path() .join(ARCHIVED_SESSIONS_SUBDIR) @@ -381,15 +391,23 @@ async fn thread_archive_succeeds_when_descendant_archive_fails() -> Result<()> { ); for thread_id in [parent_thread_id, grandchild_thread_id] { assert!( - find_thread_path_by_id_str(codex_home.path(), &thread_id.to_string()) - .await? - .is_none(), + find_thread_path_by_id_str( + codex_home.path(), + &thread_id.to_string(), + /*state_db_ctx*/ None, + ) + .await? + .is_none(), "expected active rollout for {thread_id} to be archived" ); assert!( - find_archived_thread_path_by_id_str(codex_home.path(), &thread_id.to_string()) - .await? - .is_some(), + find_archived_thread_path_by_id_str( + codex_home.path(), + &thread_id.to_string(), + /*state_db_ctx*/ None, + ) + .await? + .is_some(), "expected archived rollout for {thread_id} to exist" ); } @@ -455,15 +473,19 @@ async fn thread_archive_succeeds_when_spawned_descendant_is_missing() -> Result< assert_eq!(archived_notification.thread_id, parent_id); assert!( - find_thread_path_by_id_str(codex_home.path(), &parent_id) + find_thread_path_by_id_str(codex_home.path(), &parent_id, /*state_db_ctx*/ None) .await? .is_none(), "parent should be archived even when a descendant is missing" ); assert!( - find_archived_thread_path_by_id_str(codex_home.path(), &parent_id) - .await? - .is_some(), + find_archived_thread_path_by_id_str( + codex_home.path(), + &parent_id, + /*state_db_ctx*/ None, + ) + .await? + .is_some(), "parent should be moved into archived sessions" ); diff --git a/codex-rs/app-server/tests/suite/v2/thread_list.rs b/codex-rs/app-server/tests/suite/v2/thread_list.rs index 615692d70d..ebaba81852 100644 --- a/codex-rs/app-server/tests/suite/v2/thread_list.rs +++ b/codex-rs/app-server/tests/suite/v2/thread_list.rs @@ -614,6 +614,7 @@ sqlite = true generate_memories: false, }; let repaired_page = codex_core::RolloutRecorder::list_threads( + Some(state_db.clone()), &rollout_config, /*page_size*/ 10, /*cursor*/ None, diff --git a/codex-rs/app-server/tests/suite/v2/thread_name_websocket.rs b/codex-rs/app-server/tests/suite/v2/thread_name_websocket.rs index b41e2f1d18..6626b7a6cc 100644 --- a/codex-rs/app-server/tests/suite/v2/thread_name_websocket.rs +++ b/codex-rs/app-server/tests/suite/v2/thread_name_websocket.rs @@ -211,9 +211,10 @@ async fn thread_name_update_rollout_count( codex_home: &Path, conversation_id: &str, ) -> Result { - let rollout_path = find_thread_path_by_id_str(codex_home, conversation_id) - .await? - .context("rollout path")?; + let rollout_path = + find_thread_path_by_id_str(codex_home, conversation_id, /*state_db_ctx*/ None) + .await? + .context("rollout path")?; let contents = tokio::fs::read_to_string(rollout_path).await?; Ok(contents .lines() diff --git a/codex-rs/app-server/tests/suite/v2/thread_read.rs b/codex-rs/app-server/tests/suite/v2/thread_read.rs index 589c7c330a..8c46a5ad95 100644 --- a/codex-rs/app-server/tests/suite/v2/thread_read.rs +++ b/codex-rs/app-server/tests/suite/v2/thread_read.rs @@ -300,6 +300,7 @@ async fn thread_turns_list_reads_store_history_without_rollout_path() -> Result< thread_config_loader: Arc::new(codex_config::NoopThreadConfigLoader), feedback: CodexFeedback::new(), log_db: None, + state_db: None, environment_manager: Arc::new(EnvironmentManager::default_for_tests()), config_warnings: Vec::new(), session_source: SessionSource::Cli.into(), @@ -363,6 +364,7 @@ async fn thread_read_loaded_include_turns_reads_store_history_without_rollout_pa thread_config_loader: Arc::new(codex_config::NoopThreadConfigLoader), feedback: CodexFeedback::new(), log_db: None, + state_db: None, environment_manager: Arc::new(EnvironmentManager::default_for_tests()), config_warnings: Vec::new(), session_source: SessionSource::Cli.into(), @@ -447,6 +449,7 @@ async fn thread_list_includes_store_thread_without_rollout_path() -> Result<()> thread_config_loader: Arc::new(codex_config::NoopThreadConfigLoader), feedback: CodexFeedback::new(), log_db: None, + state_db: None, environment_manager: Arc::new(EnvironmentManager::default_for_tests()), config_warnings: Vec::new(), session_source: SessionSource::Cli.into(), diff --git a/codex-rs/app-server/tests/suite/v2/thread_unarchive.rs b/codex-rs/app-server/tests/suite/v2/thread_unarchive.rs index b2ae60ae35..588764edb8 100644 --- a/codex-rs/app-server/tests/suite/v2/thread_unarchive.rs +++ b/codex-rs/app-server/tests/suite/v2/thread_unarchive.rs @@ -75,9 +75,10 @@ async fn thread_unarchive_moves_rollout_back_into_sessions_directory() -> Result ) .await??; - let found_rollout_path = find_thread_path_by_id_str(codex_home.path(), &thread.id) - .await? - .expect("expected rollout path for thread id to exist"); + let found_rollout_path = + find_thread_path_by_id_str(codex_home.path(), &thread.id, /*state_db_ctx*/ None) + .await? + .expect("expected rollout path for thread id to exist"); assert_paths_match_on_disk(&found_rollout_path, &rollout_path)?; let archive_id = mcp @@ -92,9 +93,13 @@ async fn thread_unarchive_moves_rollout_back_into_sessions_directory() -> Result .await??; let _: ThreadArchiveResponse = to_response::(archive_resp)?; - let archived_path = find_archived_thread_path_by_id_str(codex_home.path(), &thread.id) - .await? - .expect("expected archived rollout path for thread id to exist"); + let archived_path = find_archived_thread_path_by_id_str( + codex_home.path(), + &thread.id, + /*state_db_ctx*/ None, + ) + .await? + .expect("expected archived rollout path for thread id to exist"); let archived_path_display = archived_path.display(); assert!( archived_path.exists(), diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index 18b811bc1d..c56055ea2d 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -1388,7 +1388,7 @@ async fn run_debug_prompt_input_command( }); } - let prompt_input = codex_core::build_prompt_input(config, input).await?; + let prompt_input = codex_core::build_prompt_input(config, input, /*state_db*/ None).await?; println!("{}", serde_json::to_string_pretty(&prompt_input)?); Ok(()) diff --git a/codex-rs/core-api/src/lib.rs b/codex-rs/core-api/src/lib.rs index dca169ed2b..aa68656d1b 100644 --- a/codex-rs/core-api/src/lib.rs +++ b/codex-rs/core-api/src/lib.rs @@ -27,6 +27,7 @@ pub use codex_core::ForkSnapshot; pub use codex_core::McpManager; pub use codex_core::NewThread; pub use codex_core::StartThreadOptions; +pub use codex_core::StateDbHandle; pub use codex_core::ThreadManager; pub use codex_core::ThreadShutdownReport; pub use codex_core::config::Config; @@ -37,6 +38,7 @@ pub use codex_core::config::Permissions; pub use codex_core::config::TerminalResizeReflowConfig; pub use codex_core::config::ThreadStoreConfig; pub use codex_core::config::find_codex_home; +pub use codex_core::init_state_db; pub use codex_core::skills::SkillsManager; pub use codex_core::thread_store_from_config; pub use codex_exec_server::EnvironmentManager; diff --git a/codex-rs/core/src/agent/control.rs b/codex-rs/core/src/agent/control.rs index 2ade6ab25c..1a13c0b00c 100644 --- a/codex-rs/core/src/agent/control.rs +++ b/codex-rs/core/src/agent/control.rs @@ -29,7 +29,6 @@ use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::SubAgentSource; use codex_protocol::protocol::TurnEnvironmentSelection; use codex_protocol::user_input::UserInput; -use codex_rollout::state_db; use codex_state::DirectionalThreadSpawnEdgeStatus; use codex_thread_store::ReadThreadParams; use serde::Serialize; @@ -526,6 +525,7 @@ impl AgentControl { let _ = config.features.disable(Feature::Collab); } let state = self.upgrade()?; + let state_db_ctx = state.state_db(); let mut reservation = self.state.reserve_spawn_slot(config.agent_max_threads)?; let (session_source, agent_metadata) = match session_source { SessionSource::SubAgent(SubAgentSource::ThreadSpawn { @@ -536,7 +536,7 @@ impl AgentControl { agent_nickname: _, }) => { let (resumed_agent_nickname, resumed_agent_role) = - if let Some(state_db_ctx) = state_db::get_state_db(&config).await { + if let Some(state_db_ctx) = state_db_ctx.as_ref() { match state_db_ctx.get_thread(thread_id).await { Ok(Some(metadata)) => (metadata.agent_nickname, metadata.agent_role), Ok(None) | Err(_) => (None, None), diff --git a/codex-rs/core/src/agent/control_tests.rs b/codex-rs/core/src/agent/control_tests.rs index 7ef2120d5c..b95aad4489 100644 --- a/codex-rs/core/src/agent/control_tests.rs +++ b/codex-rs/core/src/agent/control_tests.rs @@ -1,5 +1,6 @@ use super::*; use crate::CodexThread; +use crate::StateDbHandle; use crate::ThreadManager; use crate::agent::agent_status_from_event; use crate::config::AgentRoleConfig; @@ -7,6 +8,7 @@ use crate::config::Config; use crate::config::ConfigBuilder; use crate::context::ContextualUserFragment; use crate::context::SubagentNotification; +use crate::init_state_db; use assert_matches::assert_matches; use codex_features::Feature; use codex_login::CodexAuth; @@ -84,6 +86,7 @@ fn spawn_agent_call(call_id: &str) -> ResponseItem { struct AgentControlHarness { _home: TempDir, config: Config, + state_db: Option, manager: ThreadManager, control: AgentControl, } @@ -91,16 +94,19 @@ struct AgentControlHarness { impl AgentControlHarness { async fn new() -> Self { let (home, config) = test_config().await; - let manager = ThreadManager::with_models_provider_and_home_for_tests( + let state_db = init_state_db(&config).await; + let manager = ThreadManager::with_models_provider_home_and_state_for_tests( CodexAuth::from_api_key("dummy"), config.model_provider.clone(), config.codex_home.to_path_buf(), std::sync::Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), + state_db.clone(), ); let control = manager.agent_control(); Self { _home: home, config, + state_db, manager, control, } @@ -1537,16 +1543,19 @@ async fn resume_thread_subagent_restores_stored_nickname_and_role() { .features .enable(Feature::Sqlite) .expect("test config should allow sqlite"); - let manager = ThreadManager::with_models_provider_and_home_for_tests( + let state_db = init_state_db(&config).await; + let manager = ThreadManager::with_models_provider_home_and_state_for_tests( CodexAuth::from_api_key("dummy"), config.model_provider.clone(), config.codex_home.to_path_buf(), std::sync::Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), + state_db.clone(), ); let control = manager.agent_control(); let harness = AgentControlHarness { _home: home, config, + state_db, manager, control, }; @@ -1695,7 +1704,10 @@ async fn resume_agent_from_rollout_reads_archived_rollout_path() { .shutdown_live_agent(child_thread_id) .await .expect("child shutdown should succeed"); - let store = LocalThreadStore::new(LocalThreadStoreConfig::from_config(&harness.config)); + let store = LocalThreadStore::new( + LocalThreadStoreConfig::from_config(&harness.config), + harness.state_db.clone(), + ); store .archive_thread(ArchiveThreadParams { thread_id: child_thread_id, diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 721feee834..cf909c5e81 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -145,7 +145,7 @@ pub(crate) mod shell_snapshot; pub mod spawn; pub(crate) mod state_db_bridge; pub use state_db_bridge::StateDbHandle; -pub use state_db_bridge::get_state_db; +pub use state_db_bridge::init_state_db; mod thread_rollout_truncation; mod tools; pub(crate) mod turn_diff_tracker; diff --git a/codex-rs/core/src/personality_migration.rs b/codex-rs/core/src/personality_migration.rs index 5227cf07e3..975aecd4af 100644 --- a/codex-rs/core/src/personality_migration.rs +++ b/codex-rs/core/src/personality_migration.rs @@ -1,6 +1,7 @@ use crate::config::edit::ConfigEditsBuilder; use codex_config::config_toml::ConfigToml; use codex_protocol::config_types::Personality; +use codex_rollout::state_db::StateDbHandle; use codex_thread_store::ListThreadsParams; use codex_thread_store::LocalThreadStore; use codex_thread_store::LocalThreadStoreConfig; @@ -24,6 +25,7 @@ pub enum PersonalityMigrationStatus { pub async fn maybe_migrate_personality( codex_home: &Path, config_toml: &ConfigToml, + state_db: Option, ) -> io::Result { let marker_path = codex_home.join(PERSONALITY_MIGRATION_FILENAME); if tokio::fs::try_exists(&marker_path).await? { @@ -43,7 +45,7 @@ pub async fn maybe_migrate_personality( .or_else(|| config_toml.model_provider.clone()) .unwrap_or_else(|| "openai".to_string()); - if !has_recorded_sessions(codex_home, model_provider_id.as_str()).await? { + if !has_recorded_sessions(codex_home, model_provider_id.as_str(), state_db).await? { create_marker(&marker_path).await?; return Ok(PersonalityMigrationStatus::SkippedNoSessions); } @@ -60,12 +62,19 @@ pub async fn maybe_migrate_personality( Ok(PersonalityMigrationStatus::Applied) } -async fn has_recorded_sessions(codex_home: &Path, default_provider: &str) -> io::Result { - let store = LocalThreadStore::new(LocalThreadStoreConfig { - codex_home: codex_home.to_path_buf(), - sqlite_home: codex_home.to_path_buf(), - default_model_provider_id: default_provider.to_string(), - }); +async fn has_recorded_sessions( + codex_home: &Path, + default_provider: &str, + state_db: Option, +) -> io::Result { + let store = LocalThreadStore::new( + LocalThreadStoreConfig { + codex_home: codex_home.to_path_buf(), + sqlite_home: codex_home.to_path_buf(), + default_model_provider_id: default_provider.to_string(), + }, + state_db, + ); if has_threads(&store, /*archived*/ false).await? { return Ok(true); } diff --git a/codex-rs/core/src/personality_migration_tests.rs b/codex-rs/core/src/personality_migration_tests.rs index 4aef53a5c4..28d34bc3c9 100644 --- a/codex-rs/core/src/personality_migration_tests.rs +++ b/codex-rs/core/src/personality_migration_tests.rs @@ -87,7 +87,7 @@ async fn applies_when_sessions_exist_and_no_personality() -> io::Result<()> { write_session_with_user_event(temp.path()).await?; let config_toml = ConfigToml::default(); - let status = maybe_migrate_personality(temp.path(), &config_toml).await?; + let status = maybe_migrate_personality(temp.path(), &config_toml, /*state_db*/ None).await?; assert_eq!(status, PersonalityMigrationStatus::Applied); assert!(temp.path().join(PERSONALITY_MIGRATION_FILENAME).exists()); @@ -103,7 +103,7 @@ async fn applies_when_only_archived_sessions_exist_and_no_personality() -> io::R write_archived_session_with_user_event(temp.path()).await?; let config_toml = ConfigToml::default(); - let status = maybe_migrate_personality(temp.path(), &config_toml).await?; + let status = maybe_migrate_personality(temp.path(), &config_toml, /*state_db*/ None).await?; assert_eq!(status, PersonalityMigrationStatus::Applied); assert!(temp.path().join(PERSONALITY_MIGRATION_FILENAME).exists()); @@ -119,7 +119,7 @@ async fn skips_when_marker_exists() -> io::Result<()> { create_marker(&temp.path().join(PERSONALITY_MIGRATION_FILENAME)).await?; let config_toml = ConfigToml::default(); - let status = maybe_migrate_personality(temp.path(), &config_toml).await?; + let status = maybe_migrate_personality(temp.path(), &config_toml, /*state_db*/ None).await?; assert_eq!(status, PersonalityMigrationStatus::SkippedMarker); assert!(!temp.path().join("config.toml").exists()); @@ -136,7 +136,7 @@ async fn skips_when_personality_explicit() -> io::Result<()> { .map_err(|err| io::Error::other(format!("failed to write config: {err}")))?; let config_toml = read_config_toml(temp.path()).await?; - let status = maybe_migrate_personality(temp.path(), &config_toml).await?; + let status = maybe_migrate_personality(temp.path(), &config_toml, /*state_db*/ None).await?; assert_eq!( status, @@ -153,7 +153,7 @@ async fn skips_when_personality_explicit() -> io::Result<()> { async fn skips_when_no_sessions() -> io::Result<()> { let temp = TempDir::new()?; let config_toml = ConfigToml::default(); - let status = maybe_migrate_personality(temp.path(), &config_toml).await?; + let status = maybe_migrate_personality(temp.path(), &config_toml, /*state_db*/ None).await?; assert_eq!(status, PersonalityMigrationStatus::SkippedNoSessions); assert!(temp.path().join(PERSONALITY_MIGRATION_FILENAME).exists()); diff --git a/codex-rs/core/src/prompt_debug.rs b/codex-rs/core/src/prompt_debug.rs index d4f3130129..7c6144c10e 100644 --- a/codex-rs/core/src/prompt_debug.rs +++ b/codex-rs/core/src/prompt_debug.rs @@ -16,6 +16,7 @@ use crate::config::Config; use crate::session::session::Session; use crate::session::turn::build_prompt; use crate::session::turn::built_tools; +use crate::state_db_bridge::StateDbHandle; use crate::thread_manager::ThreadManager; use crate::thread_manager::thread_store_from_config; @@ -24,6 +25,7 @@ use crate::thread_manager::thread_store_from_config; pub async fn build_prompt_input( mut config: Config, input: Vec, + state_db: Option, ) -> CodexResult> { config.ephemeral = true; @@ -35,13 +37,15 @@ pub async fn build_prompt_input( config.codex_linux_sandbox_exe.clone(), )?; + let thread_store = thread_store_from_config(&config, state_db.clone()); let thread_manager = ThreadManager::new( &config, Arc::clone(&auth_manager), SessionSource::Exec, Arc::new(EnvironmentManager::new(EnvironmentManagerArgs::new(local_runtime_paths)).await), /*analytics_events_client*/ None, - thread_store_from_config(&config), + thread_store, + state_db.clone(), ); let thread = thread_manager.start_thread(config).await?; diff --git a/codex-rs/core/src/session/mod.rs b/codex-rs/core/src/session/mod.rs index cb8c21c4da..b697b8e623 100644 --- a/codex-rs/core/src/session/mod.rs +++ b/codex-rs/core/src/session/mod.rs @@ -556,7 +556,15 @@ impl Codex { }; match thread_id { Some(thread_id) => { - let state_db_ctx = state_db::get_state_db(&config).await; + let state_db_ctx = if config.ephemeral { + None + } else if let Some(local_store) = + thread_store.as_any().downcast_ref::() + { + local_store.state_db().await + } else { + None + }; state_db::get_dynamic_tools(state_db_ctx.as_deref(), thread_id, "codex_spawn") .await } @@ -1298,6 +1306,7 @@ impl Session { self.services.user_shell.as_ref().clone(), self.services.shell_snapshot_tx.clone(), self.services.session_telemetry.clone(), + self.services.state_db.clone(), ); } diff --git a/codex-rs/core/src/session/session.rs b/codex-rs/core/src/session/session.rs index ac6f766756..2444ded83f 100644 --- a/codex-rs/core/src/session/session.rs +++ b/codex-rs/core/src/session/session.rs @@ -726,6 +726,7 @@ impl Session { session_configuration.cwd.clone(), &mut default_shell, session_telemetry.clone(), + state_db_ctx.clone(), ) } } else { diff --git a/codex-rs/core/src/session/tests.rs b/codex-rs/core/src/session/tests.rs index b8554b25e9..fe38cd7f64 100644 --- a/codex-rs/core/src/session/tests.rs +++ b/codex-rs/core/src/session/tests.rs @@ -3492,6 +3492,7 @@ async fn session_new_fails_when_zsh_fork_enabled_without_zsh_path() { /*analytics_events_client*/ None, Arc::new(codex_thread_store::LocalThreadStore::new( codex_thread_store::LocalThreadStoreConfig::from_config(config.as_ref()), + /*state_db*/ None, )), codex_rollout_trace::ThreadTraceContext::disabled(), ) @@ -3639,6 +3640,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) { live_thread: None, thread_store: Arc::new(codex_thread_store::LocalThreadStore::new( codex_thread_store::LocalThreadStoreConfig::from_config(config.as_ref()), + /*state_db*/ None, )), model_client: ModelClient::new( Some(auth_manager.clone()), @@ -3810,6 +3812,7 @@ async fn make_session_with_config_and_rx( /*analytics_events_client*/ None, Arc::new(codex_thread_store::LocalThreadStore::new( codex_thread_store::LocalThreadStoreConfig::from_config(config.as_ref()), + /*state_db*/ None, )), codex_rollout_trace::ThreadTraceContext::disabled(), ) @@ -4998,10 +5001,44 @@ async fn make_session_and_context_with_auth_and_config_and_rx( where F: FnOnce(&mut Config), { - let (tx_event, rx_event) = async_channel::unbounded(); let codex_home = tempfile::tempdir().expect("create temp dir"); - let mut config = build_test_config(codex_home.path()).await; + make_session_and_context_with_auth_config_home_and_rx( + auth, + dynamic_tools, + codex_home.path(), + configure_config, + ) + .await +} + +async fn make_session_and_context_with_auth_config_home_and_rx( + auth: CodexAuth, + dynamic_tools: Vec, + codex_home: &Path, + configure_config: F, +) -> ( + Arc, + Arc, + async_channel::Receiver, +) +where + F: FnOnce(&mut Config), +{ + let (tx_event, rx_event) = async_channel::unbounded(); + let mut config = build_test_config(codex_home).await; configure_config(&mut config); + let state_db = if config.features.enabled(Feature::Goals) { + Some( + codex_state::StateRuntime::init( + config.sqlite_home.clone(), + config.model_provider_id.clone(), + ) + .await + .expect("goal tests should initialize sqlite state db"), + ) + } else { + None + }; let config = Arc::new(config); let conversation_id = ThreadId::default(); let auth_manager = AuthManager::from_auth_for_testing(auth); @@ -5127,10 +5164,11 @@ where agent_control, network_proxy: None, network_approval: Arc::clone(&network_approval), - state_db: None, + state_db: state_db.clone(), live_thread: None, thread_store: Arc::new(codex_thread_store::LocalThreadStore::new( codex_thread_store::LocalThreadStoreConfig::from_config(config.as_ref()), + state_db, )), model_client: ModelClient::new( Some(Arc::clone(&auth_manager)), @@ -5225,10 +5263,13 @@ async fn make_goal_session_and_context_with_rx() -> ( Arc, Arc, async_channel::Receiver, + tempfile::TempDir, ) { - let (session, turn_context, rx) = make_session_and_context_with_auth_and_config_and_rx( + let codex_home = tempfile::tempdir().expect("create temp dir"); + let (session, turn_context, rx) = make_session_and_context_with_auth_config_home_and_rx( CodexAuth::from_api_key("Test API Key"), Vec::new(), + codex_home.path(), |config| { config .features @@ -5238,14 +5279,14 @@ async fn make_goal_session_and_context_with_rx() -> ( ) .await; upsert_goal_test_thread(session.as_ref()).await; - (session, turn_context, rx) + (session, turn_context, rx, codex_home) } async fn upsert_goal_test_thread(session: &Session) { let config = session.get_config().await; - let state_db = goal_test_state_db(session) - .await - .expect("goal test state db should initialize"); + let state_db = session + .state_db() + .expect("goal test session should have a state db"); let mut builder = codex_state::ThreadMetadataBuilder::new( session.conversation_id, config @@ -7062,7 +7103,7 @@ async fn abort_empty_active_turn_preserves_pending_input() { #[tokio::test] async fn interrupt_accounts_active_goal_before_pausing() -> anyhow::Result<()> { - let (sess, tc, _rx) = make_goal_session_and_context_with_rx().await; + let (sess, tc, _rx, _codex_home) = make_goal_session_and_context_with_rx().await; sess.set_thread_goal( tc.as_ref(), SetGoalRequest { @@ -7316,6 +7357,9 @@ fn post_goal_token_usage() -> TokenUsage { } async fn goal_test_state_db(sess: &Session) -> anyhow::Result { + if let Some(state_db) = sess.state_db() { + return Ok(state_db); + } let config = sess.get_config().await; codex_state::StateRuntime::init(config.sqlite_home.clone(), config.model_provider_id.clone()) .await @@ -7323,7 +7367,7 @@ async fn goal_test_state_db(sess: &Session) -> anyhow::Result anyhow::Result<()> { - let (sess, tc, rx) = make_goal_session_and_context_with_rx().await; + let (sess, tc, rx, _codex_home) = make_goal_session_and_context_with_rx().await; sess.set_thread_goal( tc.as_ref(), SetGoalRequest { @@ -7423,7 +7467,7 @@ async fn budget_limited_accounting_steers_active_turn_without_aborting() -> anyh #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn external_goal_mutation_accounts_active_turn_before_status_change() -> anyhow::Result<()> { - let (sess, tc, _rx) = make_goal_session_and_context_with_rx().await; + let (sess, tc, _rx, _codex_home) = make_goal_session_and_context_with_rx().await; sess.set_thread_goal( tc.as_ref(), SetGoalRequest { @@ -7485,7 +7529,7 @@ async fn external_goal_mutation_accounts_active_turn_before_status_change() -> a #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn external_active_goal_set_marks_current_turn_for_accounting() -> anyhow::Result<()> { - let (sess, tc, _rx) = make_goal_session_and_context_with_rx().await; + let (sess, tc, _rx, _codex_home) = make_goal_session_and_context_with_rx().await; sess.spawn_task( Arc::clone(&tc), Vec::new(), @@ -8149,7 +8193,7 @@ async fn sample_rollout( #[tokio::test] async fn create_goal_tool_rejects_existing_goal() { - let (session, turn_context, _rx) = make_goal_session_and_context_with_rx().await; + let (session, turn_context, _rx, _codex_home) = make_goal_session_and_context_with_rx().await; let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); let handler = GoalHandler; @@ -8211,7 +8255,7 @@ async fn create_goal_tool_rejects_existing_goal() { #[tokio::test] async fn update_goal_tool_rejects_pausing_goal() { - let (session, turn_context, _rx) = make_goal_session_and_context_with_rx().await; + let (session, turn_context, _rx, _codex_home) = make_goal_session_and_context_with_rx().await; let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); let handler = GoalHandler; @@ -8271,7 +8315,7 @@ async fn update_goal_tool_rejects_pausing_goal() { #[tokio::test] async fn update_goal_tool_marks_goal_complete() { - let (session, turn_context, _rx) = make_goal_session_and_context_with_rx().await; + let (session, turn_context, _rx, _codex_home) = make_goal_session_and_context_with_rx().await; let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); let handler = GoalHandler; diff --git a/codex-rs/core/src/session/tests/guardian_tests.rs b/codex-rs/core/src/session/tests/guardian_tests.rs index d6a87d466a..ad7dbb1054 100644 --- a/codex-rs/core/src/session/tests/guardian_tests.rs +++ b/codex-rs/core/src/session/tests/guardian_tests.rs @@ -731,6 +731,7 @@ async fn guardian_subagent_does_not_inherit_parent_exec_policy_rules() { let skills_watcher = Arc::new(SkillsWatcher::noop()); let thread_store = Arc::new(codex_thread_store::LocalThreadStore::new( codex_thread_store::LocalThreadStoreConfig::from_config(&config), + /*state_db*/ None, )); let CodexSpawnOk { codex, .. } = Codex::spawn(CodexSpawnArgs { diff --git a/codex-rs/core/src/shell_snapshot.rs b/codex-rs/core/src/shell_snapshot.rs index 40cb4a9605..b328a977d7 100644 --- a/codex-rs/core/src/shell_snapshot.rs +++ b/codex-rs/core/src/shell_snapshot.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use std::time::Duration; use std::time::SystemTime; +use crate::StateDbHandle; use crate::rollout::list::find_thread_path_by_id_str; use crate::shell::Shell; use crate::shell::ShellType; @@ -41,6 +42,7 @@ impl ShellSnapshot { session_cwd: AbsolutePathBuf, shell: &mut Shell, session_telemetry: SessionTelemetry, + state_db: Option, ) -> watch::Sender>> { let (shell_snapshot_tx, shell_snapshot_rx) = watch::channel(None); shell.shell_snapshot = shell_snapshot_rx; @@ -52,6 +54,7 @@ impl ShellSnapshot { shell.clone(), shell_snapshot_tx.clone(), session_telemetry, + state_db, ); shell_snapshot_tx @@ -64,6 +67,7 @@ impl ShellSnapshot { shell: Shell, shell_snapshot_tx: watch::Sender>>, session_telemetry: SessionTelemetry, + state_db: Option, ) { Self::spawn_snapshot_task( codex_home, @@ -72,6 +76,7 @@ impl ShellSnapshot { shell, shell_snapshot_tx, session_telemetry, + state_db, ); } @@ -82,15 +87,21 @@ impl ShellSnapshot { snapshot_shell: Shell, shell_snapshot_tx: watch::Sender>>, session_telemetry: SessionTelemetry, + state_db: Option, ) { let snapshot_span = info_span!("shell_snapshot", thread_id = %session_id); tokio::spawn( async move { let timer = session_telemetry.start_timer("codex.shell_snapshot.duration_ms", &[]); - let snapshot = - ShellSnapshot::try_new(&codex_home, session_id, &session_cwd, &snapshot_shell) - .await - .map(Arc::new); + let snapshot = ShellSnapshot::try_new( + &codex_home, + session_id, + &session_cwd, + &snapshot_shell, + state_db, + ) + .await + .map(Arc::new); let success = snapshot.is_ok(); let success_tag = if success { "true" } else { "false" }; let _ = timer.map(|timer| timer.record(&[("success", success_tag)])); @@ -110,6 +121,7 @@ impl ShellSnapshot { session_id: ThreadId, session_cwd: &AbsolutePathBuf, shell: &Shell, + state_db: Option, ) -> std::result::Result { // File to store the snapshot let extension = match shell.shell_type { @@ -131,7 +143,9 @@ impl ShellSnapshot { let codex_home = codex_home.clone(); let cleanup_session_id = session_id; tokio::spawn(async move { - if let Err(err) = cleanup_stale_snapshots(&codex_home, cleanup_session_id).await { + if let Err(err) = + cleanup_stale_snapshots(&codex_home, cleanup_session_id, state_db).await + { tracing::warn!("Failed to clean up shell snapshots: {err:?}"); } }); @@ -486,6 +500,7 @@ $envVars | ForEach-Object { pub async fn cleanup_stale_snapshots( codex_home: &AbsolutePathBuf, active_session_id: ThreadId, + state_db: Option, ) -> Result<()> { let snapshot_dir = codex_home.join(SNAPSHOT_DIR); @@ -515,7 +530,8 @@ pub async fn cleanup_stale_snapshots( continue; } - let rollout_path = find_thread_path_by_id_str(codex_home, session_id).await?; + let rollout_path = + find_thread_path_by_id_str(codex_home, session_id, state_db.as_deref()).await?; let Some(rollout_path) = rollout_path else { remove_snapshot_file(&path).await; continue; diff --git a/codex-rs/core/src/shell_snapshot_tests.rs b/codex-rs/core/src/shell_snapshot_tests.rs index 0f1aea2021..0199347b4e 100644 --- a/codex-rs/core/src/shell_snapshot_tests.rs +++ b/codex-rs/core/src/shell_snapshot_tests.rs @@ -202,6 +202,7 @@ async fn try_new_creates_and_deletes_snapshot_file() -> Result<()> { ThreadId::new(), &dir.path().abs(), &shell, + /*state_db*/ None, ) .await .expect("snapshot should be created"); @@ -227,14 +228,24 @@ async fn try_new_uses_distinct_generation_paths() -> Result<()> { shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), }; - let initial_snapshot = - ShellSnapshot::try_new(&dir.path().abs(), session_id, &dir.path().abs(), &shell) - .await - .expect("initial snapshot should be created"); - let refreshed_snapshot = - ShellSnapshot::try_new(&dir.path().abs(), session_id, &dir.path().abs(), &shell) - .await - .expect("refreshed snapshot should be created"); + let initial_snapshot = ShellSnapshot::try_new( + &dir.path().abs(), + session_id, + &dir.path().abs(), + &shell, + /*state_db*/ None, + ) + .await + .expect("initial snapshot should be created"); + let refreshed_snapshot = ShellSnapshot::try_new( + &dir.path().abs(), + session_id, + &dir.path().abs(), + &shell, + /*state_db*/ None, + ) + .await + .expect("refreshed snapshot should be created"); let initial_path = initial_snapshot.path.clone(); let refreshed_path = refreshed_snapshot.path.clone(); @@ -428,7 +439,7 @@ async fn cleanup_stale_snapshots_removes_orphans_and_keeps_live() -> Result<()> fs::write(&orphan_snapshot, "orphan").await?; fs::write(&invalid_snapshot, "invalid").await?; - cleanup_stale_snapshots(&codex_home, ThreadId::new()).await?; + cleanup_stale_snapshots(&codex_home, ThreadId::new(), /*state_db*/ None).await?; assert_eq!(live_snapshot.exists(), true); assert_eq!(orphan_snapshot.exists(), false); @@ -451,7 +462,7 @@ async fn cleanup_stale_snapshots_removes_stale_rollouts() -> Result<()> { set_file_mtime(&rollout_path, SNAPSHOT_RETENTION + Duration::from_secs(60))?; - cleanup_stale_snapshots(&codex_home, ThreadId::new()).await?; + cleanup_stale_snapshots(&codex_home, ThreadId::new(), /*state_db*/ None).await?; assert_eq!(stale_snapshot.exists(), false); Ok(()) @@ -472,7 +483,7 @@ async fn cleanup_stale_snapshots_skips_active_session() -> Result<()> { set_file_mtime(&rollout_path, SNAPSHOT_RETENTION + Duration::from_secs(60))?; - cleanup_stale_snapshots(&codex_home, active_session).await?; + cleanup_stale_snapshots(&codex_home, active_session, /*state_db*/ None).await?; assert_eq!(active_snapshot.exists(), true); Ok(()) diff --git a/codex-rs/core/src/state_db_bridge.rs b/codex-rs/core/src/state_db_bridge.rs index c588f039d2..78d3cb11f9 100644 --- a/codex-rs/core/src/state_db_bridge.rs +++ b/codex-rs/core/src/state_db_bridge.rs @@ -3,6 +3,6 @@ pub use codex_rollout::state_db::StateDbHandle; use crate::config::Config; -pub async fn get_state_db(config: &Config) -> Option { - rollout_state_db::get_state_db(config).await +pub async fn init_state_db(config: &Config) -> Option { + rollout_state_db::init(config).await } diff --git a/codex-rs/core/src/stream_events_utils.rs b/codex-rs/core/src/stream_events_utils.rs index 8ae4374e7b..29884da4ae 100644 --- a/codex-rs/core/src/stream_events_utils.rs +++ b/codex-rs/core/src/stream_events_utils.rs @@ -138,8 +138,11 @@ pub(crate) async fn record_completed_response_item( .await; } mark_thread_memory_mode_polluted_if_external_context(sess, turn_context, item).await; - let has_memory_citation = - record_stage1_output_usage_and_detect_memory_citation(turn_context, item).await; + let has_memory_citation = record_stage1_output_usage_and_detect_memory_citation( + sess.services.state_db.as_ref(), + item, + ) + .await; if has_memory_citation { sess.record_memory_citation_for_turn(&turn_context.sub_id) .await; @@ -174,7 +177,7 @@ pub(crate) async fn mark_thread_memory_mode_polluted_if_external_context( } async fn record_stage1_output_usage_and_detect_memory_citation( - turn_context: &TurnContext, + state_db_ctx: Option<&state_db::StateDbHandle>, item: &ResponseItem, ) -> bool { let Some(raw_text) = raw_assistant_output_text_from_item(item) else { @@ -190,7 +193,7 @@ async fn record_stage1_output_usage_and_detect_memory_citation( return true; } - if let Some(db) = state_db::get_state_db(turn_context.config.as_ref()).await { + if let Some(db) = state_db_ctx { let _ = db.record_stage1_output_usage(&thread_ids).await; } true diff --git a/codex-rs/core/src/test_support.rs b/codex-rs/core/src/test_support.rs index 6dbcf7a464..48eec66c58 100644 --- a/codex-rs/core/src/test_support.rs +++ b/codex-rs/core/src/test_support.rs @@ -73,6 +73,22 @@ pub fn thread_manager_with_models_provider_and_home( ) } +pub fn thread_manager_with_models_provider_home_and_state( + auth: CodexAuth, + provider: ModelProviderInfo, + codex_home: PathBuf, + environment_manager: Arc, + state_db: Option, +) -> ThreadManager { + ThreadManager::with_models_provider_home_and_state_for_tests( + auth, + provider, + codex_home, + environment_manager, + state_db, + ) +} + pub async fn start_thread_with_user_shell_override( thread_manager: &ThreadManager, config: Config, diff --git a/codex-rs/core/src/thread_manager.rs b/codex-rs/core/src/thread_manager.rs index 927f45279e..d4e36ab3c8 100644 --- a/codex-rs/core/src/thread_manager.rs +++ b/codex-rs/core/src/thread_manager.rs @@ -50,6 +50,7 @@ use codex_protocol::protocol::TurnAbortReason; use codex_protocol::protocol::TurnAbortedEvent; use codex_protocol::protocol::TurnEnvironmentSelection; use codex_protocol::protocol::W3cTraceContext; +use codex_rollout::state_db::StateDbHandle; use codex_state::DirectionalThreadSpawnEdgeStatus; use codex_thread_store::InMemoryThreadStore; use codex_thread_store::LocalThreadStore; @@ -248,6 +249,7 @@ pub(crate) struct ThreadManagerState { thread_store: Arc, session_source: SessionSource, analytics_events_client: Option, + state_db: Option, // Captures submitted ops for testing purpose when test mode is enabled. ops_log: Option, } @@ -263,10 +265,14 @@ pub fn build_models_manager( ) } -pub fn thread_store_from_config(config: &Config) -> Arc { +pub fn thread_store_from_config( + config: &Config, + state_db: Option, +) -> Arc { match &config.experimental_thread_store { ThreadStoreConfig::Local => Arc::new(LocalThreadStore::new( LocalThreadStoreConfig::from_config(config), + state_db, )), ThreadStoreConfig::Remote { endpoint } => Arc::new(RemoteThreadStore::new(endpoint)), ThreadStoreConfig::InMemory { id } => InMemoryThreadStore::for_id(id), @@ -281,6 +287,7 @@ impl ThreadManager { environment_manager: Arc, analytics_events_client: Option, thread_store: Arc, + state_db: Option, ) -> Self { let codex_home = config.codex_home.clone(); let restriction_product = session_source.restriction_product(); @@ -310,6 +317,7 @@ impl ThreadManager { auth_manager, session_source, analytics_events_client, + state_db, ops_log: should_use_test_thread_manager_behavior() .then(|| Arc::new(std::sync::Mutex::new(Vec::new()))), }), @@ -347,6 +355,22 @@ impl ThreadManager { provider: ModelProviderInfo, codex_home: PathBuf, environment_manager: Arc, + ) -> Self { + Self::with_models_provider_home_and_state_for_tests( + auth, + provider, + codex_home, + environment_manager, + /*state_db*/ None, + ) + } + + pub(crate) fn with_models_provider_home_and_state_for_tests( + auth: CodexAuth, + provider: ModelProviderInfo, + codex_home: PathBuf, + environment_manager: Arc, + state_db: Option, ) -> Self { set_thread_manager_test_mode_for_tests(/*enabled*/ true); let auth_manager = AuthManager::from_auth_for_testing(auth); @@ -369,12 +393,14 @@ impl ThreadManager { let skills_watcher = build_skills_watcher(Arc::clone(&skills_manager)); // This test constructor has no Config input. Tests that need a non-local // process store should construct ThreadManager::new with an explicit store. - let thread_store: Arc = - Arc::new(LocalThreadStore::new(LocalThreadStoreConfig { + let thread_store: Arc = Arc::new(LocalThreadStore::new( + LocalThreadStoreConfig { codex_home: codex_home.clone(), sqlite_home: codex_home.clone(), default_model_provider_id: OPENAI_PROVIDER_ID.to_string(), - })); + }, + state_db.clone(), + )); Self { state: Arc::new(ThreadManagerState { threads: Arc::new(RwLock::new(HashMap::new())), @@ -390,6 +416,7 @@ impl ThreadManager { auth_manager, session_source: SessionSource::Exec, analytics_events_client: None, + state_db, ops_log: should_use_test_thread_manager_behavior() .then(|| Arc::new(std::sync::Mutex::new(Vec::new()))), }), @@ -837,6 +864,10 @@ impl ThreadManager { } impl ThreadManagerState { + pub(crate) fn state_db(&self) -> Option { + self.state_db.clone() + } + pub(crate) async fn list_thread_ids(&self) -> Vec { self.threads .read() diff --git a/codex-rs/core/src/thread_manager_tests.rs b/codex-rs/core/src/thread_manager_tests.rs index 6c53452184..d50a3ae8f9 100644 --- a/codex-rs/core/src/thread_manager_tests.rs +++ b/codex-rs/core/src/thread_manager_tests.rs @@ -1,5 +1,6 @@ use super::*; use crate::config::test_config; +use crate::init_state_db; use crate::rollout::RolloutRecorder; use crate::session::session::SessionSettingsUpdate; use crate::session::tests::make_session_and_context; @@ -389,7 +390,8 @@ async fn resume_and_fork_do_not_restore_thread_environments_from_rollout() { SessionSource::Exec, Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), /*analytics_events_client*/ None, - thread_store_from_config(&config), + thread_store_from_config(&config, /*state_db*/ None), + /*state_db*/ None, ); let selected_cwd = AbsolutePathBuf::try_from(config.cwd.as_path().join("selected")).expect("absolute path"); @@ -498,7 +500,8 @@ async fn resume_active_thread_from_rollout_returns_running_thread() { SessionSource::Exec, Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), /*analytics_events_client*/ None, - thread_store_from_config(&config), + thread_store_from_config(&config, /*state_db*/ None), + /*state_db*/ None, ); let source = manager @@ -551,7 +554,8 @@ async fn resume_stopped_thread_from_rollout_spawns_new_thread() { SessionSource::Exec, Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), /*analytics_events_client*/ None, - thread_store_from_config(&config), + thread_store_from_config(&config, /*state_db*/ None), + /*state_db*/ None, ); let source = manager @@ -614,7 +618,8 @@ async fn new_uses_active_provider_for_model_refresh() { SessionSource::Exec, Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), /*analytics_events_client*/ None, - thread_store_from_config(&config), + thread_store_from_config(&config, /*state_db*/ None), + /*state_db*/ None, ); let _ = manager.list_models(RefreshStrategy::Online).await; @@ -819,13 +824,15 @@ async fn interrupted_fork_snapshot_does_not_synthesize_turn_id_for_legacy_histor let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let state_db = init_state_db(&config).await; let manager = ThreadManager::new( &config, auth_manager.clone(), SessionSource::Exec, Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), /*analytics_events_client*/ None, - thread_store_from_config(&config), + thread_store_from_config(&config, state_db.clone()), + state_db.clone(), ); let source = manager @@ -921,13 +928,15 @@ async fn interrupted_fork_snapshot_preserves_explicit_turn_id() { let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let state_db = init_state_db(&config).await; let manager = ThreadManager::new( &config, auth_manager.clone(), SessionSource::Exec, Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), /*analytics_events_client*/ None, - thread_store_from_config(&config), + thread_store_from_config(&config, state_db.clone()), + state_db.clone(), ); let source = manager @@ -1012,13 +1021,15 @@ async fn interrupted_fork_snapshot_uses_persisted_mid_turn_history_without_live_ let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let state_db = init_state_db(&config).await; let manager = ThreadManager::new( &config, auth_manager.clone(), SessionSource::Exec, Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), /*analytics_events_client*/ None, - thread_store_from_config(&config), + thread_store_from_config(&config, state_db.clone()), + state_db.clone(), ); let source = manager @@ -1148,13 +1159,15 @@ async fn resumed_thread_keeps_paused_goal_paused() -> anyhow::Result<()> { let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let state_db = init_state_db(&config).await; let manager = ThreadManager::new( &config, auth_manager.clone(), SessionSource::Exec, Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), /*analytics_events_client*/ None, - thread_store_from_config(&config), + thread_store_from_config(&config, state_db.clone()), + state_db.clone(), ); let source = manager diff --git a/codex-rs/core/src/tools/handlers/multi_agents_tests.rs b/codex-rs/core/src/tools/handlers/multi_agents_tests.rs index 61dc77eb36..88c1698d8f 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents_tests.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents_tests.rs @@ -3,8 +3,10 @@ use crate::ThreadManager; use crate::config::AgentRoleConfig; use crate::config::DEFAULT_AGENT_MAX_DEPTH; use crate::function_tool::FunctionCallError; +use crate::init_state_db; use crate::session::tests::make_session_and_context; use crate::session_prefix::format_subagent_notification_message; +use crate::thread_manager::thread_store_from_config; use crate::tools::context::ToolOutput; use crate::tools::handlers::multi_agents_v2::CloseAgentHandler as CloseAgentHandlerV2; use crate::tools::handlers::multi_agents_v2::FollowupTaskHandler as FollowupTaskHandlerV2; @@ -3149,13 +3151,22 @@ async fn close_agent_submits_shutdown_and_returns_previous_status() { #[tokio::test] async fn tool_handlers_cascade_close_and_resume_and_keep_explicitly_closed_subtrees_closed() { let (_session, turn) = make_session_and_context().await; - let manager = thread_manager(); let mut config = turn.config.as_ref().clone(); config.agent_max_depth = 3; config .features .enable(Feature::Sqlite) .expect("test config should allow sqlite"); + let state_db = init_state_db(&config).await; + let manager = ThreadManager::new( + &config, + AuthManager::from_auth_for_testing(CodexAuth::from_api_key("dummy")), + SessionSource::Exec, + Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), + /*analytics_events_client*/ None, + thread_store_from_config(&config, state_db.clone()), + state_db.clone(), + ); let parent = manager .start_thread(config.clone()) diff --git a/codex-rs/core/tests/common/test_codex.rs b/codex-rs/core/tests/common/test_codex.rs index 291a0795ce..14f652f58b 100644 --- a/codex-rs/core/tests/common/test_codex.rs +++ b/codex-rs/core/tests/common/test_codex.rs @@ -423,6 +423,7 @@ impl TestCodexBuilder { environment_manager: Arc, ) -> anyhow::Result { let auth = self.auth.clone(); + let state_db = codex_core::init_state_db(&config).await; let thread_manager = if config.model_catalog.is_some() { ThreadManager::new( &config, @@ -430,14 +431,16 @@ impl TestCodexBuilder { SessionSource::Exec, Arc::clone(&environment_manager), /*analytics_events_client*/ None, - thread_store_from_config(&config), + thread_store_from_config(&config, state_db.clone()), + state_db.clone(), ) } else { - codex_core::test_support::thread_manager_with_models_provider_and_home( + codex_core::test_support::thread_manager_with_models_provider_home_and_state( auth.clone(), config.model_provider.clone(), config.codex_home.to_path_buf(), Arc::clone(&environment_manager), + state_db.clone(), ) }; let thread_manager = Arc::new(thread_manager); diff --git a/codex-rs/core/tests/suite/client.rs b/codex-rs/core/tests/suite/client.rs index 6829ea932a..afa586a158 100644 --- a/codex-rs/core/tests/suite/client.rs +++ b/codex-rs/core/tests/suite/client.rs @@ -1112,7 +1112,8 @@ async fn prefers_apikey_when_config_prefers_apikey_even_with_chatgpt_tokens() { SessionSource::Exec, Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), /*analytics_events_client*/ None, - thread_store_from_config(&config), + thread_store_from_config(&config, /*state_db*/ None), + /*state_db*/ None, ); let NewThread { thread: codex, .. } = thread_manager .start_thread(config.clone()) diff --git a/codex-rs/core/tests/suite/personality_migration.rs b/codex-rs/core/tests/suite/personality_migration.rs index f300745129..68121d2e12 100644 --- a/codex-rs/core/tests/suite/personality_migration.rs +++ b/codex-rs/core/tests/suite/personality_migration.rs @@ -141,7 +141,8 @@ async fn migration_marker_exists_no_sessions_no_change() -> io::Result<()> { let marker_path = temp.path().join(PERSONALITY_MIGRATION_FILENAME); tokio::fs::write(&marker_path, "v1\n").await?; - let status = maybe_migrate_personality(temp.path(), &ConfigToml::default()).await?; + let status = + maybe_migrate_personality(temp.path(), &ConfigToml::default(), /*state_db*/ None).await?; assert_eq!(status, PersonalityMigrationStatus::SkippedMarker); assert_eq!( @@ -155,7 +156,8 @@ async fn migration_marker_exists_no_sessions_no_change() -> io::Result<()> { async fn no_marker_no_sessions_no_change() -> io::Result<()> { let temp = TempDir::new()?; - let status = maybe_migrate_personality(temp.path(), &ConfigToml::default()).await?; + let status = + maybe_migrate_personality(temp.path(), &ConfigToml::default(), /*state_db*/ None).await?; assert_eq!(status, PersonalityMigrationStatus::SkippedNoSessions); assert_eq!( @@ -174,7 +176,8 @@ async fn no_marker_sessions_sets_personality() -> io::Result<()> { let temp = TempDir::new()?; write_session_with_user_event(temp.path()).await?; - let status = maybe_migrate_personality(temp.path(), &ConfigToml::default()).await?; + let status = + maybe_migrate_personality(temp.path(), &ConfigToml::default(), /*state_db*/ None).await?; assert_eq!(status, PersonalityMigrationStatus::Applied); assert_eq!( @@ -194,7 +197,7 @@ async fn no_marker_sessions_preserves_existing_config_fields() -> io::Result<()> tokio::fs::write(temp.path().join("config.toml"), "model = \"gpt-5.4\"\n").await?; let config_toml = read_config_toml(temp.path()).await?; - let status = maybe_migrate_personality(temp.path(), &config_toml).await?; + let status = maybe_migrate_personality(temp.path(), &config_toml, /*state_db*/ None).await?; assert_eq!(status, PersonalityMigrationStatus::Applied); let persisted = read_config_toml(temp.path()).await?; @@ -208,7 +211,8 @@ async fn no_marker_meta_only_rollout_is_treated_as_no_sessions() -> io::Result<( let temp = TempDir::new()?; write_session_with_meta_only(temp.path()).await?; - let status = maybe_migrate_personality(temp.path(), &ConfigToml::default()).await?; + let status = + maybe_migrate_personality(temp.path(), &ConfigToml::default(), /*state_db*/ None).await?; assert_eq!(status, PersonalityMigrationStatus::SkippedNoSessions); assert_eq!( @@ -228,7 +232,7 @@ async fn no_marker_explicit_global_personality_skips_migration() -> io::Result<( write_session_with_user_event(temp.path()).await?; let config_toml = parse_config_toml("personality = \"friendly\"\n")?; - let status = maybe_migrate_personality(temp.path(), &config_toml).await?; + let status = maybe_migrate_personality(temp.path(), &config_toml, /*state_db*/ None).await?; assert_eq!( status, @@ -258,7 +262,7 @@ personality = "friendly" "#, )?; - let status = maybe_migrate_personality(temp.path(), &config_toml).await?; + let status = maybe_migrate_personality(temp.path(), &config_toml, /*state_db*/ None).await?; assert_eq!( status, @@ -281,7 +285,7 @@ async fn marker_short_circuits_invalid_profile_resolution() -> io::Result<()> { tokio::fs::write(temp.path().join(PERSONALITY_MIGRATION_FILENAME), "v1\n").await?; let config_toml = parse_config_toml("profile = \"missing\"\n")?; - let status = maybe_migrate_personality(temp.path(), &config_toml).await?; + let status = maybe_migrate_personality(temp.path(), &config_toml, /*state_db*/ None).await?; assert_eq!(status, PersonalityMigrationStatus::SkippedMarker); Ok(()) @@ -292,7 +296,7 @@ async fn invalid_selected_profile_returns_error_and_does_not_write_marker() -> i let temp = TempDir::new()?; let config_toml = parse_config_toml("profile = \"missing\"\n")?; - let err = maybe_migrate_personality(temp.path(), &config_toml) + let err = maybe_migrate_personality(temp.path(), &config_toml, /*state_db*/ None) .await .expect_err("missing profile should fail"); @@ -309,8 +313,10 @@ async fn applied_migration_is_idempotent_on_second_run() -> io::Result<()> { let temp = TempDir::new()?; write_session_with_user_event(temp.path()).await?; - let first_status = maybe_migrate_personality(temp.path(), &ConfigToml::default()).await?; - let second_status = maybe_migrate_personality(temp.path(), &ConfigToml::default()).await?; + let first_status = + maybe_migrate_personality(temp.path(), &ConfigToml::default(), /*state_db*/ None).await?; + let second_status = + maybe_migrate_personality(temp.path(), &ConfigToml::default(), /*state_db*/ None).await?; assert_eq!(first_status, PersonalityMigrationStatus::Applied); assert_eq!(second_status, PersonalityMigrationStatus::SkippedMarker); @@ -324,7 +330,8 @@ async fn no_marker_archived_sessions_sets_personality() -> io::Result<()> { let temp = TempDir::new()?; write_archived_session_with_user_event(temp.path()).await?; - let status = maybe_migrate_personality(temp.path(), &ConfigToml::default()).await?; + let status = + maybe_migrate_personality(temp.path(), &ConfigToml::default(), /*state_db*/ None).await?; assert_eq!(status, PersonalityMigrationStatus::Applied); assert_eq!( diff --git a/codex-rs/core/tests/suite/prompt_debug_tests.rs b/codex-rs/core/tests/suite/prompt_debug_tests.rs index 4fee438261..dc506bc474 100644 --- a/codex-rs/core/tests/suite/prompt_debug_tests.rs +++ b/codex-rs/core/tests/suite/prompt_debug_tests.rs @@ -29,6 +29,7 @@ async fn build_prompt_input_includes_context_and_user_message() -> Result<()> { text: "hello from debug prompt".to_string(), text_elements: Vec::new(), }], + /*state_db*/ None, ) .await?; diff --git a/codex-rs/core/tests/suite/rollout_list_find.rs b/codex-rs/core/tests/suite/rollout_list_find.rs index eef0d0f5f4..d0b39a9283 100644 --- a/codex-rs/core/tests/suite/rollout_list_find.rs +++ b/codex-rs/core/tests/suite/rollout_list_find.rs @@ -14,6 +14,7 @@ use codex_core::find_thread_path_by_id_str; use codex_protocol::ThreadId; use codex_protocol::models::BaseInstructions; use codex_protocol::protocol::SessionSource; +use codex_rollout::StateDbHandle; use codex_state::StateRuntime; use codex_state::ThreadMetadataBuilder; use pretty_assertions::assert_eq; @@ -27,7 +28,13 @@ fn write_minimal_rollout_with_id_in_subdir(codex_home: &Path, subdir: &str, id: std::fs::create_dir_all(&sessions).unwrap(); let file = sessions.join(format!("rollout-2024-01-01T00-00-00-{id}.jsonl")); - let mut f = std::fs::File::create(&file).unwrap(); + write_minimal_rollout_with_id_at_path(&file, id); + + file +} + +fn write_minimal_rollout_with_id_at_path(file: &Path, id: Uuid) { + let mut f = std::fs::File::create(file).unwrap(); // Minimal first line: session_meta with the id so content search can find it writeln!( f, @@ -46,8 +53,6 @@ fn write_minimal_rollout_with_id_in_subdir(codex_home: &Path, subdir: &str, id: }) ) .unwrap(); - - file } /// Create sessions/YYYY/MM/DD and write a minimal rollout file containing the @@ -56,7 +61,11 @@ fn write_minimal_rollout_with_id(codex_home: &Path, id: Uuid) -> PathBuf { write_minimal_rollout_with_id_in_subdir(codex_home, "sessions", id) } -async fn upsert_thread_metadata(codex_home: &Path, thread_id: ThreadId, rollout_path: PathBuf) { +async fn upsert_thread_metadata( + codex_home: &Path, + thread_id: ThreadId, + rollout_path: PathBuf, +) -> StateDbHandle { let runtime = StateRuntime::init(codex_home.to_path_buf(), "test-provider".to_string()) .await .unwrap(); @@ -73,6 +82,7 @@ async fn upsert_thread_metadata(codex_home: &Path, thread_id: ThreadId, rollout_ builder.cwd = codex_home.to_path_buf(); let metadata = builder.build("test-provider"); runtime.upsert_thread(&metadata).await.unwrap(); + runtime } #[tokio::test] @@ -81,9 +91,10 @@ async fn find_locates_rollout_file_by_id() { let id = Uuid::new_v4(); let expected = write_minimal_rollout_with_id(home.path(), id); - let found = find_thread_path_by_id_str(home.path(), &id.to_string()) - .await - .unwrap(); + let found = + find_thread_path_by_id_str(home.path(), &id.to_string(), /*state_db_ctx*/ None) + .await + .unwrap(); assert_eq!(found.unwrap(), expected); } @@ -97,9 +108,10 @@ async fn find_handles_gitignore_covering_codex_home_directory() { let id = Uuid::new_v4(); let expected = write_minimal_rollout_with_id(&codex_home, id); - let found = find_thread_path_by_id_str(&codex_home, &id.to_string()) - .await - .unwrap(); + let found = + find_thread_path_by_id_str(&codex_home, &id.to_string(), /*state_db_ctx*/ None) + .await + .unwrap(); assert_eq!(found, Some(expected)); } @@ -113,11 +125,11 @@ async fn find_prefers_sqlite_path_by_id() { "sessions/2030/12/30/rollout-2030-12-30T00-00-00-{id}.jsonl" )); std::fs::create_dir_all(db_path.parent().unwrap()).unwrap(); - std::fs::write(&db_path, "").unwrap(); + write_minimal_rollout_with_id_at_path(&db_path, id); write_minimal_rollout_with_id(home.path(), id); - upsert_thread_metadata(home.path(), thread_id, db_path.clone()).await; + let state_db = upsert_thread_metadata(home.path(), thread_id, db_path.clone()).await; - let found = find_thread_path_by_id_str(home.path(), &id.to_string()) + let found = find_thread_path_by_id_str(home.path(), &id.to_string(), Some(&state_db)) .await .unwrap(); @@ -134,9 +146,9 @@ async fn find_falls_back_to_filesystem_when_sqlite_has_no_match() { let unrelated_path = home .path() .join("sessions/2030/12/30/rollout-2030-12-30T00-00-00-unrelated.jsonl"); - upsert_thread_metadata(home.path(), unrelated_thread_id, unrelated_path).await; + let state_db = upsert_thread_metadata(home.path(), unrelated_thread_id, unrelated_path).await; - let found = find_thread_path_by_id_str(home.path(), &id.to_string()) + let found = find_thread_path_by_id_str(home.path(), &id.to_string(), Some(&state_db)) .await .unwrap(); @@ -150,9 +162,10 @@ async fn find_ignores_granular_gitignore_rules() { let expected = write_minimal_rollout_with_id(home.path(), id); std::fs::write(home.path().join("sessions/.gitignore"), "*.jsonl\n").unwrap(); - let found = find_thread_path_by_id_str(home.path(), &id.to_string()) - .await - .unwrap(); + let found = + find_thread_path_by_id_str(home.path(), &id.to_string(), /*state_db_ctx*/ None) + .await + .unwrap(); assert_eq!(found, Some(expected)); } @@ -197,7 +210,8 @@ async fn find_locates_rollout_file_written_by_recorder() -> std::io::Result<()> ), )?; - let found = find_thread_meta_by_name_str(home.path(), thread_name).await?; + let found = + find_thread_meta_by_name_str(home.path(), thread_name, /*state_db_ctx*/ None).await?; let (path, session_meta) = found.expect("expected rollout path to be found"); assert_eq!(session_meta.meta.id, thread_id); @@ -214,9 +228,13 @@ async fn find_archived_locates_rollout_file_by_id() { let id = Uuid::new_v4(); let expected = write_minimal_rollout_with_id_in_subdir(home.path(), "archived_sessions", id); - let found = find_archived_thread_path_by_id_str(home.path(), &id.to_string()) - .await - .unwrap(); + let found = find_archived_thread_path_by_id_str( + home.path(), + &id.to_string(), + /*state_db_ctx*/ None, + ) + .await + .unwrap(); assert_eq!(found, Some(expected)); } diff --git a/codex-rs/core/tests/suite/skills.rs b/codex-rs/core/tests/suite/skills.rs index a68af6a1e2..add1ee6ea7 100644 --- a/codex-rs/core/tests/suite/skills.rs +++ b/codex-rs/core/tests/suite/skills.rs @@ -246,7 +246,8 @@ async fn list_skills_skips_cwd_roots_when_environment_disabled() -> Result<()> { )?, )), /*analytics_events_client*/ None, - thread_store_from_config(&config), + thread_store_from_config(&config, /*state_db*/ None), + /*state_db*/ None, ); let new_thread = thread_manager.start_thread(config.clone()).await?; let cwd = config.cwd.to_path_buf(); diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index f2f0ed030b..d61346f1d0 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -57,6 +57,7 @@ use codex_cloud_requirements::cloud_requirements_loader_for_storage; use codex_config::ConfigLoadError; use codex_config::LoaderOverrides; use codex_config::format_config_error_with_source; +use codex_core::StateDbHandle; use codex_core::check_execpolicy_for_warnings; use codex_core::config::Config; use codex_core::config::ConfigBuilder; @@ -194,6 +195,7 @@ impl RequestIdSequencer { struct ExecRunArgs { in_process_start_args: InProcessClientStartArgs, + state_db: Option, command: Option, config: Config, dangerously_bypass_approvals_and_sandbox: bool, @@ -503,6 +505,7 @@ pub async fn run_main(cli: Cli, arg0_paths: Arg0DispatchPaths) -> anyhow::Result arg0_paths.codex_self_exe.clone(), arg0_paths.codex_linux_sandbox_exe.clone(), )?; + let state_db = codex_core::init_state_db(&config).await; let in_process_start_args = InProcessClientStartArgs { arg0_paths, config: std::sync::Arc::new(config.clone()), @@ -511,6 +514,7 @@ pub async fn run_main(cli: Cli, arg0_paths: Arg0DispatchPaths) -> anyhow::Result cloud_requirements: run_cloud_requirements, feedback: CodexFeedback::new(), log_db: None, + state_db: state_db.clone(), environment_manager: std::sync::Arc::new( EnvironmentManager::new(EnvironmentManagerArgs::new(local_runtime_paths)).await, ), @@ -525,6 +529,7 @@ pub async fn run_main(cli: Cli, arg0_paths: Arg0DispatchPaths) -> anyhow::Result }; run_exec_session(ExecRunArgs { in_process_start_args, + state_db, command, config, dangerously_bypass_approvals_and_sandbox, @@ -546,6 +551,7 @@ pub async fn run_main(cli: Cli, arg0_paths: Arg0DispatchPaths) -> anyhow::Result async fn run_exec_session(args: ExecRunArgs) -> anyhow::Result<()> { let ExecRunArgs { in_process_start_args, + state_db, command, config, dangerously_bypass_approvals_and_sandbox, @@ -672,7 +678,9 @@ async fn run_exec_session(args: ExecRunArgs) -> anyhow::Result<()> { let (primary_thread_id, fallback_session_configured) = if let Some(ExecCommand::Resume(args)) = command.as_ref() { - if let Some(thread_id) = resolve_resume_thread_id(&client, &config, args).await? { + if let Some(thread_id) = + resolve_resume_thread_id(&client, &config, state_db.as_ref(), args).await? + { let response: ThreadResumeResponse = send_request_with_response( &client, ClientRequest::ThreadResume { @@ -1318,6 +1326,7 @@ fn cwds_match(current_cwd: &Path, session_cwd: &Path) -> bool { async fn resolve_resume_thread_id( client: &InProcessAppServerClient, config: &Config, + state_db: Option<&StateDbHandle>, args: &crate::cli::ResumeArgs, ) -> anyhow::Result> { let model_providers = resume_lookup_model_providers(config, args); @@ -1365,7 +1374,7 @@ async fn resolve_resume_thread_id( if Uuid::parse_str(session_id).is_ok() { return Ok(Some(session_id.to_string())); } - if let Some(state_db) = codex_core::get_state_db(config).await { + if let Some(state_db) = state_db { let cwd = (!args.all).then_some(config.cwd.as_path()); let resolved = state_db .find_thread_by_exact_title( @@ -1380,7 +1389,8 @@ async fn resolve_resume_thread_id( return Ok(Some(thread.id.to_string())); } if let Some((_, session_meta)) = - find_thread_meta_by_name_str(&config.codex_home, session_id).await? + find_thread_meta_by_name_str(&config.codex_home, session_id, Some(state_db.as_ref())) + .await? && (args.all || cwds_match(config.cwd.as_path(), &session_meta.meta.cwd)) { return Ok(Some(session_meta.meta.id.to_string())); diff --git a/codex-rs/mcp-server/src/lib.rs b/codex-rs/mcp-server/src/lib.rs index bb54ffcc53..75372c6e7c 100644 --- a/codex-rs/mcp-server/src/lib.rs +++ b/codex-rs/mcp-server/src/lib.rs @@ -83,6 +83,7 @@ pub async fn run_main( std::io::Error::new(ErrorKind::InvalidData, format!("error loading config: {e}")) })?; set_default_client_residency_requirement(config.enforce_residency.value()); + let state_db = codex_core::init_state_db(&config).await; let otel = codex_core::otel_init::build_provider( &config, @@ -144,6 +145,7 @@ pub async fn run_main( arg0_paths, Arc::new(config), environment_manager, + state_db, ) .await; async move { diff --git a/codex-rs/mcp-server/src/message_processor.rs b/codex-rs/mcp-server/src/message_processor.rs index 99076650cc..9ce46dfeda 100644 --- a/codex-rs/mcp-server/src/message_processor.rs +++ b/codex-rs/mcp-server/src/message_processor.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::sync::Arc; use codex_arg0::Arg0DispatchPaths; +use codex_core::StateDbHandle; use codex_core::ThreadManager; use codex_core::config::Config; use codex_core::thread_store_from_config; @@ -53,6 +54,7 @@ impl MessageProcessor { arg0_paths: Arg0DispatchPaths, config: Arc, environment_manager: Arc, + state_db: Option, ) -> Self { let outgoing = Arc::new(outgoing); let auth_manager = AuthManager::shared_from_config( @@ -66,7 +68,8 @@ impl MessageProcessor { SessionSource::Mcp, environment_manager, /*analytics_events_client*/ None, - thread_store_from_config(config.as_ref()), + thread_store_from_config(config.as_ref(), state_db.clone()), + state_db.clone(), )); Self { outgoing, diff --git a/codex-rs/rollout/src/list.rs b/codex-rs/rollout/src/list.rs index bdb7198835..8ba63a713e 100644 --- a/codex-rs/rollout/src/list.rs +++ b/codex-rs/rollout/src/list.rs @@ -1239,6 +1239,7 @@ async fn find_thread_path_by_id_str_in_subdir( codex_home: &Path, subdir: &str, id_str: &str, + state_db_ctx: Option<&codex_state::StateRuntime>, ) -> io::Result> { // Validate UUID format early. if Uuid::parse_str(id_str).is_err() { @@ -1253,8 +1254,8 @@ async fn find_thread_path_by_id_str_in_subdir( _ => None, }; let thread_id = ThreadId::from_string(id_str).ok(); - let state_db_ctx = state_db::open_if_present(codex_home, "").await; - if let Some(state_db_ctx) = state_db_ctx.as_deref() + let mut unverified_db_path = None; + if let Some(state_db_ctx) = state_db_ctx && let Some(thread_id) = thread_id && let Some(db_path) = state_db::find_rollout_path_by_id( Some(state_db_ctx), @@ -1265,21 +1266,43 @@ async fn find_thread_path_by_id_str_in_subdir( .await { if tokio::fs::try_exists(&db_path).await.unwrap_or(false) { - return Ok(Some(db_path)); + match read_session_meta_line(&db_path).await { + Ok(meta_line) if meta_line.meta.id == thread_id => { + return Ok(Some(db_path)); + } + Ok(meta_line) => { + tracing::error!( + "state db returned rollout path for thread {id_str} but file belongs to thread {}: {}", + meta_line.meta.id, + db_path.display() + ); + tracing::warn!( + "state db discrepancy during find_thread_path_by_id_str_in_subdir: mismatched_db_path" + ); + } + Err(err) => { + tracing::debug!( + "state db returned rollout path for thread {id_str} that could not be verified: {}: {err}", + db_path.display() + ); + unverified_db_path = Some(db_path); + } + } + } else { + tracing::error!( + "state db returned stale rollout path for thread {id_str}: {}", + db_path.display() + ); + tracing::warn!( + "state db discrepancy during find_thread_path_by_id_str_in_subdir: stale_db_path" + ); } - tracing::error!( - "state db returned stale rollout path for thread {id_str}: {}", - db_path.display() - ); - tracing::warn!( - "state db discrepancy during find_thread_path_by_id_str_in_subdir: stale_db_path" - ); } let mut root = codex_home.to_path_buf(); root.push(subdir); if !root.exists() { - return Ok(None); + return Ok(unverified_db_path); } // This is safe because we know the values are valid. #[allow(clippy::unwrap_used)] @@ -1301,7 +1324,7 @@ async fn find_thread_path_by_id_str_in_subdir( "state db discrepancy during find_thread_path_by_id_str_in_subdir: falling_back" ); state_db::read_repair_rollout_path( - state_db_ctx.as_deref(), + state_db_ctx, thread_id, archived_only, found_path.as_path(), @@ -1309,7 +1332,7 @@ async fn find_thread_path_by_id_str_in_subdir( .await; } - Ok(found) + Ok(found.or(unverified_db_path)) } /// Locate a recorded thread rollout file by its UUID string using the existing @@ -1318,16 +1341,19 @@ async fn find_thread_path_by_id_str_in_subdir( pub async fn find_thread_path_by_id_str( codex_home: &Path, id_str: &str, + state_db_ctx: Option<&codex_state::StateRuntime>, ) -> io::Result> { - find_thread_path_by_id_str_in_subdir(codex_home, SESSIONS_SUBDIR, id_str).await + find_thread_path_by_id_str_in_subdir(codex_home, SESSIONS_SUBDIR, id_str, state_db_ctx).await } /// Locate an archived thread rollout file by its UUID string. pub async fn find_archived_thread_path_by_id_str( codex_home: &Path, id_str: &str, + state_db_ctx: Option<&codex_state::StateRuntime>, ) -> io::Result> { - find_thread_path_by_id_str_in_subdir(codex_home, ARCHIVED_SESSIONS_SUBDIR, id_str).await + find_thread_path_by_id_str_in_subdir(codex_home, ARCHIVED_SESSIONS_SUBDIR, id_str, state_db_ctx) + .await } /// Extract the `YYYY/MM/DD` directory components from a rollout filename. diff --git a/codex-rs/rollout/src/metadata.rs b/codex-rs/rollout/src/metadata.rs index e7a25f0cda..2dd2df3a41 100644 --- a/codex-rs/rollout/src/metadata.rs +++ b/codex-rs/rollout/src/metadata.rs @@ -136,6 +136,21 @@ pub(crate) async fn backfill_sessions( runtime: &codex_state::StateRuntime, codex_home: &Path, default_provider: &str, +) { + backfill_sessions_with_lease( + runtime, + codex_home, + default_provider, + BACKFILL_LEASE_SECONDS, + ) + .await; +} + +pub(crate) async fn backfill_sessions_with_lease( + runtime: &codex_state::StateRuntime, + codex_home: &Path, + default_provider: &str, + backfill_lease_seconds: i64, ) { let metric_client = codex_otel::global(); let timer = metric_client @@ -154,7 +169,7 @@ pub(crate) async fn backfill_sessions( if backfill_state.status == BackfillStatus::Complete { return; } - let claimed = match runtime.try_claim_backfill(BACKFILL_LEASE_SECONDS).await { + let claimed = match runtime.try_claim_backfill(backfill_lease_seconds).await { Ok(claimed) => claimed, Err(err) => { warn!( diff --git a/codex-rs/rollout/src/recorder.rs b/codex-rs/rollout/src/recorder.rs index ea0c7e3e36..512c223263 100644 --- a/codex-rs/rollout/src/recorder.rs +++ b/codex-rs/rollout/src/recorder.rs @@ -79,7 +79,6 @@ pub struct RolloutRecorder { tx: Sender, writer_task: Arc, pub(crate) rollout_path: PathBuf, - state_db: Option, event_persistence_mode: EventPersistenceMode, } @@ -230,6 +229,7 @@ impl RolloutRecorder { /// List threads (rollout files) under the provided Codex home directory. #[allow(clippy::too_many_arguments)] pub async fn list_threads( + state_db_ctx: Option, config: &impl RolloutConfigView, page_size: usize, cursor: Option<&Cursor>, @@ -242,6 +242,7 @@ impl RolloutRecorder { search_term: Option<&str>, ) -> std::io::Result { Self::list_threads_with_db_fallback( + state_db_ctx, config, page_size, cursor, @@ -260,6 +261,7 @@ impl RolloutRecorder { #[allow(clippy::too_many_arguments)] pub async fn list_threads_from_state_db( + state_db_ctx: Option, config: &impl RolloutConfigView, page_size: usize, cursor: Option<&Cursor>, @@ -272,6 +274,7 @@ impl RolloutRecorder { search_term: Option<&str>, ) -> std::io::Result { Self::list_threads_with_db_fallback( + state_db_ctx, config, page_size, cursor, @@ -291,6 +294,7 @@ impl RolloutRecorder { /// List archived threads (rollout files) under the archived sessions directory. #[allow(clippy::too_many_arguments)] pub async fn list_archived_threads( + state_db_ctx: Option, config: &impl RolloutConfigView, page_size: usize, cursor: Option<&Cursor>, @@ -303,6 +307,7 @@ impl RolloutRecorder { search_term: Option<&str>, ) -> std::io::Result { Self::list_threads_with_db_fallback( + state_db_ctx, config, page_size, cursor, @@ -321,6 +326,7 @@ impl RolloutRecorder { #[allow(clippy::too_many_arguments)] pub async fn list_archived_threads_from_state_db( + state_db_ctx: Option, config: &impl RolloutConfigView, page_size: usize, cursor: Option<&Cursor>, @@ -333,6 +339,7 @@ impl RolloutRecorder { search_term: Option<&str>, ) -> std::io::Result { Self::list_threads_with_db_fallback( + state_db_ctx, config, page_size, cursor, @@ -351,6 +358,7 @@ impl RolloutRecorder { #[allow(clippy::too_many_arguments)] async fn list_threads_with_db_fallback( + state_db_ctx: Option, config: &impl RolloutConfigView, page_size: usize, cursor: Option<&Cursor>, @@ -365,7 +373,6 @@ impl RolloutRecorder { search_term: Option<&str>, ) -> std::io::Result { let codex_home = config.codex_home(); - let state_db_ctx = state_db::get_state_db(config).await; let archived = match archive_filter { ThreadListArchiveFilter::Active => false, ThreadListArchiveFilter::Archived => true, @@ -575,6 +582,7 @@ impl RolloutRecorder { /// Find the newest recorded thread path, optionally filtering to a matching cwd. #[allow(clippy::too_many_arguments)] pub async fn find_latest_thread_path( + state_db_ctx: Option, config: &impl RolloutConfigView, page_size: usize, cursor: Option<&Cursor>, @@ -585,7 +593,6 @@ impl RolloutRecorder { filter_cwd: Option<&Path>, ) -> std::io::Result> { let codex_home = config.codex_home(); - let state_db_ctx = state_db::get_state_db(config).await; let cwd_filter = filter_cwd.map(Path::to_path_buf); if state_db_ctx.is_some() { let mut db_cursor = cursor.cloned(); @@ -770,7 +777,6 @@ impl RolloutRecorder { tx, writer_task, rollout_path, - state_db: state_db_ctx, event_persistence_mode, }) } @@ -779,10 +785,6 @@ impl RolloutRecorder { self.rollout_path.as_path() } - pub fn state_db(&self) -> Option { - self.state_db.clone() - } - pub async fn record_items(&self, items: &[RolloutItem]) -> std::io::Result<()> { let mut filtered = Vec::new(); for item in items { diff --git a/codex-rs/rollout/src/recorder_tests.rs b/codex-rs/rollout/src/recorder_tests.rs index 0138db72df..5711c47bad 100644 --- a/codex-rs/rollout/src/recorder_tests.rs +++ b/codex-rs/rollout/src/recorder_tests.rs @@ -3,6 +3,7 @@ use super::*; use crate::config::RolloutConfig; use chrono::TimeZone; +use codex_protocol::ThreadId; use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; use codex_protocol::models::ResponseItem; use codex_protocol::protocol::AgentMessageEvent; @@ -11,6 +12,9 @@ use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::RolloutItem; use codex_protocol::protocol::RolloutLine; use codex_protocol::protocol::SandboxPolicy; +use codex_protocol::protocol::SessionMeta; +use codex_protocol::protocol::SessionMetaLine; +use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::TurnContextItem; use codex_protocol::protocol::UserMessageEvent; use pretty_assertions::assert_eq; @@ -65,6 +69,77 @@ fn write_session_file(root: &Path, ts: &str, uuid: Uuid) -> std::io::Result anyhow::Result<()> { + let home = TempDir::new().expect("temp dir"); + let uuid = Uuid::new_v4(); + let thread_id = ThreadId::from_string(&uuid.to_string())?; + let rollout_path = home.path().join(format!( + "sessions/2026/01/27/rollout-2026-01-27T12-34-56-{uuid}.jsonl" + )); + let parent = rollout_path + .parent() + .expect("rollout path should have parent"); + fs::create_dir_all(parent)?; + + let session_meta_line = SessionMetaLine { + meta: SessionMeta { + id: thread_id, + forked_from_id: None, + timestamp: "2026-01-27T12:34:56Z".to_string(), + cwd: home.path().to_path_buf(), + originator: "test".to_string(), + cli_version: "test".to_string(), + source: SessionSource::Cli, + agent_path: None, + agent_nickname: None, + agent_role: None, + model_provider: None, + base_instructions: None, + dynamic_tools: None, + memory_mode: None, + }, + git: None, + }; + let lines = [ + RolloutLine { + timestamp: "2026-01-27T12:34:56Z".to_string(), + item: RolloutItem::SessionMeta(session_meta_line), + }, + RolloutLine { + timestamp: "2026-01-27T12:34:57Z".to_string(), + item: RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { + message: "hello from startup backfill".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + })), + }, + ]; + let jsonl = lines + .iter() + .map(serde_json::to_string) + .collect::, _>>()? + .join("\n"); + fs::write(&rollout_path, format!("{jsonl}\n"))?; + + let runtime = crate::state_db::init(&test_config(home.path())) + .await + .expect("state db should initialize"); + + let metadata = runtime + .get_thread(thread_id) + .await? + .expect("thread should be backfilled before init returns"); + assert_eq!(metadata.rollout_path, rollout_path); + assert_eq!( + runtime.get_backfill_state().await?.status, + codex_state::BackfillStatus::Complete + ); + + Ok(()) +} + #[tokio::test] async fn load_rollout_items_skips_legacy_ghost_snapshot_lines() -> std::io::Result<()> { let home = TempDir::new().expect("temp dir"); @@ -526,6 +601,7 @@ async fn list_threads_db_disabled_does_not_skip_paginated_items() -> std::io::Re let default_provider = config.model_provider_id.clone(); let page1 = RolloutRecorder::list_threads( + /*state_db_ctx*/ None, &config, /*page_size*/ 1, /*cursor*/ None, @@ -543,6 +619,7 @@ async fn list_threads_db_disabled_does_not_skip_paginated_items() -> std::io::Re let cursor = page1.next_cursor.clone().expect("cursor should be present"); let page2 = RolloutRecorder::list_threads( + /*state_db_ctx*/ None, &config, /*page_size*/ 1, Some(&cursor), @@ -602,6 +679,7 @@ async fn list_threads_db_enabled_drops_missing_rollout_paths() -> std::io::Resul let default_provider = config.model_provider_id.clone(); let page = RolloutRecorder::list_threads( + Some(runtime.clone()), &config, /*page_size*/ 10, /*cursor*/ None, @@ -666,6 +744,7 @@ async fn list_threads_db_enabled_repairs_stale_rollout_paths() -> std::io::Resul let default_provider = config.model_provider_id.clone(); let page = RolloutRecorder::list_threads( + Some(runtime.clone()), &config, /*page_size*/ 1, /*cursor*/ None, @@ -738,6 +817,7 @@ async fn list_threads_state_db_only_skips_jsonl_repair_scan() -> std::io::Result let cwd_filters = [home.path().to_path_buf()]; let state_db_only_page = RolloutRecorder::list_threads_from_state_db( + Some(runtime.clone()), &config, /*page_size*/ 10, /*cursor*/ None, @@ -753,6 +833,7 @@ async fn list_threads_state_db_only_skips_jsonl_repair_scan() -> std::io::Result assert_eq!(state_db_only_page.items.len(), 0); let repaired_page = RolloutRecorder::list_threads( + Some(runtime.clone()), &config, /*page_size*/ 10, /*cursor*/ None, @@ -768,6 +849,7 @@ async fn list_threads_state_db_only_skips_jsonl_repair_scan() -> std::io::Result assert_eq!(repaired_page.items.len(), 1); let repaired_state_db_only_page = RolloutRecorder::list_threads_from_state_db( + Some(runtime.clone()), &config, /*page_size*/ 10, /*cursor*/ None, @@ -825,6 +907,7 @@ async fn list_threads_default_filter_returns_filesystem_scan_results() -> std::i let cwd_filters = [stale_cwd]; let state_db_only_page = RolloutRecorder::list_threads_from_state_db( + Some(runtime.clone()), &config, /*page_size*/ 10, /*cursor*/ None, @@ -840,6 +923,7 @@ async fn list_threads_default_filter_returns_filesystem_scan_results() -> std::i assert_eq!(state_db_only_page.items.len(), 1); let scanned_page = RolloutRecorder::list_threads( + Some(runtime.clone()), &config, /*page_size*/ 10, /*cursor*/ None, @@ -855,6 +939,7 @@ async fn list_threads_default_filter_returns_filesystem_scan_results() -> std::i assert_eq!(scanned_page.items.len(), 0); let repaired_state_db_only_page = RolloutRecorder::list_threads_from_state_db( + Some(runtime.clone()), &config, /*page_size*/ 10, /*cursor*/ None, @@ -913,6 +998,7 @@ async fn list_threads_metadata_filter_overlays_state_db_list_metadata() -> std:: .expect("state db upsert should succeed"); let page = RolloutRecorder::list_threads( + Some(runtime.clone()), &config, /*page_size*/ 10, /*cursor*/ None, @@ -1039,6 +1125,7 @@ async fn list_threads_search_repairs_stale_state_db_hits_before_returning() -> s .expect("state db upsert should succeed"); let stale_state_db_only_page = RolloutRecorder::list_threads_from_state_db( + Some(runtime.clone()), &config, /*page_size*/ 10, /*cursor*/ None, @@ -1054,6 +1141,7 @@ async fn list_threads_search_repairs_stale_state_db_hits_before_returning() -> s assert_eq!(stale_state_db_only_page.items.len(), 1); let scanned_page = RolloutRecorder::list_threads( + Some(runtime.clone()), &config, /*page_size*/ 10, /*cursor*/ None, @@ -1069,6 +1157,7 @@ async fn list_threads_search_repairs_stale_state_db_hits_before_returning() -> s assert_eq!(scanned_page.items.len(), 0); let repaired_state_db_only_page = RolloutRecorder::list_threads_from_state_db( + Some(runtime.clone()), &config, /*page_size*/ 10, /*cursor*/ None, diff --git a/codex-rs/rollout/src/session_index.rs b/codex-rs/rollout/src/session_index.rs index 039ce27f24..e227515357 100644 --- a/codex-rs/rollout/src/session_index.rs +++ b/codex-rs/rollout/src/session_index.rs @@ -117,6 +117,7 @@ pub async fn find_thread_names_by_ids( pub async fn find_thread_meta_by_name_str( codex_home: &Path, name: &str, + state_db_ctx: Option<&codex_state::StateRuntime>, ) -> std::io::Result> { if name.trim().is_empty() { return Ok(None); @@ -135,8 +136,12 @@ pub async fn find_thread_meta_by_name_str( while let Some(thread_id) = rx.recv().await { // Keep walking until a matching id resolves to a loadable rollout so an unsaved or partial // rename cannot shadow an older persisted session with the same name. - if let Some(path) = - super::list::find_thread_path_by_id_str(codex_home, &thread_id.to_string()).await? + if let Some(path) = super::list::find_thread_path_by_id_str( + codex_home, + &thread_id.to_string(), + state_db_ctx, + ) + .await? && let Ok(session_meta) = super::list::read_session_meta_line(&path).await { drop(rx); diff --git a/codex-rs/rollout/src/session_index_tests.rs b/codex-rs/rollout/src/session_index_tests.rs index c6a539fb28..fbef7eb4f9 100644 --- a/codex-rs/rollout/src/session_index_tests.rs +++ b/codex-rs/rollout/src/session_index_tests.rs @@ -99,7 +99,7 @@ async fn find_thread_meta_by_name_str_skips_newest_entry_without_rollout() -> st ]; write_index(&path, &lines)?; - let found = find_thread_meta_by_name_str(temp.path(), "same").await?; + let found = find_thread_meta_by_name_str(temp.path(), "same", /*state_db_ctx*/ None).await?; assert_eq!( found.map(|(path, session_meta)| (path, session_meta.meta.id)), @@ -136,7 +136,7 @@ async fn find_thread_meta_by_name_str_skips_partial_rollout() -> std::io::Result ]; write_index(&path, &lines)?; - let found = find_thread_meta_by_name_str(temp.path(), "same").await?; + let found = find_thread_meta_by_name_str(temp.path(), "same", /*state_db_ctx*/ None).await?; assert_eq!(found.map(|(path, _)| path), Some(saved_rollout_path)); Ok(()) @@ -174,7 +174,7 @@ async fn find_thread_meta_by_name_str_ignores_historical_name_after_rename() -> ]; write_index(&path, &lines)?; - let found = find_thread_meta_by_name_str(temp.path(), "same").await?; + let found = find_thread_meta_by_name_str(temp.path(), "same", /*state_db_ctx*/ None).await?; assert_eq!(found.map(|(path, _)| path), Some(current_rollout_path)); Ok(()) diff --git a/codex-rs/rollout/src/state_db.rs b/codex-rs/rollout/src/state_db.rs index 41b59c9760..d039e16d68 100644 --- a/codex-rs/rollout/src/state_db.rs +++ b/codex-rs/rollout/src/state_db.rs @@ -17,15 +17,52 @@ use serde_json::Value; use std::path::Path; use std::path::PathBuf; use std::sync::Arc; +use std::time::Duration; +use std::time::Instant; +use tracing::info; use tracing::warn; /// Core-facing handle to the SQLite-backed state runtime. pub type StateDbHandle = Arc; -/// Initialize the state runtime for thread state persistence and backfill checks. +#[cfg(not(test))] +const STARTUP_BACKFILL_POLL_INTERVAL: Duration = Duration::from_secs(1); +#[cfg(test)] +const STARTUP_BACKFILL_POLL_INTERVAL: Duration = Duration::from_millis(10); +#[cfg(not(test))] +const STARTUP_BACKFILL_WAIT_TIMEOUT: Duration = Duration::from_secs(30); +#[cfg(test)] +const STARTUP_BACKFILL_WAIT_TIMEOUT: Duration = Duration::from_secs(2); + +/// Initialize the state runtime for thread state persistence. +/// +/// This is the process entry point for local state: it opens the SQLite-backed +/// runtime, applies rollout metadata backfills as needed, and returns the +/// initialized handle. pub async fn init(config: &impl RolloutConfigView) -> Option { let config = RolloutConfig::from_view(config); - init_with_roots( + match try_init_with_roots( + config.codex_home, + config.sqlite_home, + config.model_provider_id, + ) + .await + { + Ok(runtime) => Some(runtime), + Err(err) => { + emit_startup_warning(&format!("failed to initialize state runtime: {err}")); + None + } + } +} + +/// Initialize the state runtime and return any initialization error to the caller. +/// +/// Prefer [`init`] unless the caller needs to surface the exact failure after +/// tracing or UI setup has completed. +pub async fn try_init(config: &impl RolloutConfigView) -> anyhow::Result { + let config = RolloutConfig::from_view(config); + try_init_with_roots( config.codex_home, config.sqlite_home, config.model_provider_id, @@ -33,52 +70,128 @@ pub async fn init(config: &impl RolloutConfigView) -> Option { .await } -/// Initialize the state runtime for a local thread store. -pub async fn init_with_roots( +async fn try_init_with_roots( codex_home: PathBuf, sqlite_home: PathBuf, default_model_provider_id: String, -) -> Option { - let runtime = match codex_state::StateRuntime::init( - sqlite_home.clone(), - default_model_provider_id.clone(), +) -> anyhow::Result { + try_init_with_roots_inner( + codex_home, + sqlite_home, + default_model_provider_id, + /*backfill_lease_seconds*/ None, ) .await - { - Ok(runtime) => runtime, - Err(err) => { - warn!( - "failed to initialize state runtime at {}: {err}", - sqlite_home.display() - ); - return None; - } - }; - let backfill_state = match runtime.get_backfill_state().await { - Ok(state) => state, - Err(err) => { - warn!( +} + +#[cfg(test)] +async fn try_init_with_roots_and_backfill_lease( + codex_home: PathBuf, + sqlite_home: PathBuf, + default_model_provider_id: String, + backfill_lease_seconds: i64, +) -> anyhow::Result { + try_init_with_roots_inner( + codex_home, + sqlite_home, + default_model_provider_id, + Some(backfill_lease_seconds), + ) + .await +} + +async fn try_init_with_roots_inner( + codex_home: PathBuf, + sqlite_home: PathBuf, + default_model_provider_id: String, + backfill_lease_seconds: Option, +) -> anyhow::Result { + let runtime = + codex_state::StateRuntime::init(sqlite_home.clone(), default_model_provider_id.clone()) + .await + .map_err(|err| { + anyhow::anyhow!( + "failed to initialize state runtime at {}: {err}", + sqlite_home.display() + ) + })?; + let wait_started = Instant::now(); + let mut reported_wait = false; + loop { + let backfill_state = runtime.get_backfill_state().await.map_err(|err| { + anyhow::anyhow!( "failed to read backfill state at {}: {err}", codex_home.display() - ); - return None; + ) + })?; + if backfill_state.status == codex_state::BackfillStatus::Complete { + return Ok(runtime); } - }; - if backfill_state.status != codex_state::BackfillStatus::Complete { - let runtime_for_backfill = runtime.clone(); - tokio::spawn(async move { + + if let Some(backfill_lease_seconds) = backfill_lease_seconds { + metadata::backfill_sessions_with_lease( + runtime.as_ref(), + codex_home.as_path(), + default_model_provider_id.as_str(), + backfill_lease_seconds, + ) + .await; + } else { metadata::backfill_sessions( - runtime_for_backfill.as_ref(), + runtime.as_ref(), codex_home.as_path(), default_model_provider_id.as_str(), ) .await; - }); + } + let backfill_state = runtime.get_backfill_state().await.map_err(|err| { + anyhow::anyhow!( + "failed to read backfill state at {} after startup backfill: {err}", + codex_home.display() + ) + })?; + if backfill_state.status == codex_state::BackfillStatus::Complete { + return Ok(runtime); + } + if wait_started.elapsed() >= STARTUP_BACKFILL_WAIT_TIMEOUT { + return Err(anyhow::anyhow!( + "timed out waiting for state db backfill at {} after {:?} (status: {})", + codex_home.display(), + STARTUP_BACKFILL_WAIT_TIMEOUT, + backfill_state.status.as_str() + )); + } + + let message = format!( + "state db backfill is {} at {}; waiting up to {:?} before retrying startup initialization", + backfill_state.status.as_str(), + codex_home.display(), + STARTUP_BACKFILL_WAIT_TIMEOUT, + ); + if reported_wait { + info!("{message}"); + } else { + emit_startup_warning(&message); + reported_wait = true; + } + tokio::time::sleep(STARTUP_BACKFILL_POLL_INTERVAL).await; } - Some(runtime) } -/// Get the DB if the feature is enabled and the DB exists. +fn emit_startup_warning(message: &str) { + warn!("{message}"); + if !tracing::dispatcher::has_been_set() { + #[allow(clippy::print_stderr)] + { + eprintln!("{message}"); + } + } +} + +/// Open the DB if it exists and its startup backfill has already completed. +/// +/// Unlike [`init`], this helper does not run rollout backfill. It is for +/// optional local reads from non-owning contexts such as remote app-server mode. pub async fn get_state_db(config: &impl RolloutConfigView) -> Option { let state_path = codex_state::state_db_path(config.sqlite_home()); if !tokio::fs::try_exists(&state_path).await.unwrap_or(false) { @@ -93,21 +206,6 @@ pub async fn get_state_db(config: &impl RolloutConfigView) -> Option Option { - let db_path = codex_state::state_db_path(codex_home); - if !tokio::fs::try_exists(&db_path).await.unwrap_or(false) { - return None; - } - let runtime = - codex_state::StateRuntime::init(codex_home.to_path_buf(), default_provider.to_string()) - .await - .ok()?; - require_backfill_complete(runtime, codex_home).await -} - async fn require_backfill_complete( runtime: StateDbHandle, codex_home: &Path, diff --git a/codex-rs/rollout/src/state_db_tests.rs b/codex-rs/rollout/src/state_db_tests.rs index a4e59db9d0..10a9a3da13 100644 --- a/codex-rs/rollout/src/state_db_tests.rs +++ b/codex-rs/rollout/src/state_db_tests.rs @@ -7,6 +7,7 @@ use chrono::NaiveDateTime; use chrono::Timelike; use chrono::Utc; use pretty_assertions::assert_eq; +use tempfile::TempDir; #[test] fn cursor_to_anchor_normalizes_timestamp_format() { @@ -22,3 +23,64 @@ fn cursor_to_anchor_normalizes_timestamp_format() { assert_eq!(anchor.ts, expected_ts); } + +#[tokio::test] +async fn try_init_waits_for_concurrent_startup_backfill() -> anyhow::Result<()> { + let home = TempDir::new().expect("temp dir"); + let runtime = + codex_state::StateRuntime::init(home.path().to_path_buf(), "test-provider".to_string()) + .await?; + let claimed = runtime.try_claim_backfill(/*lease_seconds*/ 60).await?; + assert!(claimed); + let runtime_for_completion = runtime.clone(); + let complete_backfill = tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(25)).await; + runtime_for_completion + .mark_backfill_complete(/*last_watermark*/ None) + .await + }); + + let initialized = try_init_with_roots_and_backfill_lease( + home.path().to_path_buf(), + home.path().to_path_buf(), + "test-provider".to_string(), + /*backfill_lease_seconds*/ 60, + ) + .await?; + complete_backfill.await??; + assert_eq!( + initialized.get_backfill_state().await?.status, + codex_state::BackfillStatus::Complete + ); + + Ok(()) +} + +#[tokio::test] +async fn try_init_times_out_waiting_for_stuck_startup_backfill() -> anyhow::Result<()> { + let home = TempDir::new().expect("temp dir"); + let runtime = + codex_state::StateRuntime::init(home.path().to_path_buf(), "test-provider".to_string()) + .await?; + let claimed = runtime.try_claim_backfill(/*lease_seconds*/ 60).await?; + assert!(claimed); + + let result = try_init_with_roots_and_backfill_lease( + home.path().to_path_buf(), + home.path().to_path_buf(), + "test-provider".to_string(), + /*backfill_lease_seconds*/ 60, + ) + .await; + let err = match result { + Ok(_) => panic!("state db init should not wait forever for incomplete backfill"), + Err(err) => err, + }; + assert!( + err.to_string() + .contains("timed out waiting for state db backfill"), + "unexpected error: {err}" + ); + + Ok(()) +} diff --git a/codex-rs/rollout/src/tests.rs b/codex-rs/rollout/src/tests.rs index fba8a9827a..b5c2790dae 100644 --- a/codex-rs/rollout/src/tests.rs +++ b/codex-rs/rollout/src/tests.rs @@ -58,7 +58,7 @@ async fn insert_state_db_thread( thread_id: ThreadId, rollout_path: &Path, archived: bool, -) { +) -> crate::state_db::StateDbHandle { let runtime = codex_state::StateRuntime::init(home.to_path_buf(), TEST_PROVIDER.to_string()) .await .expect("state db should initialize"); @@ -87,6 +87,7 @@ async fn insert_state_db_thread( .upsert_thread(&metadata) .await .expect("state db upsert should succeed"); + runtime } // TODO(jif) fix @@ -236,7 +237,7 @@ async fn find_thread_path_falls_back_when_db_path_is_stale() { let stale_db_path = home.join(format!( "sessions/2099/01/01/rollout-2099-01-01T00-00-00-{uuid}.jsonl" )); - insert_state_db_thread( + let runtime = insert_state_db_thread( home, thread_id, stale_db_path.as_path(), @@ -244,7 +245,52 @@ async fn find_thread_path_falls_back_when_db_path_is_stale() { ) .await; - let found = find_thread_path_by_id_str(home, &uuid.to_string()) + let found = find_thread_path_by_id_str(home, &uuid.to_string(), Some(runtime.as_ref())) + .await + .expect("lookup should succeed"); + assert_eq!(found, Some(fs_rollout_path.clone())); + assert_state_db_rollout_path(home, thread_id, Some(fs_rollout_path.as_path())).await; +} + +#[tokio::test] +async fn find_thread_path_falls_back_when_db_path_points_to_another_thread() { + let temp = TempDir::new().unwrap(); + let home = temp.path(); + let uuid = Uuid::from_u128(304); + let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); + let ts = "2025-01-03T13-00-00"; + write_session_file( + home, + ts, + uuid, + /*num_records*/ 1, + Some(SessionSource::Cli), + ) + .unwrap(); + let fs_rollout_path = home.join(format!("sessions/2025/01/03/rollout-{ts}-{uuid}.jsonl")); + + let other_uuid = Uuid::from_u128(1304); + let other_ts = "2025-01-04T13-00-00"; + write_session_file( + home, + other_ts, + other_uuid, + /*num_records*/ 1, + Some(SessionSource::Cli), + ) + .unwrap(); + let stale_db_path = home.join(format!( + "sessions/2025/01/04/rollout-{other_ts}-{other_uuid}.jsonl" + )); + let runtime = insert_state_db_thread( + home, + thread_id, + stale_db_path.as_path(), + /*archived*/ false, + ) + .await; + + let found = find_thread_path_by_id_str(home, &uuid.to_string(), Some(runtime.as_ref())) .await .expect("lookup should succeed"); assert_eq!(found, Some(fs_rollout_path.clone())); @@ -269,21 +315,44 @@ async fn find_thread_path_repairs_missing_db_row_after_filesystem_fallback() { let fs_rollout_path = home.join(format!("sessions/2025/01/03/rollout-{ts}-{uuid}.jsonl")); // Create an empty state DB so lookup takes the DB-first path and then falls back to files. - let _runtime = codex_state::StateRuntime::init(home.to_path_buf(), TEST_PROVIDER.to_string()) + let runtime = codex_state::StateRuntime::init(home.to_path_buf(), TEST_PROVIDER.to_string()) .await .expect("state db should initialize"); - _runtime + runtime .mark_backfill_complete(/*last_watermark*/ None) .await .expect("backfill should be complete"); - let found = find_thread_path_by_id_str(home, &uuid.to_string()) + let found = find_thread_path_by_id_str(home, &uuid.to_string(), Some(runtime.as_ref())) .await .expect("lookup should succeed"); assert_eq!(found, Some(fs_rollout_path.clone())); assert_state_db_rollout_path(home, thread_id, Some(fs_rollout_path.as_path())).await; } +#[tokio::test] +async fn find_thread_path_accepts_existing_state_db_path_without_canonical_filename() { + let temp = TempDir::new().unwrap(); + let home = temp.path(); + let uuid = Uuid::from_u128(305); + let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); + let db_rollout_path = home.join("sessions/2025/01/03/custom-rollout-name.jsonl"); + fs::create_dir_all(db_rollout_path.parent().expect("rollout parent")).unwrap(); + fs::write(&db_rollout_path, "").unwrap(); + let runtime = insert_state_db_thread( + home, + thread_id, + db_rollout_path.as_path(), + /*archived*/ false, + ) + .await; + + let found = find_thread_path_by_id_str(home, &uuid.to_string(), Some(runtime.as_ref())) + .await + .expect("lookup should succeed"); + assert_eq!(found, Some(db_rollout_path)); +} + #[test] fn rollout_date_parts_extracts_directory_components() { let file_name = OsStr::new("rollout-2025-03-01T09-00-00-123.jsonl"); diff --git a/codex-rs/thread-manager-sample/src/main.rs b/codex-rs/thread-manager-sample/src/main.rs index cc2262512d..7064d9e4ce 100644 --- a/codex-rs/thread-manager-sample/src/main.rs +++ b/codex-rs/thread-manager-sample/src/main.rs @@ -54,6 +54,7 @@ use codex_core_api::WebSearchMode; use codex_core_api::arg0_dispatch_or_else; use codex_core_api::built_in_model_providers; use codex_core_api::find_codex_home; +use codex_core_api::init_state_db; use codex_core_api::item_event_to_server_notification; use codex_core_api::set_default_originator; use codex_core_api::thread_store_from_config; @@ -102,6 +103,7 @@ async fn run_main(arg0_paths: Arg0DispatchPaths) -> anyhow::Result<()> { }; let config = new_config(args.model, arg0_paths)?; + let state_db = init_state_db(&config).await; let auth_manager = AuthManager::shared_from_config(&config, /*enable_codex_api_key_env*/ false).await; @@ -109,7 +111,7 @@ async fn run_main(arg0_paths: Arg0DispatchPaths) -> anyhow::Result<()> { config.codex_self_exe.clone(), config.codex_linux_sandbox_exe.clone(), )?; - let thread_store = thread_store_from_config(&config); + let thread_store = thread_store_from_config(&config, state_db.clone()); let environment_manager = Arc::new(EnvironmentManager::new(EnvironmentManagerArgs::new(local_runtime_paths)).await); let thread_manager = ThreadManager::new( @@ -119,6 +121,7 @@ async fn run_main(arg0_paths: Arg0DispatchPaths) -> anyhow::Result<()> { environment_manager, /*analytics_events_client*/ None, Arc::clone(&thread_store), + state_db, ); let NewThread { diff --git a/codex-rs/thread-store/src/local/archive_thread.rs b/codex-rs/thread-store/src/local/archive_thread.rs index 5df1d5b761..8fb214e98c 100644 --- a/codex-rs/thread-store/src/local/archive_thread.rs +++ b/codex-rs/thread-store/src/local/archive_thread.rs @@ -13,15 +13,19 @@ pub(super) async fn archive_thread( params: ArchiveThreadParams, ) -> ThreadStoreResult<()> { let thread_id = params.thread_id; - let rollout_path = - find_thread_path_by_id_str(store.config.codex_home.as_path(), &thread_id.to_string()) - .await - .map_err(|err| ThreadStoreError::InvalidRequest { - message: format!("failed to locate thread id {thread_id}: {err}"), - })? - .ok_or_else(|| ThreadStoreError::InvalidRequest { - message: format!("no rollout found for thread id {thread_id}"), - })?; + let state_db_ctx = store.state_db().await; + let rollout_path = find_thread_path_by_id_str( + store.config.codex_home.as_path(), + &thread_id.to_string(), + state_db_ctx.as_deref(), + ) + .await + .map_err(|err| ThreadStoreError::InvalidRequest { + message: format!("failed to locate thread id {thread_id}: {err}"), + })? + .ok_or_else(|| ThreadStoreError::InvalidRequest { + message: format!("no rollout found for thread id {thread_id}"), + })?; let canonical_rollout_path = scoped_rollout_path( store.config.codex_home.join(codex_rollout::SESSIONS_SUBDIR), @@ -48,7 +52,7 @@ pub(super) async fn archive_thread( } })?; - if let Some(ctx) = store.state_db().await { + if let Some(ctx) = state_db_ctx { let _ = ctx .mark_archived(thread_id, archived_path.as_path(), Utc::now()) .await; @@ -77,7 +81,7 @@ mod tests { #[tokio::test] async fn archive_thread_moves_rollout_to_archived_collection() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = Uuid::from_u128(201); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let active_path = @@ -123,7 +127,6 @@ mod tests { async fn archive_thread_updates_sqlite_metadata_when_present() { let home = TempDir::new().expect("temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config.clone()); let uuid = Uuid::from_u128(202); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let active_path = @@ -134,6 +137,7 @@ mod tests { ) .await .expect("state db should initialize"); + let store = LocalThreadStore::new(config.clone(), Some(runtime.clone())); runtime .mark_backfill_complete(/*last_watermark*/ None) .await diff --git a/codex-rs/thread-store/src/local/list_threads.rs b/codex-rs/thread-store/src/local/list_threads.rs index 037bd25085..e470ad2be9 100644 --- a/codex-rs/thread-store/src/local/list_threads.rs +++ b/codex-rs/thread-store/src/local/list_threads.rs @@ -39,6 +39,7 @@ pub(super) async fn list_threads( SortDirection::Asc => codex_rollout::SortDirection::Asc, SortDirection::Desc => codex_rollout::SortDirection::Desc, }; + let state_db = store.state_db().await; let rollout_config = RolloutConfig { codex_home: store.config.codex_home.clone(), sqlite_home: store.config.sqlite_home.clone(), @@ -47,6 +48,7 @@ pub(super) async fn list_threads( generate_memories: false, }; let page = list_rollout_threads( + state_db, &rollout_config, store.config.default_model_provider_id.as_str(), ¶ms, @@ -106,6 +108,7 @@ pub(super) async fn list_threads( } async fn list_rollout_threads( + state_db: Option, config: &RolloutConfig, default_model_provider_id: &str, params: &ListThreadsParams, @@ -115,6 +118,7 @@ async fn list_rollout_threads( ) -> ThreadStoreResult { let page = if params.use_state_db_only && params.archived { RolloutRecorder::list_archived_threads_from_state_db( + state_db, config, params.page_size, cursor, @@ -129,6 +133,7 @@ async fn list_rollout_threads( .await } else if params.use_state_db_only { RolloutRecorder::list_threads_from_state_db( + state_db, config, params.page_size, cursor, @@ -143,6 +148,7 @@ async fn list_rollout_threads( .await } else if params.archived { RolloutRecorder::list_archived_threads( + state_db, config, params.page_size, cursor, @@ -157,6 +163,7 @@ async fn list_rollout_threads( .await } else { RolloutRecorder::list_threads( + state_db, config, params.page_size, cursor, @@ -196,7 +203,7 @@ mod tests { #[tokio::test] async fn list_threads_uses_default_provider_when_rollout_omits_provider() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); write_session_file_with( home.path(), home.path().join("sessions/2025/01/03"), @@ -231,7 +238,6 @@ mod tests { async fn list_threads_preserves_sqlite_title_search_results() { let home = TempDir::new().expect("temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config.clone()); let uuid = Uuid::from_u128(103); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let rollout_path = home.path().join("rollout-title-search.jsonl"); @@ -243,6 +249,7 @@ mod tests { ) .await .expect("state db should initialize"); + let store = LocalThreadStore::new(config.clone(), Some(runtime.clone())); runtime .mark_backfill_complete(/*last_watermark*/ None) .await @@ -296,7 +303,7 @@ mod tests { #[tokio::test] async fn list_threads_selects_active_or_archived_collection() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let active_uuid = Uuid::from_u128(105); let archived_uuid = Uuid::from_u128(106); write_session_file(home.path(), "2025-01-03T12-00-00", active_uuid) @@ -365,7 +372,7 @@ mod tests { async fn list_threads_returns_local_rollout_summary() { let home = TempDir::new().expect("temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config); + let store = LocalThreadStore::new(config, /*state_db*/ None); let uuid = Uuid::from_u128(101); let path = write_session_file(home.path(), "2025-01-03T12-00-00", uuid).expect("session file"); @@ -404,7 +411,7 @@ mod tests { #[tokio::test] async fn list_threads_rejects_invalid_cursor() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let err = store .list_threads(ListThreadsParams { diff --git a/codex-rs/thread-store/src/local/mod.rs b/codex-rs/thread-store/src/local/mod.rs index 04dd8b2490..058c7b2309 100644 --- a/codex-rs/thread-store/src/local/mod.rs +++ b/codex-rs/thread-store/src/local/mod.rs @@ -19,7 +19,6 @@ use std::collections::hash_map::Entry; use std::path::PathBuf; use std::sync::Arc; use tokio::sync::Mutex; -use tokio::sync::OnceCell; use crate::AppendThreadItemsParams; use crate::ArchiveThreadParams; @@ -42,7 +41,7 @@ use crate::UpdateThreadMetadataParams; pub struct LocalThreadStore { pub(super) config: LocalThreadStoreConfig, live_recorders: Arc>>, - state_db: Arc>, + state_db: Option, } /// Process-scoped configuration for local thread storage. @@ -76,30 +75,18 @@ impl std::fmt::Debug for LocalThreadStore { } impl LocalThreadStore { - /// Create a local store from process-scoped local storage configuration. - pub fn new(config: LocalThreadStoreConfig) -> Self { + /// Create a local store using an already initialized state DB handle. + pub fn new(config: LocalThreadStoreConfig, state_db: Option) -> Self { Self { config, live_recorders: Arc::new(Mutex::new(HashMap::new())), - state_db: Arc::new(OnceCell::new()), + state_db, } } /// Return the state DB handle used by local rollout writers. pub async fn state_db(&self) -> Option { - self.state_db - .get_or_try_init(|| async { - codex_rollout::state_db::init_with_roots( - self.config.codex_home.clone(), - self.config.sqlite_home.clone(), - self.config.default_model_provider_id.clone(), - ) - .await - .ok_or(()) - }) - .await - .ok() - .cloned() + self.state_db.clone() } /// Read a local rollout-backed thread by path. @@ -302,7 +289,7 @@ mod tests { #[tokio::test] async fn live_writer_lifecycle_writes_and_closes() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let thread_id = ThreadId::default(); store @@ -351,7 +338,7 @@ mod tests { #[tokio::test] async fn create_thread_rejects_missing_cwd() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let thread_id = ThreadId::default(); let mut params = create_thread_params(thread_id); params.metadata.cwd = None; @@ -371,7 +358,7 @@ mod tests { #[tokio::test] async fn discard_thread_drops_unmaterialized_live_writer() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let thread_id = ThreadId::default(); store @@ -410,7 +397,7 @@ mod tests { let config = test_config(home.path()); let thread_id = ThreadId::default(); - let first_store = LocalThreadStore::new(config.clone()); + let first_store = LocalThreadStore::new(config.clone(), /*state_db*/ None); first_store .create_thread(create_thread_params(thread_id)) .await @@ -439,7 +426,7 @@ mod tests { .await .expect("shutdown initial writer"); - let resumed_store = LocalThreadStore::new(config); + let resumed_store = LocalThreadStore::new(config, /*state_db*/ None); resumed_store .resume_thread(ResumeThreadParams { thread_id, @@ -470,7 +457,7 @@ mod tests { #[tokio::test] async fn create_thread_rejects_duplicate_live_writer() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let thread_id = ThreadId::default(); store @@ -490,7 +477,7 @@ mod tests { #[tokio::test] async fn resume_thread_rejects_duplicate_live_writer() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let thread_id = ThreadId::default(); store @@ -519,7 +506,7 @@ mod tests { #[tokio::test] async fn resume_thread_rejects_missing_cwd() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = uuid::Uuid::from_u128(407); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let rollout_path = @@ -548,7 +535,7 @@ mod tests { async fn load_history_uses_live_writer_rollout_path() { let home = TempDir::new().expect("temp dir"); let external_home = TempDir::new().expect("external temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = uuid::Uuid::from_u128(404); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let rollout_path = write_session_file(external_home.path(), "2025-01-04T10-00-00", uuid) @@ -597,7 +584,7 @@ mod tests { async fn read_thread_uses_live_writer_rollout_path_for_external_resume() { let home = TempDir::new().expect("temp dir"); let external_home = TempDir::new().expect("external temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = uuid::Uuid::from_u128(406); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let rollout_path = write_session_file(external_home.path(), "2025-01-04T11-00-00", uuid) @@ -636,7 +623,7 @@ mod tests { #[tokio::test] async fn load_history_uses_live_writer_rollout_path_for_archived_source() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = uuid::Uuid::from_u128(405); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let rollout_path = write_archived_session_file(home.path(), "2025-01-04T10-30-00", uuid) @@ -704,7 +691,7 @@ mod tests { #[tokio::test] async fn read_thread_by_rollout_path_includes_history() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let thread_id = ThreadId::default(); store diff --git a/codex-rs/thread-store/src/local/read_thread.rs b/codex-rs/thread-store/src/local/read_thread.rs index 8b3d3160db..097277705f 100644 --- a/codex-rs/thread-store/src/local/read_thread.rs +++ b/codex-rs/thread-store/src/local/read_thread.rs @@ -10,7 +10,6 @@ use codex_rollout::find_thread_name_by_id; use codex_rollout::find_thread_path_by_id_str; use codex_rollout::read_session_meta_line; use codex_rollout::read_thread_item_from_rollout; -use codex_state::StateRuntime; use codex_state::ThreadMetadata; use super::LocalThreadStore; @@ -172,16 +171,22 @@ async fn resolve_rollout_path( return Ok(Some(path)); } + let state_db_ctx = store.state_db().await; if include_archived { - match find_thread_path_by_id_str(store.config.codex_home.as_path(), &thread_id.to_string()) - .await - .map_err(|err| ThreadStoreError::InvalidRequest { - message: format!("failed to locate thread id {thread_id}: {err}"), - })? { + match find_thread_path_by_id_str( + store.config.codex_home.as_path(), + &thread_id.to_string(), + state_db_ctx.as_deref(), + ) + .await + .map_err(|err| ThreadStoreError::InvalidRequest { + message: format!("failed to locate thread id {thread_id}: {err}"), + })? { Some(path) => Ok(Some(path)), None => find_archived_thread_path_by_id_str( store.config.codex_home.as_path(), &thread_id.to_string(), + state_db_ctx.as_deref(), ) .await .map_err(|err| ThreadStoreError::InvalidRequest { @@ -189,11 +194,15 @@ async fn resolve_rollout_path( }), } } else { - find_thread_path_by_id_str(store.config.codex_home.as_path(), &thread_id.to_string()) - .await - .map_err(|err| ThreadStoreError::InvalidRequest { - message: format!("failed to locate thread id {thread_id}: {err}"), - }) + find_thread_path_by_id_str( + store.config.codex_home.as_path(), + &thread_id.to_string(), + state_db_ctx.as_deref(), + ) + .await + .map_err(|err| ThreadStoreError::InvalidRequest { + message: format!("failed to locate thread id {thread_id}: {err}"), + }) } } @@ -246,12 +255,7 @@ async fn read_sqlite_metadata( store: &LocalThreadStore, thread_id: codex_protocol::ThreadId, ) -> Option { - let runtime = StateRuntime::init( - store.config.sqlite_home.clone(), - store.config.default_model_provider_id.clone(), - ) - .await - .ok()?; + let runtime = store.state_db().await?; runtime.get_thread(thread_id).await.ok().flatten() } @@ -411,7 +415,7 @@ mod tests { #[tokio::test] async fn read_thread_returns_active_rollout_summary() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = Uuid::from_u128(205); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let active_path = @@ -439,7 +443,7 @@ mod tests { #[tokio::test] async fn read_thread_returns_rollout_path_summary() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = Uuid::from_u128(211); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let active_path = @@ -470,7 +474,6 @@ mod tests { async fn read_thread_by_rollout_path_prefers_sqlite_git_info() { let home = TempDir::new().expect("temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config.clone()); let uuid = Uuid::from_u128(223); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let active_path = @@ -481,6 +484,7 @@ mod tests { ) .await .expect("state db should initialize"); + let store = LocalThreadStore::new(config.clone(), Some(runtime.clone())); let mut builder = ThreadMetadataBuilder::new( thread_id, active_path.clone(), @@ -518,7 +522,7 @@ mod tests { #[tokio::test] async fn read_thread_returns_archived_rollout_when_requested() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = Uuid::from_u128(207); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let archived_path = write_archived_session_file(home.path(), "2025-01-03T12-00-00", uuid) @@ -559,7 +563,7 @@ mod tests { #[tokio::test] async fn read_thread_prefers_active_rollout_over_archived() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = Uuid::from_u128(208); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let active_path = @@ -584,7 +588,7 @@ mod tests { #[tokio::test] async fn read_thread_returns_forked_from_id() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = Uuid::from_u128(209); let parent_uuid = Uuid::from_u128(210); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); @@ -617,7 +621,6 @@ mod tests { async fn read_thread_applies_sqlite_thread_name() { let home = TempDir::new().expect("temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config.clone()); let uuid = Uuid::from_u128(212); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let rollout_path = @@ -628,6 +631,7 @@ mod tests { ) .await .expect("state db should initialize"); + let store = LocalThreadStore::new(config.clone(), Some(runtime.clone())); let mut builder = ThreadMetadataBuilder::new(thread_id, rollout_path, Utc::now(), SessionSource::Cli); builder.model_provider = Some(config.default_model_provider_id.clone()); @@ -657,7 +661,13 @@ mod tests { async fn read_thread_preserves_rollout_cwd_when_sqlite_metadata_exists() { let home = TempDir::new().expect("temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config.clone()); + let runtime = codex_state::StateRuntime::init( + config.sqlite_home.clone(), + config.default_model_provider_id.clone(), + ) + .await + .expect("state db should initialize"); + let store = LocalThreadStore::new(config.clone(), Some(runtime.clone())); let uuid = Uuid::from_u128(224); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let day_dir = home.path().join("sessions/2025/01/03"); @@ -690,12 +700,6 @@ mod tests { }); writeln!(file, "{user_event}").expect("write user event"); - let runtime = codex_state::StateRuntime::init( - config.sqlite_home.clone(), - config.default_model_provider_id.clone(), - ) - .await - .expect("state db should initialize"); let mut builder = ThreadMetadataBuilder::new( thread_id, rollout_path.clone(), @@ -732,7 +736,7 @@ mod tests { #[tokio::test] async fn read_thread_uses_legacy_thread_name_when_sqlite_title_is_missing() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = Uuid::from_u128(213); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); write_session_file(home.path(), "2025-01-03T12-00-00", uuid).expect("session file"); @@ -756,7 +760,6 @@ mod tests { async fn read_thread_uses_sqlite_metadata_for_rollout_without_user_preview() { let home = TempDir::new().expect("temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config.clone()); let uuid = Uuid::from_u128(217); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let day_dir = home.path().join("sessions/2025/01/03"); @@ -784,6 +787,7 @@ mod tests { ) .await .expect("state db should initialize"); + let store = LocalThreadStore::new(config.clone(), Some(runtime.clone())); let mut builder = ThreadMetadataBuilder::new( thread_id, rollout_path.clone(), @@ -826,7 +830,6 @@ mod tests { let home = TempDir::new().expect("temp dir"); let external = TempDir::new().expect("external temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config.clone()); let uuid = Uuid::from_u128(220); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let rollout_path = @@ -838,6 +841,7 @@ mod tests { ) .await .expect("state db should initialize"); + let store = LocalThreadStore::new(config.clone(), Some(runtime.clone())); let mut builder = ThreadMetadataBuilder::new( thread_id, stale_path.clone(), @@ -875,7 +879,6 @@ mod tests { let home = TempDir::new().expect("temp dir"); let external = TempDir::new().expect("external temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config.clone()); let uuid = Uuid::from_u128(221); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let rollout_path = @@ -889,6 +892,7 @@ mod tests { ) .await .expect("state db should initialize"); + let store = LocalThreadStore::new(config.clone(), Some(runtime.clone())); let mut builder = ThreadMetadataBuilder::new(thread_id, stale_path, Utc::now(), SessionSource::Cli); builder.model_provider = Some("wrong-sqlite-provider".to_string()); @@ -920,7 +924,7 @@ mod tests { #[tokio::test] async fn read_thread_uses_session_meta_for_rollout_without_user_preview_or_sqlite_metadata() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = Uuid::from_u128(218); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let day_dir = home.path().join("sessions/2025/01/03"); @@ -975,7 +979,6 @@ mod tests { let home = TempDir::new().expect("temp dir"); let external = TempDir::new().expect("external temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config.clone()); let uuid = Uuid::from_u128(214); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let rollout_path = external @@ -987,6 +990,7 @@ mod tests { ) .await .expect("state db should initialize"); + let store = LocalThreadStore::new(config.clone(), Some(runtime.clone())); let mut builder = ThreadMetadataBuilder::new( thread_id, rollout_path.clone(), @@ -1033,7 +1037,6 @@ mod tests { let home = TempDir::new().expect("temp dir"); let external = TempDir::new().expect("external temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config.clone()); let uuid = Uuid::from_u128(216); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let rollout_path = external @@ -1047,6 +1050,7 @@ mod tests { .expect("state db should initialize"); let mut builder = ThreadMetadataBuilder::new(thread_id, rollout_path, Utc::now(), SessionSource::Cli); + let store = LocalThreadStore::new(config.clone(), Some(runtime.clone())); builder.archived_at = Some(Utc::now()); let mut metadata = builder.build(config.default_model_provider_id.as_str()); metadata.first_user_message = Some("Archived SQLite preview".to_string()); @@ -1089,7 +1093,6 @@ mod tests { async fn read_thread_sqlite_fallback_loads_archived_history() { let home = TempDir::new().expect("temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config.clone()); let uuid = Uuid::from_u128(219); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let archived_path = write_archived_session_file(home.path(), "2025-01-03T12-00-00", uuid) @@ -1100,6 +1103,7 @@ mod tests { ) .await .expect("state db should initialize"); + let store = LocalThreadStore::new(config.clone(), Some(runtime.clone())); let mut builder = ThreadMetadataBuilder::new( thread_id, archived_path.clone(), @@ -1135,7 +1139,7 @@ mod tests { #[tokio::test] async fn read_thread_fails_without_rollout() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = Uuid::from_u128(206); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); diff --git a/codex-rs/thread-store/src/local/unarchive_thread.rs b/codex-rs/thread-store/src/local/unarchive_thread.rs index 8a3ab2960a..ad41db69ac 100644 --- a/codex-rs/thread-store/src/local/unarchive_thread.rs +++ b/codex-rs/thread-store/src/local/unarchive_thread.rs @@ -17,9 +17,11 @@ pub(super) async fn unarchive_thread( params: ArchiveThreadParams, ) -> ThreadStoreResult { let thread_id = params.thread_id; + let state_db_ctx = store.state_db().await; let archived_path = find_archived_thread_path_by_id_str( store.config.codex_home.as_path(), &thread_id.to_string(), + state_db_ctx.as_deref(), ) .await .map_err(|err| ThreadStoreError::InvalidRequest { @@ -71,7 +73,7 @@ pub(super) async fn unarchive_thread( message: format!("failed to update unarchived thread timestamp: {err}"), })?; - if let Some(ctx) = store.state_db().await { + if let Some(ctx) = state_db_ctx { let _ = ctx .mark_unarchived(thread_id, restored_path.as_path()) .await; @@ -116,7 +118,7 @@ mod tests { #[tokio::test] async fn unarchive_thread_restores_rollout_and_returns_updated_thread() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = Uuid::from_u128(203); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let archived_path = write_archived_session_file(home.path(), "2025-01-03T13-00-00", uuid) @@ -147,7 +149,6 @@ mod tests { async fn unarchive_thread_updates_sqlite_metadata_when_present() { let home = TempDir::new().expect("temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config.clone()); let uuid = Uuid::from_u128(204); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let archived_path = write_archived_session_file(home.path(), "2025-01-03T13-00-00", uuid) @@ -158,6 +159,7 @@ mod tests { ) .await .expect("state db should initialize"); + let store = LocalThreadStore::new(config.clone(), Some(runtime.clone())); runtime .mark_backfill_complete(/*last_watermark*/ None) .await diff --git a/codex-rs/thread-store/src/local/update_thread_metadata.rs b/codex-rs/thread-store/src/local/update_thread_metadata.rs index fba0172525..677aa3fdce 100644 --- a/codex-rs/thread-store/src/local/update_thread_metadata.rs +++ b/codex-rs/thread-store/src/local/update_thread_metadata.rs @@ -157,12 +157,16 @@ async fn resolve_rollout_path( return Ok(ResolvedRolloutPath { path, archived }); } - let active_path = - find_thread_path_by_id_str(store.config.codex_home.as_path(), &thread_id.to_string()) - .await - .map_err(|err| ThreadStoreError::InvalidRequest { - message: format!("failed to locate thread id {thread_id}: {err}"), - })?; + let state_db_ctx = store.state_db().await; + let active_path = find_thread_path_by_id_str( + store.config.codex_home.as_path(), + &thread_id.to_string(), + state_db_ctx.as_deref(), + ) + .await + .map_err(|err| ThreadStoreError::InvalidRequest { + message: format!("failed to locate thread id {thread_id}: {err}"), + })?; if let Some(path) = active_path { return Ok(ResolvedRolloutPath { path, @@ -174,18 +178,22 @@ async fn resolve_rollout_path( message: format!("thread not found: {thread_id}"), }); } - find_archived_thread_path_by_id_str(store.config.codex_home.as_path(), &thread_id.to_string()) - .await - .map_err(|err| ThreadStoreError::InvalidRequest { - message: format!("failed to locate archived thread id {thread_id}: {err}"), - })? - .map(|path| ResolvedRolloutPath { - path, - archived: true, - }) - .ok_or_else(|| ThreadStoreError::InvalidRequest { - message: format!("thread not found: {thread_id}"), - }) + find_archived_thread_path_by_id_str( + store.config.codex_home.as_path(), + &thread_id.to_string(), + state_db_ctx.as_deref(), + ) + .await + .map_err(|err| ThreadStoreError::InvalidRequest { + message: format!("failed to locate archived thread id {thread_id}: {err}"), + })? + .map(|path| ResolvedRolloutPath { + path, + archived: true, + }) + .ok_or_else(|| ThreadStoreError::InvalidRequest { + message: format!("thread not found: {thread_id}"), + }) } fn rollout_path_is_archived(store: &LocalThreadStore, path: &std::path::Path) -> bool { @@ -213,7 +221,7 @@ mod tests { #[tokio::test] async fn update_thread_metadata_sets_name_on_active_rollout_and_indexes_name() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = Uuid::from_u128(301); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let path = @@ -248,7 +256,6 @@ mod tests { async fn update_thread_metadata_sets_memory_mode_on_active_rollout() { let home = TempDir::new().expect("temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config.clone()); let uuid = Uuid::from_u128(302); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let path = @@ -259,6 +266,7 @@ mod tests { ) .await .expect("state db should initialize"); + let store = LocalThreadStore::new(config.clone(), Some(runtime.clone())); let thread = store .update_thread_metadata(UpdateThreadMetadataParams { @@ -288,7 +296,7 @@ mod tests { async fn update_thread_metadata_uses_live_rollout_path_for_external_resume() { let home = TempDir::new().expect("temp dir"); let external_home = TempDir::new().expect("external temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = Uuid::from_u128(307); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let path = write_session_file(external_home.path(), "2025-01-03T14-45-00", uuid) @@ -328,7 +336,7 @@ mod tests { #[tokio::test] async fn update_thread_metadata_rejects_mismatched_session_meta_id() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let filename_uuid = Uuid::from_u128(303); let metadata_uuid = Uuid::from_u128(304); let thread_id = ThreadId::from_string(&filename_uuid.to_string()).expect("valid thread id"); @@ -360,7 +368,7 @@ mod tests { #[tokio::test] async fn update_thread_metadata_rejects_multi_field_patch_without_partial_write() { let home = TempDir::new().expect("temp dir"); - let store = LocalThreadStore::new(test_config(home.path())); + let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None); let uuid = Uuid::from_u128(305); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let path = @@ -395,7 +403,6 @@ mod tests { async fn update_thread_metadata_keeps_archived_thread_archived_in_sqlite() { let home = TempDir::new().expect("temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config.clone()); let uuid = Uuid::from_u128(306); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let archived_path = write_archived_session_file(home.path(), "2025-01-03T16-00-00", uuid) @@ -406,6 +413,7 @@ mod tests { ) .await .expect("state db should initialize"); + let store = LocalThreadStore::new(config.clone(), Some(runtime.clone())); runtime .mark_backfill_complete(/*last_watermark*/ None) .await @@ -458,7 +466,6 @@ mod tests { async fn update_thread_metadata_keeps_live_archived_thread_archived_in_sqlite() { let home = TempDir::new().expect("temp dir"); let config = test_config(home.path()); - let store = LocalThreadStore::new(config.clone()); let uuid = Uuid::from_u128(308); let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); let archived_path = write_archived_session_file(home.path(), "2025-01-03T16-30-00", uuid) @@ -469,6 +476,7 @@ mod tests { ) .await .expect("state db should initialize"); + let store = LocalThreadStore::new(config.clone(), Some(runtime.clone())); runtime .mark_backfill_complete(/*last_watermark*/ None) .await diff --git a/codex-rs/tui/src/app.rs b/codex-rs/tui/src/app.rs index 13140df3d7..2f3599da64 100644 --- a/codex-rs/tui/src/app.rs +++ b/codex-rs/tui/src/app.rs @@ -144,6 +144,7 @@ use codex_protocol::openai_models::ModelUpgrade; use codex_protocol::openai_models::ReasoningEffort as ReasoningEffortConfig; #[cfg(target_os = "windows")] use codex_protocol::permissions::FileSystemSandboxKind; +use codex_rollout::StateDbHandle; use codex_terminal_detection::user_agent; use codex_utils_absolute_path::AbsolutePathBuf; use color_eyre::eyre::Result; @@ -433,6 +434,7 @@ pub(crate) struct App { pub(crate) chat_widget: ChatWidget, /// Config is stored here so we can recreate ChatWidgets as needed. pub(crate) config: Config, + pub(crate) state_db: Option, pub(crate) active_profile: Option, cli_kv_overrides: Vec<(String, TomlValue)>, harness_overrides: ConfigOverrides, @@ -609,6 +611,7 @@ impl App { should_prompt_windows_sandbox_nux_at_startup: bool, remote_app_server_url: Option, remote_app_server_auth_token: Option, + state_db: Option, environment_manager: Arc, ) -> Result { use tokio_stream::StreamExt; @@ -854,6 +857,7 @@ See the Codex keymap documentation for supported actions and examples." app_event_tx, chat_widget, config, + state_db, active_profile, cli_kv_overrides, harness_overrides, diff --git a/codex-rs/tui/src/app/event_dispatch.rs b/codex-rs/tui/src/app/event_dispatch.rs index dfe7ceb4f9..8686377a5d 100644 --- a/codex-rs/tui/src/app/event_dispatch.rs +++ b/codex-rs/tui/src/app/event_dispatch.rs @@ -61,6 +61,7 @@ impl App { }, None => crate::AppServerTarget::Embedded, }, + self.state_db.clone(), self.environment_manager.clone(), ) .await diff --git a/codex-rs/tui/src/app/session_lifecycle.rs b/codex-rs/tui/src/app/session_lifecycle.rs index e83abcd0fb..05aba144c8 100644 --- a/codex-rs/tui/src/app/session_lifecycle.rs +++ b/codex-rs/tui/src/app/session_lifecycle.rs @@ -638,7 +638,7 @@ impl App { } else { match crate::session_resume::resolve_cwd_for_resume_or_fork( tui, - &self.config, + self.state_db.as_deref(), ¤t_cwd, target_session.thread_id, target_session.path.as_deref(), diff --git a/codex-rs/tui/src/app/test_support.rs b/codex-rs/tui/src/app/test_support.rs index eade7bf60e..3d3956c657 100644 --- a/codex-rs/tui/src/app/test_support.rs +++ b/codex-rs/tui/src/app/test_support.rs @@ -20,6 +20,7 @@ pub(super) async fn make_test_app() -> App { app_event_tx, chat_widget, config, + state_db: None, active_profile: None, cli_kv_overrides: Vec::new(), harness_overrides: ConfigOverrides::default(), diff --git a/codex-rs/tui/src/app/tests.rs b/codex-rs/tui/src/app/tests.rs index 799ac69e37..787afc7a15 100644 --- a/codex-rs/tui/src/app/tests.rs +++ b/codex-rs/tui/src/app/tests.rs @@ -3771,6 +3771,7 @@ async fn make_test_app() -> App { app_event_tx, chat_widget, config, + state_db: None, active_profile: None, cli_kv_overrides: Vec::new(), harness_overrides: ConfigOverrides::default(), @@ -3832,6 +3833,7 @@ async fn make_test_app_with_channels() -> ( app_event_tx, chat_widget, config, + state_db: None, active_profile: None, cli_kv_overrides: Vec::new(), harness_overrides: ConfigOverrides::default(), diff --git a/codex-rs/tui/src/app/thread_routing.rs b/codex-rs/tui/src/app/thread_routing.rs index acd79989b6..df6f01e8bd 100644 --- a/codex-rs/tui/src/app/thread_routing.rs +++ b/codex-rs/tui/src/app/thread_routing.rs @@ -916,7 +916,7 @@ impl App { session.cwd = notification.thread.cwd.clone(); let rollout_path = notification.thread.path.clone(); if let Some(model) = - read_session_model(&self.config, thread_id, rollout_path.as_deref()).await + read_session_model(self.state_db.as_deref(), thread_id, rollout_path.as_deref()).await { session.model = model; } else if rollout_path.is_some() { diff --git a/codex-rs/tui/src/app/thread_session_state.rs b/codex-rs/tui/src/app/thread_session_state.rs index 25ee6cd14a..3a898b82a3 100644 --- a/codex-rs/tui/src/app/thread_session_state.rs +++ b/codex-rs/tui/src/app/thread_session_state.rs @@ -87,7 +87,7 @@ impl App { session.instruction_source_paths = Vec::new(); session.rollout_path = thread.path.clone(); if let Some(model) = - read_session_model(&self.config, thread_id, thread.path.as_deref()).await + read_session_model(self.state_db.as_deref(), thread_id, thread.path.as_deref()).await { session.model = model; } else if thread.path.is_some() { diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index 466b7b9e6e..5128d492fb 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -49,7 +49,8 @@ use codex_protocol::ThreadId; use codex_protocol::config_types::AltScreenMode; use codex_protocol::config_types::SandboxMode; use codex_protocol::config_types::WindowsSandboxLevel; -use codex_rollout::state_db::get_state_db; +use codex_rollout::StateDbHandle; +use codex_rollout::state_db; use codex_state::log_db; use codex_terminal_detection::terminal_info; use codex_utils_absolute_path::AbsolutePathBuf; @@ -271,6 +272,7 @@ async fn start_embedded_app_server( cloud_requirements: CloudRequirementsLoader, feedback: codex_feedback::CodexFeedback, log_db: Option, + state_db: Option, environment_manager: Arc, ) -> color_eyre::Result { start_embedded_app_server_with( @@ -281,6 +283,7 @@ async fn start_embedded_app_server( cloud_requirements, feedback, log_db, + state_db, environment_manager, InProcessAppServerClient::start, ) @@ -398,6 +401,7 @@ async fn start_app_server( cloud_requirements: CloudRequirementsLoader, feedback: codex_feedback::CodexFeedback, log_db: Option, + state_db: Option, environment_manager: Arc, ) -> color_eyre::Result { match target { @@ -409,6 +413,7 @@ async fn start_app_server( cloud_requirements, feedback, log_db, + state_db, environment_manager, ) .await @@ -423,6 +428,7 @@ async fn start_app_server( pub(crate) async fn start_app_server_for_picker( config: &Config, target: &AppServerTarget, + state_db: Option, environment_manager: Arc, ) -> color_eyre::Result { let app_server = start_app_server( @@ -434,6 +440,7 @@ pub(crate) async fn start_app_server_for_picker( CloudRequirementsLoader::default(), codex_feedback::CodexFeedback::new(), /*log_db*/ None, + state_db, environment_manager, ) .await?; @@ -444,9 +451,11 @@ pub(crate) async fn start_app_server_for_picker( pub(crate) async fn start_embedded_app_server_for_picker( config: &Config, ) -> color_eyre::Result { + let state_db = state_db::init(config).await; start_app_server_for_picker( config, &AppServerTarget::Embedded, + state_db, Arc::new(EnvironmentManager::default_for_tests()), ) .await @@ -461,6 +470,7 @@ async fn start_embedded_app_server_with( cloud_requirements: CloudRequirementsLoader, feedback: codex_feedback::CodexFeedback, log_db: Option, + state_db: Option, environment_manager: Arc, start_client: F, ) -> color_eyre::Result @@ -486,6 +496,7 @@ where cloud_requirements, feedback, log_db, + state_db, environment_manager, config_warnings, session_source: serde_json::from_value(serde_json::json!("cli")) @@ -787,15 +798,6 @@ pub async fn run_main( } }; - if let Err(err) = crate::legacy_core::personality_migration::maybe_migrate_personality( - &codex_home, - &config_toml, - ) - .await - { - tracing::warn!(error = %err, "failed to run personality migration"); - } - let chatgpt_base_url = config_toml .chatgpt_base_url .clone() @@ -865,13 +867,53 @@ pub async fn run_main( ..Default::default() }; - let config = load_config_or_exit( + let mut config = load_config_or_exit( cli_kv_overrides.clone(), overrides.clone(), cloud_requirements.clone(), ) .await; + let state_db = match &app_server_target { + AppServerTarget::Embedded => state_db::init(&config).await, + AppServerTarget::Remote { .. } => state_db::get_state_db(&config).await, + }; + + let effective_toml = config.config_layer_stack.effective_config(); + match effective_toml.try_into() { + Ok(config_toml) => { + match crate::legacy_core::personality_migration::maybe_migrate_personality( + &config.codex_home, + &config_toml, + state_db.clone(), + ) + .await + { + Ok( + crate::legacy_core::personality_migration::PersonalityMigrationStatus::Applied, + ) => { + config = load_config_or_exit( + cli_kv_overrides.clone(), + overrides.clone(), + cloud_requirements.clone(), + ) + .await; + } + Ok( + crate::legacy_core::personality_migration::PersonalityMigrationStatus::SkippedMarker + | crate::legacy_core::personality_migration::PersonalityMigrationStatus::SkippedExplicitPersonality + | crate::legacy_core::personality_migration::PersonalityMigrationStatus::SkippedNoSessions, + ) => {} + Err(err) => { + tracing::warn!(error = %err, "failed to run personality migration"); + } + } + } + Err(err) => { + tracing::warn!(error = %err, "failed to deserialize config for personality migration"); + } + } + #[allow(clippy::print_stderr)] match check_execpolicy_for_warnings(&config.config_layer_stack).await { Ok(None) => {} @@ -1003,7 +1045,7 @@ pub async fn run_main( let otel_tracing_layer = otel.as_ref().and_then(|o| o.tracing_layer()); - let log_db = get_state_db(&config).await.map(log_db::start); + let log_db = state_db.clone().map(log_db::start); let log_db_layer = log_db .clone() .map(|layer| layer.with_filter(Targets::new().with_default(Level::TRACE))); @@ -1029,6 +1071,7 @@ pub async fn run_main( cloud_requirements, feedback, log_db, + state_db, remote_url, remote_auth_token, environment_manager, @@ -1050,6 +1093,7 @@ async fn run_ratatui_app( mut cloud_requirements: CloudRequirementsLoader, feedback: codex_feedback::CodexFeedback, log_db: Option, + state_db: Option, remote_url: Option, remote_auth_token: Option, environment_manager: Arc, @@ -1109,6 +1153,7 @@ async fn run_ratatui_app( cloud_requirements.clone(), feedback.clone(), log_db.clone(), + state_db.clone(), environment_manager.clone(), ) .await @@ -1360,7 +1405,7 @@ async fn run_ratatui_app( } else { match resolve_cwd_for_resume_or_fork( &mut tui, - &config, + state_db.as_deref(), ¤t_cwd, target_session.thread_id, target_session.path.as_deref(), @@ -1438,6 +1483,7 @@ async fn run_ratatui_app( cloud_requirements.clone(), feedback.clone(), log_db.clone(), + state_db.clone(), environment_manager.clone(), ) .await @@ -1468,6 +1514,7 @@ async fn run_ratatui_app( should_prompt_windows_sandbox_nux_at_startup, remote_url, remote_auth_token, + state_db, environment_manager, ) .await; @@ -1672,6 +1719,7 @@ mod tests { async fn start_test_embedded_app_server( config: Config, ) -> color_eyre::Result { + let state_db = state_db::init(&config).await; start_embedded_app_server( Arg0DispatchPaths::default(), config, @@ -1680,6 +1728,7 @@ mod tests { CloudRequirementsLoader::default(), codex_feedback::CodexFeedback::new(), /*log_db*/ None, + state_db, Arc::new(EnvironmentManager::default_for_tests()), ) .await @@ -2026,6 +2075,7 @@ mod tests { CloudRequirementsLoader::default(), codex_feedback::CodexFeedback::new(), /*log_db*/ None, + /*state_db*/ None, Arc::new(EnvironmentManager::default_for_tests()), |_args| async { Err(std::io::Error::other("boom")) }, ) diff --git a/codex-rs/tui/src/onboarding/auth.rs b/codex-rs/tui/src/onboarding/auth.rs index 9ceef56dbc..6f226b5cfe 100644 --- a/codex-rs/tui/src/onboarding/auth.rs +++ b/codex-rs/tui/src/onboarding/auth.rs @@ -1048,6 +1048,7 @@ mod tests { .await, feedback: codex_feedback::CodexFeedback::new(), log_db: None, + state_db: None, environment_manager: Arc::new( codex_app_server_client::EnvironmentManager::default_for_tests(), ), diff --git a/codex-rs/tui/src/session_resume.rs b/codex-rs/tui/src/session_resume.rs index 169a096d1e..9b47599edc 100644 --- a/codex-rs/tui/src/session_resume.rs +++ b/codex-rs/tui/src/session_resume.rs @@ -12,10 +12,9 @@ use crate::cwd_prompt; use crate::cwd_prompt::CwdPromptAction; use crate::cwd_prompt::CwdPromptOutcome; use crate::cwd_prompt::CwdSelection; -use crate::legacy_core::config::Config; use crate::tui::Tui; use codex_protocol::ThreadId; -use codex_rollout::state_db::get_state_db; +use codex_state::StateRuntime; use codex_utils_path as path_utils; use serde::Deserialize; use serde_json::Value; @@ -66,11 +65,11 @@ pub(crate) async fn resolve_session_thread_id( } pub(crate) async fn read_session_model( - config: &Config, + state_db_ctx: Option<&StateRuntime>, thread_id: ThreadId, path: Option<&Path>, ) -> Option { - if let Some(state_db_ctx) = get_state_db(config).await + if let Some(state_db_ctx) = state_db_ctx && let Ok(Some(metadata)) = state_db_ctx.get_thread(thread_id).await && let Some(model) = metadata.model { @@ -86,14 +85,14 @@ pub(crate) async fn read_session_model( pub(crate) async fn resolve_cwd_for_resume_or_fork( tui: &mut Tui, - config: &Config, + state_db_ctx: Option<&StateRuntime>, current_cwd: &Path, thread_id: ThreadId, path: Option<&Path>, action: CwdPromptAction, allow_prompt: bool, ) -> color_eyre::Result { - let Some(history_cwd) = read_session_cwd(config, thread_id, path).await else { + let Some(history_cwd) = read_session_cwd(state_db_ctx, thread_id, path).await else { return Ok(ResolveCwdOutcome::Continue(None)); }; if allow_prompt && cwds_differ(current_cwd, &history_cwd) { @@ -113,11 +112,11 @@ pub(crate) async fn resolve_cwd_for_resume_or_fork( } async fn read_session_cwd( - config: &Config, + state_db_ctx: Option<&StateRuntime>, thread_id: ThreadId, path: Option<&Path>, ) -> Option { - if let Some(state_db_ctx) = get_state_db(config).await + if let Some(state_db_ctx) = state_db_ctx && let Ok(Some(metadata)) = state_db_ctx.get_thread(thread_id).await { return Some(metadata.cwd);