mirror of
https://github.com/openai/codex.git
synced 2026-06-01 19:02:59 +00:00
chore: improve DB flushing (#13620)
This branch: * Avoid flushing DB when not necessary * Filter events for which we perfom an `upsert` into the DB * Add a dedicated update function of the `thread:updated_at` that is lighter This should significantly reduce the DB lock contention. If it is not sufficient, we can de-sync the flush of the DB for `updated_at`
This commit is contained in:
@@ -1694,7 +1694,7 @@ impl Session {
|
||||
self.services.state_db.clone()
|
||||
}
|
||||
|
||||
/// Ensure all rollout writes are durably flushed.
|
||||
/// Ensure rollout file writes are durably flushed.
|
||||
pub(crate) async fn flush_rollout(&self) {
|
||||
let recorder = {
|
||||
let guard = self.services.rollout.lock().await;
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use chrono::SecondsFormat;
|
||||
use chrono::Utc;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::dynamic_tools::DynamicToolSpec;
|
||||
use codex_protocol::models::BaseInstructions;
|
||||
@@ -448,7 +449,6 @@ impl RolloutRecorder {
|
||||
// future will yield, which is fine – we only need to ensure we do not
|
||||
// perform *blocking* I/O on the caller's thread.
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
|
||||
// Spawn a Tokio task that owns the file handle and performs async
|
||||
// writes. Using `tokio::fs::File` keeps everything on the async I/O
|
||||
// driver instead of blocking the runtime.
|
||||
@@ -614,14 +614,15 @@ impl RolloutRecorder {
|
||||
match self.tx.send(RolloutCmd::Shutdown { ack: tx_done }).await {
|
||||
Ok(_) => rx_done
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed waiting for rollout shutdown: {e}"))),
|
||||
.map_err(|e| IoError::other(format!("failed waiting for rollout shutdown: {e}")))?,
|
||||
Err(e) => {
|
||||
warn!("failed to send rollout shutdown command: {e}");
|
||||
Err(IoError::other(format!(
|
||||
return Err(IoError::other(format!(
|
||||
"failed to send rollout shutdown command: {e}"
|
||||
)))
|
||||
)));
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -744,25 +745,21 @@ async fn rollout_writer(
|
||||
while let Some(cmd) = rx.recv().await {
|
||||
match cmd {
|
||||
RolloutCmd::AddItems(items) => {
|
||||
let mut persisted_items = Vec::new();
|
||||
for item in items {
|
||||
persisted_items.push(item);
|
||||
}
|
||||
if persisted_items.is_empty() {
|
||||
if items.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if writer.is_none() {
|
||||
buffered_items.extend(persisted_items);
|
||||
buffered_items.extend(items);
|
||||
continue;
|
||||
}
|
||||
|
||||
write_and_reconcile_items(
|
||||
writer.as_mut(),
|
||||
persisted_items.as_slice(),
|
||||
items.as_slice(),
|
||||
&rollout_path,
|
||||
state_db_ctx.as_deref(),
|
||||
&mut state_builder,
|
||||
state_builder.as_ref(),
|
||||
default_provider.as_str(),
|
||||
)
|
||||
.await?;
|
||||
@@ -800,7 +797,7 @@ async fn rollout_writer(
|
||||
buffered_items.as_slice(),
|
||||
&rollout_path,
|
||||
state_db_ctx.as_deref(),
|
||||
&mut state_builder,
|
||||
state_builder.as_ref(),
|
||||
default_provider.as_str(),
|
||||
)
|
||||
.await?;
|
||||
@@ -861,13 +858,12 @@ async fn write_session_meta(
|
||||
if let Some(writer) = writer.as_mut() {
|
||||
writer.write_rollout_item(&rollout_item).await?;
|
||||
}
|
||||
state_db::reconcile_rollout(
|
||||
sync_thread_state_after_write(
|
||||
state_db_ctx,
|
||||
rollout_path,
|
||||
default_provider,
|
||||
state_builder.as_ref(),
|
||||
std::slice::from_ref(&rollout_item),
|
||||
None,
|
||||
default_provider,
|
||||
(!generate_memories).then_some("disabled"),
|
||||
)
|
||||
.await;
|
||||
@@ -879,7 +875,7 @@ async fn write_and_reconcile_items(
|
||||
items: &[RolloutItem],
|
||||
rollout_path: &Path,
|
||||
state_db_ctx: Option<&StateRuntime>,
|
||||
state_builder: &mut Option<ThreadMetadataBuilder>,
|
||||
state_builder: Option<&ThreadMetadataBuilder>,
|
||||
default_provider: &str,
|
||||
) -> std::io::Result<()> {
|
||||
if let Some(writer) = writer.as_mut() {
|
||||
@@ -887,20 +883,65 @@ async fn write_and_reconcile_items(
|
||||
writer.write_rollout_item(item).await?;
|
||||
}
|
||||
}
|
||||
if let Some(builder) = state_builder.as_mut() {
|
||||
builder.rollout_path = rollout_path.to_path_buf();
|
||||
sync_thread_state_after_write(
|
||||
state_db_ctx,
|
||||
rollout_path,
|
||||
state_builder,
|
||||
items,
|
||||
default_provider,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn sync_thread_state_after_write(
|
||||
state_db_ctx: Option<&StateRuntime>,
|
||||
rollout_path: &Path,
|
||||
state_builder: Option<&ThreadMetadataBuilder>,
|
||||
items: &[RolloutItem],
|
||||
default_provider: &str,
|
||||
new_thread_memory_mode: Option<&str>,
|
||||
) {
|
||||
let updated_at = Utc::now();
|
||||
if new_thread_memory_mode.is_some()
|
||||
|| items
|
||||
.iter()
|
||||
.any(codex_state::rollout_item_affects_thread_metadata)
|
||||
{
|
||||
state_db::apply_rollout_items(
|
||||
state_db_ctx,
|
||||
rollout_path,
|
||||
default_provider,
|
||||
state_builder,
|
||||
items,
|
||||
"rollout_writer",
|
||||
new_thread_memory_mode,
|
||||
Some(updated_at),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let thread_id = state_builder
|
||||
.map(|builder| builder.id)
|
||||
.or_else(|| metadata::builder_from_items(items, rollout_path).map(|builder| builder.id));
|
||||
if state_db::touch_thread_updated_at(state_db_ctx, thread_id, updated_at, "rollout_writer")
|
||||
.await
|
||||
{
|
||||
return;
|
||||
}
|
||||
state_db::apply_rollout_items(
|
||||
state_db_ctx,
|
||||
rollout_path,
|
||||
default_provider,
|
||||
state_builder.as_ref(),
|
||||
state_builder,
|
||||
items,
|
||||
"rollout_writer",
|
||||
None,
|
||||
new_thread_memory_mode,
|
||||
Some(updated_at),
|
||||
)
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct JsonlWriter {
|
||||
@@ -1079,6 +1120,7 @@ mod tests {
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
use tempfile::TempDir;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -1200,6 +1242,151 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn metadata_irrelevant_events_touch_state_db_updated_at() -> std::io::Result<()> {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let mut config = ConfigBuilder::default()
|
||||
.codex_home(home.path().to_path_buf())
|
||||
.build()
|
||||
.await?;
|
||||
config
|
||||
.features
|
||||
.enable(Feature::Sqlite)
|
||||
.expect("test config should allow sqlite");
|
||||
|
||||
let state_db = codex_state::StateRuntime::init(
|
||||
home.path().to_path_buf(),
|
||||
config.model_provider_id.clone(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
state_db
|
||||
.mark_backfill_complete(None)
|
||||
.await
|
||||
.expect("backfill should be complete");
|
||||
|
||||
let thread_id = ThreadId::new();
|
||||
let recorder = RolloutRecorder::new(
|
||||
&config,
|
||||
RolloutRecorderParams::new(
|
||||
thread_id,
|
||||
None,
|
||||
SessionSource::Cli,
|
||||
BaseInstructions::default(),
|
||||
Vec::new(),
|
||||
EventPersistenceMode::Limited,
|
||||
),
|
||||
Some(state_db.clone()),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
recorder
|
||||
.record_items(&[RolloutItem::EventMsg(EventMsg::UserMessage(
|
||||
UserMessageEvent {
|
||||
message: "first-user-message".to_string(),
|
||||
images: None,
|
||||
local_images: Vec::new(),
|
||||
text_elements: Vec::new(),
|
||||
},
|
||||
))])
|
||||
.await?;
|
||||
recorder.persist().await?;
|
||||
recorder.flush().await?;
|
||||
let initial_thread = state_db
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("thread should load")
|
||||
.expect("thread should exist");
|
||||
let initial_updated_at = initial_thread.updated_at;
|
||||
let initial_title = initial_thread.title.clone();
|
||||
let initial_first_user_message = initial_thread.first_user_message.clone();
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
|
||||
recorder
|
||||
.record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
|
||||
AgentMessageEvent {
|
||||
message: "assistant text".to_string(),
|
||||
phase: None,
|
||||
},
|
||||
))])
|
||||
.await?;
|
||||
recorder.flush().await?;
|
||||
|
||||
let updated_thread = state_db
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("thread should load after agent message")
|
||||
.expect("thread should still exist");
|
||||
|
||||
assert!(updated_thread.updated_at > initial_updated_at);
|
||||
assert_eq!(updated_thread.title, initial_title);
|
||||
assert_eq!(
|
||||
updated_thread.first_user_message,
|
||||
initial_first_user_message
|
||||
);
|
||||
|
||||
recorder.shutdown().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn metadata_irrelevant_events_fall_back_to_upsert_when_thread_missing()
|
||||
-> std::io::Result<()> {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let mut config = ConfigBuilder::default()
|
||||
.codex_home(home.path().to_path_buf())
|
||||
.build()
|
||||
.await?;
|
||||
config
|
||||
.features
|
||||
.enable(Feature::Sqlite)
|
||||
.expect("test config should allow sqlite");
|
||||
|
||||
let state_db = codex_state::StateRuntime::init(
|
||||
home.path().to_path_buf(),
|
||||
config.model_provider_id.clone(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let thread_id = ThreadId::new();
|
||||
let rollout_path = home.path().join("rollout.jsonl");
|
||||
let builder = ThreadMetadataBuilder::new(
|
||||
thread_id,
|
||||
rollout_path.clone(),
|
||||
Utc::now(),
|
||||
SessionSource::Cli,
|
||||
);
|
||||
let items = vec![RolloutItem::EventMsg(EventMsg::AgentMessage(
|
||||
AgentMessageEvent {
|
||||
message: "assistant text".to_string(),
|
||||
phase: None,
|
||||
},
|
||||
))];
|
||||
|
||||
sync_thread_state_after_write(
|
||||
Some(state_db.as_ref()),
|
||||
rollout_path.as_path(),
|
||||
Some(&builder),
|
||||
items.as_slice(),
|
||||
config.model_provider_id.as_str(),
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
let thread = state_db
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("thread should load after fallback")
|
||||
.expect("thread should be inserted after fallback");
|
||||
assert_eq!(thread.id, thread_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_threads_db_disabled_does_not_skip_paginated_items() -> std::io::Result<()> {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
|
||||
@@ -361,6 +361,7 @@ pub async fn reconcile_rollout(
|
||||
items,
|
||||
"reconcile_rollout",
|
||||
new_thread_memory_mode,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
@@ -491,6 +492,7 @@ pub async fn read_repair_rollout_path(
|
||||
}
|
||||
|
||||
/// Apply rollout items incrementally to SQLite.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn apply_rollout_items(
|
||||
context: Option<&codex_state::StateRuntime>,
|
||||
rollout_path: &Path,
|
||||
@@ -499,6 +501,7 @@ pub async fn apply_rollout_items(
|
||||
items: &[RolloutItem],
|
||||
stage: &str,
|
||||
new_thread_memory_mode: Option<&str>,
|
||||
updated_at_override: Option<DateTime<Utc>>,
|
||||
) {
|
||||
let Some(ctx) = context else {
|
||||
return;
|
||||
@@ -520,7 +523,13 @@ pub async fn apply_rollout_items(
|
||||
builder.rollout_path = rollout_path.to_path_buf();
|
||||
builder.cwd = normalize_cwd_for_state_db(&builder.cwd);
|
||||
if let Err(err) = ctx
|
||||
.apply_rollout_items(&builder, items, None, new_thread_memory_mode)
|
||||
.apply_rollout_items(
|
||||
&builder,
|
||||
items,
|
||||
None,
|
||||
new_thread_memory_mode,
|
||||
updated_at_override,
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
@@ -530,6 +539,26 @@ pub async fn apply_rollout_items(
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn touch_thread_updated_at(
|
||||
context: Option<&codex_state::StateRuntime>,
|
||||
thread_id: Option<ThreadId>,
|
||||
updated_at: DateTime<Utc>,
|
||||
stage: &str,
|
||||
) -> bool {
|
||||
let Some(ctx) = context else {
|
||||
return false;
|
||||
};
|
||||
let Some(thread_id) = thread_id else {
|
||||
return false;
|
||||
};
|
||||
ctx.touch_thread_updated_at(thread_id, updated_at)
|
||||
.await
|
||||
.unwrap_or_else(|err| {
|
||||
warn!("state db touch_thread_updated_at failed during {stage} for {thread_id}: {err}");
|
||||
false
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user