chore: improve DB flushing (#13620)

This branch:
* Avoid flushing DB when not necessary
* Filter events for which we perfom an `upsert` into the DB
* Add a dedicated update function of the `thread:updated_at` that is
lighter

This should significantly reduce the DB lock contention. If it is not
sufficient, we can de-sync the flush of the DB for `updated_at`
This commit is contained in:
jif-oai
2026-03-06 18:58:14 +00:00
committed by GitHub
parent 4e6c6193a1
commit 0e41a5c4a8
6 changed files with 366 additions and 28 deletions

View File

@@ -1694,7 +1694,7 @@ impl Session {
self.services.state_db.clone()
}
/// Ensure all rollout writes are durably flushed.
/// Ensure rollout file writes are durably flushed.
pub(crate) async fn flush_rollout(&self) {
let recorder = {
let guard = self.services.rollout.lock().await;

View File

@@ -7,6 +7,7 @@ use std::path::Path;
use std::path::PathBuf;
use chrono::SecondsFormat;
use chrono::Utc;
use codex_protocol::ThreadId;
use codex_protocol::dynamic_tools::DynamicToolSpec;
use codex_protocol::models::BaseInstructions;
@@ -448,7 +449,6 @@ impl RolloutRecorder {
// future will yield, which is fine we only need to ensure we do not
// perform *blocking* I/O on the caller's thread.
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
// Spawn a Tokio task that owns the file handle and performs async
// writes. Using `tokio::fs::File` keeps everything on the async I/O
// driver instead of blocking the runtime.
@@ -614,14 +614,15 @@ impl RolloutRecorder {
match self.tx.send(RolloutCmd::Shutdown { ack: tx_done }).await {
Ok(_) => rx_done
.await
.map_err(|e| IoError::other(format!("failed waiting for rollout shutdown: {e}"))),
.map_err(|e| IoError::other(format!("failed waiting for rollout shutdown: {e}")))?,
Err(e) => {
warn!("failed to send rollout shutdown command: {e}");
Err(IoError::other(format!(
return Err(IoError::other(format!(
"failed to send rollout shutdown command: {e}"
)))
)));
}
}
};
Ok(())
}
}
@@ -744,25 +745,21 @@ async fn rollout_writer(
while let Some(cmd) = rx.recv().await {
match cmd {
RolloutCmd::AddItems(items) => {
let mut persisted_items = Vec::new();
for item in items {
persisted_items.push(item);
}
if persisted_items.is_empty() {
if items.is_empty() {
continue;
}
if writer.is_none() {
buffered_items.extend(persisted_items);
buffered_items.extend(items);
continue;
}
write_and_reconcile_items(
writer.as_mut(),
persisted_items.as_slice(),
items.as_slice(),
&rollout_path,
state_db_ctx.as_deref(),
&mut state_builder,
state_builder.as_ref(),
default_provider.as_str(),
)
.await?;
@@ -800,7 +797,7 @@ async fn rollout_writer(
buffered_items.as_slice(),
&rollout_path,
state_db_ctx.as_deref(),
&mut state_builder,
state_builder.as_ref(),
default_provider.as_str(),
)
.await?;
@@ -861,13 +858,12 @@ async fn write_session_meta(
if let Some(writer) = writer.as_mut() {
writer.write_rollout_item(&rollout_item).await?;
}
state_db::reconcile_rollout(
sync_thread_state_after_write(
state_db_ctx,
rollout_path,
default_provider,
state_builder.as_ref(),
std::slice::from_ref(&rollout_item),
None,
default_provider,
(!generate_memories).then_some("disabled"),
)
.await;
@@ -879,7 +875,7 @@ async fn write_and_reconcile_items(
items: &[RolloutItem],
rollout_path: &Path,
state_db_ctx: Option<&StateRuntime>,
state_builder: &mut Option<ThreadMetadataBuilder>,
state_builder: Option<&ThreadMetadataBuilder>,
default_provider: &str,
) -> std::io::Result<()> {
if let Some(writer) = writer.as_mut() {
@@ -887,20 +883,65 @@ async fn write_and_reconcile_items(
writer.write_rollout_item(item).await?;
}
}
if let Some(builder) = state_builder.as_mut() {
builder.rollout_path = rollout_path.to_path_buf();
sync_thread_state_after_write(
state_db_ctx,
rollout_path,
state_builder,
items,
default_provider,
None,
)
.await;
Ok(())
}
async fn sync_thread_state_after_write(
state_db_ctx: Option<&StateRuntime>,
rollout_path: &Path,
state_builder: Option<&ThreadMetadataBuilder>,
items: &[RolloutItem],
default_provider: &str,
new_thread_memory_mode: Option<&str>,
) {
let updated_at = Utc::now();
if new_thread_memory_mode.is_some()
|| items
.iter()
.any(codex_state::rollout_item_affects_thread_metadata)
{
state_db::apply_rollout_items(
state_db_ctx,
rollout_path,
default_provider,
state_builder,
items,
"rollout_writer",
new_thread_memory_mode,
Some(updated_at),
)
.await;
return;
}
let thread_id = state_builder
.map(|builder| builder.id)
.or_else(|| metadata::builder_from_items(items, rollout_path).map(|builder| builder.id));
if state_db::touch_thread_updated_at(state_db_ctx, thread_id, updated_at, "rollout_writer")
.await
{
return;
}
state_db::apply_rollout_items(
state_db_ctx,
rollout_path,
default_provider,
state_builder.as_ref(),
state_builder,
items,
"rollout_writer",
None,
new_thread_memory_mode,
Some(updated_at),
)
.await;
Ok(())
}
struct JsonlWriter {
@@ -1079,6 +1120,7 @@ mod tests {
use std::io::Write;
use std::path::Path;
use std::path::PathBuf;
use std::time::Duration;
use tempfile::TempDir;
use uuid::Uuid;
@@ -1200,6 +1242,151 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn metadata_irrelevant_events_touch_state_db_updated_at() -> std::io::Result<()> {
let home = TempDir::new().expect("temp dir");
let mut config = ConfigBuilder::default()
.codex_home(home.path().to_path_buf())
.build()
.await?;
config
.features
.enable(Feature::Sqlite)
.expect("test config should allow sqlite");
let state_db = codex_state::StateRuntime::init(
home.path().to_path_buf(),
config.model_provider_id.clone(),
None,
)
.await
.expect("state db should initialize");
state_db
.mark_backfill_complete(None)
.await
.expect("backfill should be complete");
let thread_id = ThreadId::new();
let recorder = RolloutRecorder::new(
&config,
RolloutRecorderParams::new(
thread_id,
None,
SessionSource::Cli,
BaseInstructions::default(),
Vec::new(),
EventPersistenceMode::Limited,
),
Some(state_db.clone()),
None,
)
.await?;
recorder
.record_items(&[RolloutItem::EventMsg(EventMsg::UserMessage(
UserMessageEvent {
message: "first-user-message".to_string(),
images: None,
local_images: Vec::new(),
text_elements: Vec::new(),
},
))])
.await?;
recorder.persist().await?;
recorder.flush().await?;
let initial_thread = state_db
.get_thread(thread_id)
.await
.expect("thread should load")
.expect("thread should exist");
let initial_updated_at = initial_thread.updated_at;
let initial_title = initial_thread.title.clone();
let initial_first_user_message = initial_thread.first_user_message.clone();
tokio::time::sleep(Duration::from_secs(1)).await;
recorder
.record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
AgentMessageEvent {
message: "assistant text".to_string(),
phase: None,
},
))])
.await?;
recorder.flush().await?;
let updated_thread = state_db
.get_thread(thread_id)
.await
.expect("thread should load after agent message")
.expect("thread should still exist");
assert!(updated_thread.updated_at > initial_updated_at);
assert_eq!(updated_thread.title, initial_title);
assert_eq!(
updated_thread.first_user_message,
initial_first_user_message
);
recorder.shutdown().await?;
Ok(())
}
#[tokio::test]
async fn metadata_irrelevant_events_fall_back_to_upsert_when_thread_missing()
-> std::io::Result<()> {
let home = TempDir::new().expect("temp dir");
let mut config = ConfigBuilder::default()
.codex_home(home.path().to_path_buf())
.build()
.await?;
config
.features
.enable(Feature::Sqlite)
.expect("test config should allow sqlite");
let state_db = codex_state::StateRuntime::init(
home.path().to_path_buf(),
config.model_provider_id.clone(),
None,
)
.await
.expect("state db should initialize");
let thread_id = ThreadId::new();
let rollout_path = home.path().join("rollout.jsonl");
let builder = ThreadMetadataBuilder::new(
thread_id,
rollout_path.clone(),
Utc::now(),
SessionSource::Cli,
);
let items = vec![RolloutItem::EventMsg(EventMsg::AgentMessage(
AgentMessageEvent {
message: "assistant text".to_string(),
phase: None,
},
))];
sync_thread_state_after_write(
Some(state_db.as_ref()),
rollout_path.as_path(),
Some(&builder),
items.as_slice(),
config.model_provider_id.as_str(),
None,
)
.await;
let thread = state_db
.get_thread(thread_id)
.await
.expect("thread should load after fallback")
.expect("thread should be inserted after fallback");
assert_eq!(thread.id, thread_id);
Ok(())
}
#[tokio::test]
async fn list_threads_db_disabled_does_not_skip_paginated_items() -> std::io::Result<()> {
let home = TempDir::new().expect("temp dir");

View File

@@ -361,6 +361,7 @@ pub async fn reconcile_rollout(
items,
"reconcile_rollout",
new_thread_memory_mode,
None,
)
.await;
return;
@@ -491,6 +492,7 @@ pub async fn read_repair_rollout_path(
}
/// Apply rollout items incrementally to SQLite.
#[allow(clippy::too_many_arguments)]
pub async fn apply_rollout_items(
context: Option<&codex_state::StateRuntime>,
rollout_path: &Path,
@@ -499,6 +501,7 @@ pub async fn apply_rollout_items(
items: &[RolloutItem],
stage: &str,
new_thread_memory_mode: Option<&str>,
updated_at_override: Option<DateTime<Utc>>,
) {
let Some(ctx) = context else {
return;
@@ -520,7 +523,13 @@ pub async fn apply_rollout_items(
builder.rollout_path = rollout_path.to_path_buf();
builder.cwd = normalize_cwd_for_state_db(&builder.cwd);
if let Err(err) = ctx
.apply_rollout_items(&builder, items, None, new_thread_memory_mode)
.apply_rollout_items(
&builder,
items,
None,
new_thread_memory_mode,
updated_at_override,
)
.await
{
warn!(
@@ -530,6 +539,26 @@ pub async fn apply_rollout_items(
}
}
pub async fn touch_thread_updated_at(
context: Option<&codex_state::StateRuntime>,
thread_id: Option<ThreadId>,
updated_at: DateTime<Utc>,
stage: &str,
) -> bool {
let Some(ctx) = context else {
return false;
};
let Some(thread_id) = thread_id else {
return false;
};
ctx.touch_thread_updated_at(thread_id, updated_at)
.await
.unwrap_or_else(|err| {
warn!("state db touch_thread_updated_at failed during {stage} for {thread_id}: {err}");
false
})
}
#[cfg(test)]
mod tests {
use super::*;