mirror of
https://github.com/openai/codex.git
synced 2026-04-30 17:36:40 +00:00
Extract rollout into its own crate (#15548)
This commit is contained in:
100
codex-rs/rollout/src/config.rs
Normal file
100
codex-rs/rollout/src/config.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub trait RolloutConfigView {
|
||||
fn codex_home(&self) -> &Path;
|
||||
fn sqlite_home(&self) -> &Path;
|
||||
fn cwd(&self) -> &Path;
|
||||
fn model_provider_id(&self) -> &str;
|
||||
fn generate_memories(&self) -> bool;
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct RolloutConfig {
|
||||
pub codex_home: PathBuf,
|
||||
pub sqlite_home: PathBuf,
|
||||
pub cwd: PathBuf,
|
||||
pub model_provider_id: String,
|
||||
pub generate_memories: bool,
|
||||
}
|
||||
|
||||
pub type Config = RolloutConfig;
|
||||
|
||||
impl RolloutConfig {
|
||||
pub fn from_view(view: &impl RolloutConfigView) -> Self {
|
||||
Self {
|
||||
codex_home: view.codex_home().to_path_buf(),
|
||||
sqlite_home: view.sqlite_home().to_path_buf(),
|
||||
cwd: view.cwd().to_path_buf(),
|
||||
model_provider_id: view.model_provider_id().to_string(),
|
||||
generate_memories: view.generate_memories(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RolloutConfigView for RolloutConfig {
|
||||
fn codex_home(&self) -> &Path {
|
||||
self.codex_home.as_path()
|
||||
}
|
||||
|
||||
fn sqlite_home(&self) -> &Path {
|
||||
self.sqlite_home.as_path()
|
||||
}
|
||||
|
||||
fn cwd(&self) -> &Path {
|
||||
self.cwd.as_path()
|
||||
}
|
||||
|
||||
fn model_provider_id(&self) -> &str {
|
||||
self.model_provider_id.as_str()
|
||||
}
|
||||
|
||||
fn generate_memories(&self) -> bool {
|
||||
self.generate_memories
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RolloutConfigView + ?Sized> RolloutConfigView for &T {
|
||||
fn codex_home(&self) -> &Path {
|
||||
(*self).codex_home()
|
||||
}
|
||||
|
||||
fn sqlite_home(&self) -> &Path {
|
||||
(*self).sqlite_home()
|
||||
}
|
||||
|
||||
fn cwd(&self) -> &Path {
|
||||
(*self).cwd()
|
||||
}
|
||||
|
||||
fn model_provider_id(&self) -> &str {
|
||||
(*self).model_provider_id()
|
||||
}
|
||||
|
||||
fn generate_memories(&self) -> bool {
|
||||
(*self).generate_memories()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RolloutConfigView + ?Sized> RolloutConfigView for Arc<T> {
|
||||
fn codex_home(&self) -> &Path {
|
||||
self.as_ref().codex_home()
|
||||
}
|
||||
|
||||
fn sqlite_home(&self) -> &Path {
|
||||
self.as_ref().sqlite_home()
|
||||
}
|
||||
|
||||
fn cwd(&self) -> &Path {
|
||||
self.as_ref().cwd()
|
||||
}
|
||||
|
||||
fn model_provider_id(&self) -> &str {
|
||||
self.as_ref().model_provider_id()
|
||||
}
|
||||
|
||||
fn generate_memories(&self) -> bool {
|
||||
self.as_ref().generate_memories()
|
||||
}
|
||||
}
|
||||
50
codex-rs/rollout/src/lib.rs
Normal file
50
codex-rs/rollout/src/lib.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
//! Rollout persistence and discovery for Codex session files.
|
||||
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
|
||||
pub mod config;
|
||||
pub mod list;
|
||||
pub mod metadata;
|
||||
pub mod policy;
|
||||
pub mod recorder;
|
||||
pub mod session_index;
|
||||
pub mod state_db;
|
||||
|
||||
pub(crate) mod default_client {
|
||||
pub use codex_login::default_client::*;
|
||||
}
|
||||
|
||||
pub(crate) use codex_protocol::protocol;
|
||||
|
||||
pub const SESSIONS_SUBDIR: &str = "sessions";
|
||||
pub const ARCHIVED_SESSIONS_SUBDIR: &str = "archived_sessions";
|
||||
pub static INTERACTIVE_SESSION_SOURCES: LazyLock<Vec<SessionSource>> = LazyLock::new(|| {
|
||||
vec![
|
||||
SessionSource::Cli,
|
||||
SessionSource::VSCode,
|
||||
SessionSource::Custom("atlas".to_string()),
|
||||
SessionSource::Custom("chatgpt".to_string()),
|
||||
]
|
||||
});
|
||||
|
||||
pub use codex_protocol::protocol::SessionMeta;
|
||||
pub use config::RolloutConfig;
|
||||
pub use config::RolloutConfigView;
|
||||
pub use list::find_archived_thread_path_by_id_str;
|
||||
pub use list::find_thread_path_by_id_str;
|
||||
#[deprecated(note = "use find_thread_path_by_id_str")]
|
||||
pub use list::find_thread_path_by_id_str as find_conversation_path_by_id_str;
|
||||
pub use list::rollout_date_parts;
|
||||
pub use policy::EventPersistenceMode;
|
||||
pub use recorder::RolloutRecorder;
|
||||
pub use recorder::RolloutRecorderParams;
|
||||
pub use session_index::append_thread_name;
|
||||
pub use session_index::find_thread_name_by_id;
|
||||
pub use session_index::find_thread_names_by_ids;
|
||||
pub use session_index::find_thread_path_by_name_str;
|
||||
pub use state_db::StateDbHandle;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
1275
codex-rs/rollout/src/list.rs
Normal file
1275
codex-rs/rollout/src/list.rs
Normal file
File diff suppressed because it is too large
Load Diff
443
codex-rs/rollout/src/metadata.rs
Normal file
443
codex-rs/rollout/src/metadata.rs
Normal file
@@ -0,0 +1,443 @@
|
||||
use crate::ARCHIVED_SESSIONS_SUBDIR;
|
||||
use crate::SESSIONS_SUBDIR;
|
||||
use crate::config::RolloutConfigView;
|
||||
use crate::list;
|
||||
use crate::list::parse_timestamp_uuid_from_filename;
|
||||
use crate::recorder::RolloutRecorder;
|
||||
use crate::state_db::normalize_cwd_for_state_db;
|
||||
use chrono::DateTime;
|
||||
use chrono::NaiveDateTime;
|
||||
use chrono::Timelike;
|
||||
use chrono::Utc;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_protocol::protocol::SessionMetaLine;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_state::BackfillState;
|
||||
use codex_state::BackfillStats;
|
||||
use codex_state::BackfillStatus;
|
||||
use codex_state::DB_ERROR_METRIC;
|
||||
use codex_state::DB_METRIC_BACKFILL;
|
||||
use codex_state::DB_METRIC_BACKFILL_DURATION_MS;
|
||||
use codex_state::ExtractionOutcome;
|
||||
use codex_state::ThreadMetadataBuilder;
|
||||
use codex_state::apply_rollout_item;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
const ROLLOUT_PREFIX: &str = "rollout-";
|
||||
const ROLLOUT_SUFFIX: &str = ".jsonl";
|
||||
const BACKFILL_BATCH_SIZE: usize = 200;
|
||||
#[cfg(not(test))]
|
||||
const BACKFILL_LEASE_SECONDS: i64 = 900;
|
||||
#[cfg(test)]
|
||||
const BACKFILL_LEASE_SECONDS: i64 = 1;
|
||||
|
||||
pub(crate) fn builder_from_session_meta(
|
||||
session_meta: &SessionMetaLine,
|
||||
rollout_path: &Path,
|
||||
) -> Option<ThreadMetadataBuilder> {
|
||||
let created_at = parse_timestamp_to_utc(session_meta.meta.timestamp.as_str())?;
|
||||
let mut builder = ThreadMetadataBuilder::new(
|
||||
session_meta.meta.id,
|
||||
rollout_path.to_path_buf(),
|
||||
created_at,
|
||||
session_meta.meta.source.clone(),
|
||||
);
|
||||
builder.model_provider = session_meta.meta.model_provider.clone();
|
||||
builder.agent_nickname = session_meta.meta.agent_nickname.clone();
|
||||
builder.agent_role = session_meta.meta.agent_role.clone();
|
||||
builder.agent_path = session_meta.meta.agent_path.clone();
|
||||
builder.cwd = session_meta.meta.cwd.clone();
|
||||
builder.cli_version = Some(session_meta.meta.cli_version.clone());
|
||||
builder.sandbox_policy = SandboxPolicy::new_read_only_policy();
|
||||
builder.approval_mode = AskForApproval::OnRequest;
|
||||
if let Some(git) = session_meta.git.as_ref() {
|
||||
builder.git_sha = git.commit_hash.as_ref().map(|sha| sha.0.clone());
|
||||
builder.git_branch = git.branch.clone();
|
||||
builder.git_origin_url = git.repository_url.clone();
|
||||
}
|
||||
Some(builder)
|
||||
}
|
||||
|
||||
pub fn builder_from_items(
|
||||
items: &[RolloutItem],
|
||||
rollout_path: &Path,
|
||||
) -> Option<ThreadMetadataBuilder> {
|
||||
if let Some(session_meta) = items.iter().find_map(|item| match item {
|
||||
RolloutItem::SessionMeta(meta_line) => Some(meta_line),
|
||||
RolloutItem::ResponseItem(_)
|
||||
| RolloutItem::Compacted(_)
|
||||
| RolloutItem::TurnContext(_)
|
||||
| RolloutItem::EventMsg(_) => None,
|
||||
}) && let Some(builder) = builder_from_session_meta(session_meta, rollout_path)
|
||||
{
|
||||
return Some(builder);
|
||||
}
|
||||
|
||||
let file_name = rollout_path.file_name()?.to_str()?;
|
||||
if !file_name.starts_with(ROLLOUT_PREFIX) || !file_name.ends_with(ROLLOUT_SUFFIX) {
|
||||
return None;
|
||||
}
|
||||
let (created_ts, uuid) = parse_timestamp_uuid_from_filename(file_name)?;
|
||||
let created_at =
|
||||
DateTime::<Utc>::from_timestamp(created_ts.unix_timestamp(), 0)?.with_nanosecond(0)?;
|
||||
let id = ThreadId::from_string(&uuid.to_string()).ok()?;
|
||||
Some(ThreadMetadataBuilder::new(
|
||||
id,
|
||||
rollout_path.to_path_buf(),
|
||||
created_at,
|
||||
SessionSource::default(),
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn extract_metadata_from_rollout(
|
||||
rollout_path: &Path,
|
||||
default_provider: &str,
|
||||
) -> anyhow::Result<ExtractionOutcome> {
|
||||
let (items, _thread_id, parse_errors) =
|
||||
RolloutRecorder::load_rollout_items(rollout_path).await?;
|
||||
if items.is_empty() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"empty session file: {}",
|
||||
rollout_path.display()
|
||||
));
|
||||
}
|
||||
let builder = builder_from_items(items.as_slice(), rollout_path).ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"rollout missing metadata builder: {}",
|
||||
rollout_path.display()
|
||||
)
|
||||
})?;
|
||||
let mut metadata = builder.build(default_provider);
|
||||
for item in &items {
|
||||
apply_rollout_item(&mut metadata, item, default_provider);
|
||||
}
|
||||
if let Some(updated_at) = file_modified_time_utc(rollout_path).await {
|
||||
metadata.updated_at = updated_at;
|
||||
}
|
||||
Ok(ExtractionOutcome {
|
||||
metadata,
|
||||
memory_mode: items.iter().rev().find_map(|item| match item {
|
||||
RolloutItem::SessionMeta(meta_line) => meta_line.meta.memory_mode.clone(),
|
||||
RolloutItem::ResponseItem(_)
|
||||
| RolloutItem::Compacted(_)
|
||||
| RolloutItem::TurnContext(_)
|
||||
| RolloutItem::EventMsg(_) => None,
|
||||
}),
|
||||
parse_errors,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn backfill_sessions(
|
||||
runtime: &codex_state::StateRuntime,
|
||||
config: &impl RolloutConfigView,
|
||||
) {
|
||||
let metric_client = codex_otel::metrics::global();
|
||||
let timer = metric_client
|
||||
.as_ref()
|
||||
.and_then(|otel| otel.start_timer(DB_METRIC_BACKFILL_DURATION_MS, &[]).ok());
|
||||
let backfill_state = match runtime.get_backfill_state().await {
|
||||
Ok(state) => state,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"failed to read backfill state at {}: {err}",
|
||||
config.codex_home().display()
|
||||
);
|
||||
BackfillState::default()
|
||||
}
|
||||
};
|
||||
if backfill_state.status == BackfillStatus::Complete {
|
||||
return;
|
||||
}
|
||||
let claimed = match runtime.try_claim_backfill(BACKFILL_LEASE_SECONDS).await {
|
||||
Ok(claimed) => claimed,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"failed to claim backfill worker at {}: {err}",
|
||||
config.codex_home().display()
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
if !claimed {
|
||||
info!(
|
||||
"state db backfill already running at {}; skipping duplicate worker",
|
||||
config.codex_home().display()
|
||||
);
|
||||
return;
|
||||
}
|
||||
let mut backfill_state = match runtime.get_backfill_state().await {
|
||||
Ok(state) => state,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"failed to read claimed backfill state at {}: {err}",
|
||||
config.codex_home().display()
|
||||
);
|
||||
BackfillState {
|
||||
status: BackfillStatus::Running,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
};
|
||||
if backfill_state.status != BackfillStatus::Running {
|
||||
if let Err(err) = runtime.mark_backfill_running().await {
|
||||
warn!(
|
||||
"failed to mark backfill running at {}: {err}",
|
||||
config.codex_home().display()
|
||||
);
|
||||
} else {
|
||||
backfill_state.status = BackfillStatus::Running;
|
||||
}
|
||||
}
|
||||
|
||||
let sessions_root = config.codex_home().join(SESSIONS_SUBDIR);
|
||||
let archived_root = config.codex_home().join(ARCHIVED_SESSIONS_SUBDIR);
|
||||
let mut rollout_paths: Vec<BackfillRolloutPath> = Vec::new();
|
||||
for (root, archived) in [(sessions_root, false), (archived_root, true)] {
|
||||
if !tokio::fs::try_exists(&root).await.unwrap_or(false) {
|
||||
continue;
|
||||
}
|
||||
match collect_rollout_paths(&root).await {
|
||||
Ok(paths) => {
|
||||
rollout_paths.extend(paths.into_iter().map(|path| BackfillRolloutPath {
|
||||
watermark: backfill_watermark_for_path(config.codex_home(), &path),
|
||||
path,
|
||||
archived,
|
||||
}));
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"failed to collect rollout paths under {}: {err}",
|
||||
root.display()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
rollout_paths.sort_by(|a, b| a.watermark.cmp(&b.watermark));
|
||||
if let Some(last_watermark) = backfill_state.last_watermark.as_deref() {
|
||||
rollout_paths.retain(|entry| entry.watermark.as_str() > last_watermark);
|
||||
}
|
||||
|
||||
let mut stats = BackfillStats {
|
||||
scanned: 0,
|
||||
upserted: 0,
|
||||
failed: 0,
|
||||
};
|
||||
let mut last_watermark = backfill_state.last_watermark.clone();
|
||||
for batch in rollout_paths.chunks(BACKFILL_BATCH_SIZE) {
|
||||
for rollout in batch {
|
||||
stats.scanned = stats.scanned.saturating_add(1);
|
||||
match extract_metadata_from_rollout(&rollout.path, config.model_provider_id()).await {
|
||||
Ok(outcome) => {
|
||||
if outcome.parse_errors > 0
|
||||
&& let Some(ref metric_client) = metric_client
|
||||
{
|
||||
let _ = metric_client.counter(
|
||||
DB_ERROR_METRIC,
|
||||
outcome.parse_errors as i64,
|
||||
&[("stage", "backfill_sessions")],
|
||||
);
|
||||
}
|
||||
let mut metadata = outcome.metadata;
|
||||
metadata.cwd = normalize_cwd_for_state_db(&metadata.cwd);
|
||||
let memory_mode = outcome.memory_mode.unwrap_or_else(|| "enabled".to_string());
|
||||
if let Ok(Some(existing_metadata)) = runtime.get_thread(metadata.id).await {
|
||||
metadata.prefer_existing_git_info(&existing_metadata);
|
||||
}
|
||||
if rollout.archived && metadata.archived_at.is_none() {
|
||||
let fallback_archived_at = metadata.updated_at;
|
||||
metadata.archived_at = file_modified_time_utc(&rollout.path)
|
||||
.await
|
||||
.or(Some(fallback_archived_at));
|
||||
}
|
||||
if let Err(err) = runtime.upsert_thread(&metadata).await {
|
||||
stats.failed = stats.failed.saturating_add(1);
|
||||
warn!("failed to upsert rollout {}: {err}", rollout.path.display());
|
||||
} else {
|
||||
if let Err(err) = runtime
|
||||
.set_thread_memory_mode(metadata.id, memory_mode.as_str())
|
||||
.await
|
||||
{
|
||||
stats.failed = stats.failed.saturating_add(1);
|
||||
warn!(
|
||||
"failed to restore memory mode for {}: {err}",
|
||||
rollout.path.display()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
stats.upserted = stats.upserted.saturating_add(1);
|
||||
if let Ok(meta_line) = list::read_session_meta_line(&rollout.path).await {
|
||||
if let Err(err) = runtime
|
||||
.persist_dynamic_tools(
|
||||
meta_line.meta.id,
|
||||
meta_line.meta.dynamic_tools.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
"failed to backfill dynamic tools {}: {err}",
|
||||
rollout.path.display()
|
||||
);
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
"failed to read session meta for dynamic tools {}",
|
||||
rollout.path.display()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
stats.failed = stats.failed.saturating_add(1);
|
||||
warn!(
|
||||
"failed to extract rollout {}: {err}",
|
||||
rollout.path.display()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(last_entry) = batch.last() {
|
||||
if let Err(err) = runtime
|
||||
.checkpoint_backfill(last_entry.watermark.as_str())
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
"failed to checkpoint backfill at {}: {err}",
|
||||
config.codex_home().display()
|
||||
);
|
||||
} else {
|
||||
last_watermark = Some(last_entry.watermark.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Err(err) = runtime
|
||||
.mark_backfill_complete(last_watermark.as_deref())
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
"failed to mark backfill complete at {}: {err}",
|
||||
config.codex_home().display()
|
||||
);
|
||||
}
|
||||
|
||||
info!(
|
||||
"state db backfill scanned={}, upserted={}, failed={}",
|
||||
stats.scanned, stats.upserted, stats.failed
|
||||
);
|
||||
if let Some(metric_client) = metric_client {
|
||||
let _ = metric_client.counter(
|
||||
DB_METRIC_BACKFILL,
|
||||
stats.upserted as i64,
|
||||
&[("status", "upserted")],
|
||||
);
|
||||
let _ = metric_client.counter(
|
||||
DB_METRIC_BACKFILL,
|
||||
stats.failed as i64,
|
||||
&[("status", "failed")],
|
||||
);
|
||||
}
|
||||
if let Some(timer) = timer.as_ref() {
|
||||
let status = if stats.failed == 0 {
|
||||
"success"
|
||||
} else if stats.upserted == 0 {
|
||||
"failed"
|
||||
} else {
|
||||
"partial_failure"
|
||||
};
|
||||
let _ = timer.record(&[("status", status)]);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct BackfillRolloutPath {
|
||||
watermark: String,
|
||||
path: PathBuf,
|
||||
archived: bool,
|
||||
}
|
||||
|
||||
fn backfill_watermark_for_path(codex_home: &Path, path: &Path) -> String {
|
||||
path.strip_prefix(codex_home)
|
||||
.unwrap_or(path)
|
||||
.to_string_lossy()
|
||||
.replace('\\', "/")
|
||||
}
|
||||
|
||||
async fn file_modified_time_utc(path: &Path) -> Option<DateTime<Utc>> {
|
||||
let modified = tokio::fs::metadata(path).await.ok()?.modified().ok()?;
|
||||
let updated_at: DateTime<Utc> = modified.into();
|
||||
updated_at.with_nanosecond(0)
|
||||
}
|
||||
|
||||
fn parse_timestamp_to_utc(ts: &str) -> Option<DateTime<Utc>> {
|
||||
const FILENAME_TS_FORMAT: &str = "%Y-%m-%dT%H-%M-%S";
|
||||
if let Ok(naive) = NaiveDateTime::parse_from_str(ts, FILENAME_TS_FORMAT) {
|
||||
let dt = DateTime::<Utc>::from_naive_utc_and_offset(naive, Utc);
|
||||
return dt.with_nanosecond(0);
|
||||
}
|
||||
if let Ok(dt) = DateTime::parse_from_rfc3339(ts) {
|
||||
return dt.with_timezone(&Utc).with_nanosecond(0);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
async fn collect_rollout_paths(root: &Path) -> std::io::Result<Vec<PathBuf>> {
|
||||
let mut stack = vec![root.to_path_buf()];
|
||||
let mut paths = Vec::new();
|
||||
while let Some(dir) = stack.pop() {
|
||||
let mut read_dir = match tokio::fs::read_dir(&dir).await {
|
||||
Ok(read_dir) => read_dir,
|
||||
Err(err) => {
|
||||
warn!("failed to read directory {}: {err}", dir.display());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
loop {
|
||||
let next_entry = match read_dir.next_entry().await {
|
||||
Ok(next_entry) => next_entry,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"failed to read directory entry under {}: {err}",
|
||||
dir.display()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let Some(entry) = next_entry else {
|
||||
break;
|
||||
};
|
||||
let path = entry.path();
|
||||
let file_type = match entry.file_type().await {
|
||||
Ok(file_type) => file_type,
|
||||
Err(err) => {
|
||||
warn!("failed to read file type for {}: {err}", path.display());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if file_type.is_dir() {
|
||||
stack.push(path);
|
||||
continue;
|
||||
}
|
||||
if !file_type.is_file() {
|
||||
continue;
|
||||
}
|
||||
let file_name = entry.file_name();
|
||||
let Some(name) = file_name.to_str() else {
|
||||
continue;
|
||||
};
|
||||
if name.starts_with(ROLLOUT_PREFIX) && name.ends_with(ROLLOUT_SUFFIX) {
|
||||
paths.push(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(paths)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "metadata_tests.rs"]
|
||||
mod tests;
|
||||
387
codex-rs/rollout/src/metadata_tests.rs
Normal file
387
codex-rs/rollout/src/metadata_tests.rs
Normal file
@@ -0,0 +1,387 @@
|
||||
#![allow(warnings, clippy::all)]
|
||||
|
||||
use super::*;
|
||||
use crate::config::RolloutConfig;
|
||||
use chrono::DateTime;
|
||||
use chrono::NaiveDateTime;
|
||||
use chrono::Timelike;
|
||||
use chrono::Utc;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::protocol::CompactedItem;
|
||||
use codex_protocol::protocol::GitInfo;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::RolloutLine;
|
||||
use codex_protocol::protocol::SessionMeta;
|
||||
use codex_protocol::protocol::SessionMetaLine;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_state::BackfillStatus;
|
||||
use codex_state::ThreadMetadataBuilder;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tempfile::tempdir;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn test_config(codex_home: PathBuf) -> RolloutConfig {
|
||||
RolloutConfig {
|
||||
sqlite_home: codex_home.clone(),
|
||||
cwd: codex_home.clone(),
|
||||
codex_home,
|
||||
model_provider_id: "test-provider".to_string(),
|
||||
generate_memories: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn extract_metadata_from_rollout_uses_session_meta() {
|
||||
let dir = tempdir().expect("tempdir");
|
||||
let uuid = Uuid::new_v4();
|
||||
let id = ThreadId::from_string(&uuid.to_string()).expect("thread id");
|
||||
let path = dir
|
||||
.path()
|
||||
.join(format!("rollout-2026-01-27T12-34-56-{uuid}.jsonl"));
|
||||
|
||||
let session_meta = SessionMeta {
|
||||
id,
|
||||
forked_from_id: None,
|
||||
timestamp: "2026-01-27T12:34:56Z".to_string(),
|
||||
cwd: dir.path().to_path_buf(),
|
||||
originator: "cli".to_string(),
|
||||
cli_version: "0.0.0".to_string(),
|
||||
source: SessionSource::default(),
|
||||
agent_path: None,
|
||||
agent_nickname: None,
|
||||
agent_role: None,
|
||||
model_provider: Some("openai".to_string()),
|
||||
base_instructions: None,
|
||||
dynamic_tools: None,
|
||||
memory_mode: None,
|
||||
};
|
||||
let session_meta_line = SessionMetaLine {
|
||||
meta: session_meta,
|
||||
git: None,
|
||||
};
|
||||
let rollout_line = RolloutLine {
|
||||
timestamp: "2026-01-27T12:34:56Z".to_string(),
|
||||
item: RolloutItem::SessionMeta(session_meta_line.clone()),
|
||||
};
|
||||
let json = serde_json::to_string(&rollout_line).expect("rollout json");
|
||||
let mut file = File::create(&path).expect("create rollout");
|
||||
writeln!(file, "{json}").expect("write rollout");
|
||||
|
||||
let outcome = extract_metadata_from_rollout(&path, "openai")
|
||||
.await
|
||||
.expect("extract");
|
||||
|
||||
let builder = builder_from_session_meta(&session_meta_line, path.as_path()).expect("builder");
|
||||
let mut expected = builder.build("openai");
|
||||
apply_rollout_item(&mut expected, &rollout_line.item, "openai");
|
||||
expected.updated_at = file_modified_time_utc(&path).await.expect("mtime");
|
||||
|
||||
assert_eq!(outcome.metadata, expected);
|
||||
assert_eq!(outcome.memory_mode, None);
|
||||
assert_eq!(outcome.parse_errors, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn extract_metadata_from_rollout_returns_latest_memory_mode() {
|
||||
let dir = tempdir().expect("tempdir");
|
||||
let uuid = Uuid::new_v4();
|
||||
let id = ThreadId::from_string(&uuid.to_string()).expect("thread id");
|
||||
let path = dir
|
||||
.path()
|
||||
.join(format!("rollout-2026-01-27T12-34-56-{uuid}.jsonl"));
|
||||
|
||||
let session_meta = SessionMeta {
|
||||
id,
|
||||
forked_from_id: None,
|
||||
timestamp: "2026-01-27T12:34:56Z".to_string(),
|
||||
cwd: dir.path().to_path_buf(),
|
||||
originator: "cli".to_string(),
|
||||
cli_version: "0.0.0".to_string(),
|
||||
source: SessionSource::default(),
|
||||
agent_path: None,
|
||||
agent_nickname: None,
|
||||
agent_role: None,
|
||||
model_provider: Some("openai".to_string()),
|
||||
base_instructions: None,
|
||||
dynamic_tools: None,
|
||||
memory_mode: None,
|
||||
};
|
||||
let polluted_meta = SessionMeta {
|
||||
memory_mode: Some("polluted".to_string()),
|
||||
..session_meta.clone()
|
||||
};
|
||||
let lines = vec![
|
||||
RolloutLine {
|
||||
timestamp: "2026-01-27T12:34:56Z".to_string(),
|
||||
item: RolloutItem::SessionMeta(SessionMetaLine {
|
||||
meta: session_meta,
|
||||
git: None,
|
||||
}),
|
||||
},
|
||||
RolloutLine {
|
||||
timestamp: "2026-01-27T12:35:00Z".to_string(),
|
||||
item: RolloutItem::SessionMeta(SessionMetaLine {
|
||||
meta: polluted_meta,
|
||||
git: None,
|
||||
}),
|
||||
},
|
||||
];
|
||||
let mut file = File::create(&path).expect("create rollout");
|
||||
for line in lines {
|
||||
writeln!(
|
||||
file,
|
||||
"{}",
|
||||
serde_json::to_string(&line).expect("serialize rollout line")
|
||||
)
|
||||
.expect("write rollout line");
|
||||
}
|
||||
|
||||
let outcome = extract_metadata_from_rollout(&path, "openai")
|
||||
.await
|
||||
.expect("extract");
|
||||
|
||||
assert_eq!(outcome.memory_mode.as_deref(), Some("polluted"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builder_from_items_falls_back_to_filename() {
|
||||
let dir = tempdir().expect("tempdir");
|
||||
let uuid = Uuid::new_v4();
|
||||
let path = dir
|
||||
.path()
|
||||
.join(format!("rollout-2026-01-27T12-34-56-{uuid}.jsonl"));
|
||||
let items = vec![RolloutItem::Compacted(CompactedItem {
|
||||
message: "noop".to_string(),
|
||||
replacement_history: None,
|
||||
})];
|
||||
|
||||
let builder = builder_from_items(items.as_slice(), path.as_path()).expect("builder");
|
||||
let naive = NaiveDateTime::parse_from_str("2026-01-27T12-34-56", "%Y-%m-%dT%H-%M-%S")
|
||||
.expect("timestamp");
|
||||
let created_at = DateTime::<Utc>::from_naive_utc_and_offset(naive, Utc)
|
||||
.with_nanosecond(0)
|
||||
.expect("nanosecond");
|
||||
let expected = ThreadMetadataBuilder::new(
|
||||
ThreadId::from_string(&uuid.to_string()).expect("thread id"),
|
||||
path,
|
||||
created_at,
|
||||
SessionSource::default(),
|
||||
);
|
||||
|
||||
assert_eq!(builder, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn backfill_sessions_resumes_from_watermark_and_marks_complete() {
|
||||
let dir = tempdir().expect("tempdir");
|
||||
let codex_home = dir.path().to_path_buf();
|
||||
let first_uuid = Uuid::new_v4();
|
||||
let second_uuid = Uuid::new_v4();
|
||||
let first_path = write_rollout_in_sessions(
|
||||
codex_home.as_path(),
|
||||
"2026-01-27T12-34-56",
|
||||
"2026-01-27T12:34:56Z",
|
||||
first_uuid,
|
||||
None,
|
||||
);
|
||||
let second_path = write_rollout_in_sessions(
|
||||
codex_home.as_path(),
|
||||
"2026-01-27T12-35-56",
|
||||
"2026-01-27T12:35:56Z",
|
||||
second_uuid,
|
||||
None,
|
||||
);
|
||||
|
||||
let runtime = codex_state::StateRuntime::init(codex_home.clone(), "test-provider".to_string())
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
let first_watermark = backfill_watermark_for_path(codex_home.as_path(), first_path.as_path());
|
||||
runtime.mark_backfill_running().await.expect("mark running");
|
||||
runtime
|
||||
.checkpoint_backfill(first_watermark.as_str())
|
||||
.await
|
||||
.expect("checkpoint first watermark");
|
||||
tokio::time::sleep(std::time::Duration::from_secs(
|
||||
(BACKFILL_LEASE_SECONDS + 1) as u64,
|
||||
))
|
||||
.await;
|
||||
|
||||
let config = test_config(codex_home.clone());
|
||||
backfill_sessions(runtime.as_ref(), &config).await;
|
||||
|
||||
let first_id = ThreadId::from_string(&first_uuid.to_string()).expect("first thread id");
|
||||
let second_id = ThreadId::from_string(&second_uuid.to_string()).expect("second thread id");
|
||||
assert_eq!(
|
||||
runtime
|
||||
.get_thread(first_id)
|
||||
.await
|
||||
.expect("get first thread"),
|
||||
None
|
||||
);
|
||||
assert!(
|
||||
runtime
|
||||
.get_thread(second_id)
|
||||
.await
|
||||
.expect("get second thread")
|
||||
.is_some()
|
||||
);
|
||||
|
||||
let state = runtime
|
||||
.get_backfill_state()
|
||||
.await
|
||||
.expect("get backfill state");
|
||||
assert_eq!(state.status, BackfillStatus::Complete);
|
||||
assert_eq!(
|
||||
state.last_watermark,
|
||||
Some(backfill_watermark_for_path(
|
||||
codex_home.as_path(),
|
||||
second_path.as_path()
|
||||
))
|
||||
);
|
||||
assert!(state.last_success_at.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn backfill_sessions_preserves_existing_git_branch_and_fills_missing_git_fields() {
|
||||
let dir = tempdir().expect("tempdir");
|
||||
let codex_home = dir.path().to_path_buf();
|
||||
let thread_uuid = Uuid::new_v4();
|
||||
let rollout_path = write_rollout_in_sessions(
|
||||
codex_home.as_path(),
|
||||
"2026-01-27T12-34-56",
|
||||
"2026-01-27T12:34:56Z",
|
||||
thread_uuid,
|
||||
Some(GitInfo {
|
||||
commit_hash: Some(codex_git_utils::GitSha::new("rollout-sha")),
|
||||
branch: Some("rollout-branch".to_string()),
|
||||
repository_url: Some("git@example.com:openai/codex.git".to_string()),
|
||||
}),
|
||||
);
|
||||
|
||||
let runtime = codex_state::StateRuntime::init(codex_home.clone(), "test-provider".to_string())
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
let thread_id = ThreadId::from_string(&thread_uuid.to_string()).expect("thread id");
|
||||
let mut existing = extract_metadata_from_rollout(&rollout_path, "test-provider")
|
||||
.await
|
||||
.expect("extract")
|
||||
.metadata;
|
||||
existing.git_sha = None;
|
||||
existing.git_branch = Some("sqlite-branch".to_string());
|
||||
existing.git_origin_url = None;
|
||||
runtime
|
||||
.upsert_thread(&existing)
|
||||
.await
|
||||
.expect("existing metadata upsert");
|
||||
|
||||
let config = test_config(codex_home.clone());
|
||||
backfill_sessions(runtime.as_ref(), &config).await;
|
||||
|
||||
let persisted = runtime
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("get thread")
|
||||
.expect("thread exists");
|
||||
assert_eq!(persisted.git_sha.as_deref(), Some("rollout-sha"));
|
||||
assert_eq!(persisted.git_branch.as_deref(), Some("sqlite-branch"));
|
||||
assert_eq!(
|
||||
persisted.git_origin_url.as_deref(),
|
||||
Some("git@example.com:openai/codex.git")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn backfill_sessions_normalizes_cwd_before_upsert() {
|
||||
let dir = tempdir().expect("tempdir");
|
||||
let codex_home = dir.path().to_path_buf();
|
||||
let thread_uuid = Uuid::new_v4();
|
||||
let session_cwd = codex_home.join(".");
|
||||
let rollout_path = write_rollout_in_sessions_with_cwd(
|
||||
codex_home.as_path(),
|
||||
"2026-01-27T12-34-56",
|
||||
"2026-01-27T12:34:56Z",
|
||||
thread_uuid,
|
||||
session_cwd.clone(),
|
||||
None,
|
||||
);
|
||||
|
||||
let runtime = codex_state::StateRuntime::init(codex_home.clone(), "test-provider".to_string())
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
let config = test_config(codex_home.clone());
|
||||
backfill_sessions(runtime.as_ref(), &config).await;
|
||||
|
||||
let thread_id = ThreadId::from_string(&thread_uuid.to_string()).expect("thread id");
|
||||
let stored = runtime
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("get thread")
|
||||
.expect("thread should be backfilled");
|
||||
|
||||
assert_eq!(stored.rollout_path, rollout_path);
|
||||
assert_eq!(stored.cwd, normalize_cwd_for_state_db(&session_cwd));
|
||||
}
|
||||
|
||||
fn write_rollout_in_sessions(
|
||||
codex_home: &Path,
|
||||
filename_ts: &str,
|
||||
event_ts: &str,
|
||||
thread_uuid: Uuid,
|
||||
git: Option<GitInfo>,
|
||||
) -> PathBuf {
|
||||
write_rollout_in_sessions_with_cwd(
|
||||
codex_home,
|
||||
filename_ts,
|
||||
event_ts,
|
||||
thread_uuid,
|
||||
codex_home.to_path_buf(),
|
||||
git,
|
||||
)
|
||||
}
|
||||
|
||||
fn write_rollout_in_sessions_with_cwd(
|
||||
codex_home: &Path,
|
||||
filename_ts: &str,
|
||||
event_ts: &str,
|
||||
thread_uuid: Uuid,
|
||||
cwd: PathBuf,
|
||||
git: Option<GitInfo>,
|
||||
) -> PathBuf {
|
||||
let id = ThreadId::from_string(&thread_uuid.to_string()).expect("thread id");
|
||||
let sessions_dir = codex_home.join("sessions");
|
||||
std::fs::create_dir_all(sessions_dir.as_path()).expect("create sessions dir");
|
||||
let path = sessions_dir.join(format!("rollout-{filename_ts}-{thread_uuid}.jsonl"));
|
||||
let session_meta = SessionMeta {
|
||||
id,
|
||||
forked_from_id: None,
|
||||
timestamp: event_ts.to_string(),
|
||||
cwd,
|
||||
originator: "cli".to_string(),
|
||||
cli_version: "0.0.0".to_string(),
|
||||
source: SessionSource::default(),
|
||||
agent_path: None,
|
||||
agent_nickname: None,
|
||||
agent_role: None,
|
||||
model_provider: Some("test-provider".to_string()),
|
||||
base_instructions: None,
|
||||
dynamic_tools: None,
|
||||
memory_mode: None,
|
||||
};
|
||||
let session_meta_line = SessionMetaLine {
|
||||
meta: session_meta,
|
||||
git,
|
||||
};
|
||||
let rollout_line = RolloutLine {
|
||||
timestamp: event_ts.to_string(),
|
||||
item: RolloutItem::SessionMeta(session_meta_line),
|
||||
};
|
||||
let json = serde_json::to_string(&rollout_line).expect("serialize rollout");
|
||||
let mut file = File::create(&path).expect("create rollout");
|
||||
writeln!(file, "{json}").expect("write rollout");
|
||||
path
|
||||
}
|
||||
208
codex-rs/rollout/src/policy.rs
Normal file
208
codex-rs/rollout/src/policy.rs
Normal file
@@ -0,0 +1,208 @@
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::RolloutItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
pub enum EventPersistenceMode {
|
||||
#[default]
|
||||
Limited,
|
||||
Extended,
|
||||
}
|
||||
|
||||
/// Whether a rollout `item` should be persisted in rollout files for the
|
||||
/// provided persistence `mode`.
|
||||
pub fn is_persisted_response_item(item: &RolloutItem, mode: EventPersistenceMode) -> bool {
|
||||
match item {
|
||||
RolloutItem::ResponseItem(item) => should_persist_response_item(item),
|
||||
RolloutItem::EventMsg(ev) => should_persist_event_msg(ev, mode),
|
||||
// Persist Codex executive markers so we can analyze flows (e.g., compaction, API turns).
|
||||
RolloutItem::Compacted(_) | RolloutItem::TurnContext(_) | RolloutItem::SessionMeta(_) => {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether a `ResponseItem` should be persisted in rollout files.
|
||||
#[inline]
|
||||
pub fn should_persist_response_item(item: &ResponseItem) -> bool {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::ToolSearchCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::ToolSearchOutput { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::ImageGenerationCall { .. }
|
||||
| ResponseItem::GhostSnapshot { .. }
|
||||
| ResponseItem::Compaction { .. } => true,
|
||||
ResponseItem::Other => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether a `ResponseItem` should be persisted for the memories.
|
||||
#[inline]
|
||||
pub fn should_persist_response_item_for_memories(item: &ResponseItem) -> bool {
|
||||
match item {
|
||||
ResponseItem::Message { role, .. } => role != "developer",
|
||||
ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::ToolSearchCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::ToolSearchOutput { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => true,
|
||||
ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::ImageGenerationCall { .. }
|
||||
| ResponseItem::GhostSnapshot { .. }
|
||||
| ResponseItem::Compaction { .. }
|
||||
| ResponseItem::Other => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether an `EventMsg` should be persisted in rollout files for the
|
||||
/// provided persistence `mode`.
|
||||
#[inline]
|
||||
pub fn should_persist_event_msg(ev: &EventMsg, mode: EventPersistenceMode) -> bool {
|
||||
match mode {
|
||||
EventPersistenceMode::Limited => should_persist_event_msg_limited(ev),
|
||||
EventPersistenceMode::Extended => should_persist_event_msg_extended(ev),
|
||||
}
|
||||
}
|
||||
|
||||
fn should_persist_event_msg_limited(ev: &EventMsg) -> bool {
|
||||
matches!(
|
||||
event_msg_persistence_mode(ev),
|
||||
Some(EventPersistenceMode::Limited)
|
||||
)
|
||||
}
|
||||
|
||||
fn should_persist_event_msg_extended(ev: &EventMsg) -> bool {
|
||||
matches!(
|
||||
event_msg_persistence_mode(ev),
|
||||
Some(EventPersistenceMode::Limited) | Some(EventPersistenceMode::Extended)
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the minimum persistence mode that includes this event.
|
||||
/// `None` means the event should never be persisted.
|
||||
fn event_msg_persistence_mode(ev: &EventMsg) -> Option<EventPersistenceMode> {
|
||||
match ev {
|
||||
EventMsg::UserMessage(_)
|
||||
| EventMsg::AgentMessage(_)
|
||||
| EventMsg::AgentReasoning(_)
|
||||
| EventMsg::AgentReasoningRawContent(_)
|
||||
| EventMsg::TokenCount(_)
|
||||
| EventMsg::ContextCompacted(_)
|
||||
| EventMsg::EnteredReviewMode(_)
|
||||
| EventMsg::ExitedReviewMode(_)
|
||||
| EventMsg::ThreadRolledBack(_)
|
||||
| EventMsg::UndoCompleted(_)
|
||||
| EventMsg::TurnAborted(_)
|
||||
| EventMsg::TurnStarted(_)
|
||||
| EventMsg::TurnComplete(_)
|
||||
| EventMsg::ImageGenerationEnd(_) => Some(EventPersistenceMode::Limited),
|
||||
EventMsg::ItemCompleted(event) => {
|
||||
// Plan items are derived from streaming tags and are not part of the
|
||||
// raw ResponseItem history, so we persist their completion to replay
|
||||
// them on resume without bloating rollouts with every item lifecycle.
|
||||
if matches!(event.item, codex_protocol::items::TurnItem::Plan(_)) {
|
||||
Some(EventPersistenceMode::Limited)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
EventMsg::Error(_)
|
||||
| EventMsg::GuardianAssessment(_)
|
||||
| EventMsg::WebSearchEnd(_)
|
||||
| EventMsg::ExecCommandEnd(_)
|
||||
| EventMsg::PatchApplyEnd(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::ViewImageToolCall(_)
|
||||
| EventMsg::CollabAgentSpawnEnd(_)
|
||||
| EventMsg::CollabAgentInteractionEnd(_)
|
||||
| EventMsg::CollabWaitingEnd(_)
|
||||
| EventMsg::CollabCloseEnd(_)
|
||||
| EventMsg::CollabResumeEnd(_)
|
||||
| EventMsg::DynamicToolCallRequest(_)
|
||||
| EventMsg::DynamicToolCallResponse(_) => Some(EventPersistenceMode::Extended),
|
||||
EventMsg::Warning(_)
|
||||
| EventMsg::RealtimeConversationStarted(_)
|
||||
| EventMsg::RealtimeConversationRealtime(_)
|
||||
| EventMsg::RealtimeConversationClosed(_)
|
||||
| EventMsg::ModelReroute(_)
|
||||
| EventMsg::AgentMessageDelta(_)
|
||||
| EventMsg::AgentReasoningDelta(_)
|
||||
| EventMsg::AgentReasoningRawContentDelta(_)
|
||||
| EventMsg::AgentReasoningSectionBreak(_)
|
||||
| EventMsg::RawResponseItem(_)
|
||||
| EventMsg::SessionConfigured(_)
|
||||
| EventMsg::ThreadNameUpdated(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::WebSearchBegin(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
| EventMsg::TerminalInteraction(_)
|
||||
| EventMsg::ExecCommandOutputDelta(_)
|
||||
| EventMsg::ExecApprovalRequest(_)
|
||||
| EventMsg::RequestPermissions(_)
|
||||
| EventMsg::RequestUserInput(_)
|
||||
| EventMsg::ElicitationRequest(_)
|
||||
| EventMsg::ApplyPatchApprovalRequest(_)
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::StreamError(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
| EventMsg::TurnDiff(_)
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::UndoStarted(_)
|
||||
| EventMsg::McpListToolsResponse(_)
|
||||
| EventMsg::McpStartupUpdate(_)
|
||||
| EventMsg::McpStartupComplete(_)
|
||||
| EventMsg::ListCustomPromptsResponse(_)
|
||||
| EventMsg::ListSkillsResponse(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::ShutdownComplete
|
||||
| EventMsg::DeprecationNotice(_)
|
||||
| EventMsg::ItemStarted(_)
|
||||
| EventMsg::HookStarted(_)
|
||||
| EventMsg::HookCompleted(_)
|
||||
| EventMsg::AgentMessageContentDelta(_)
|
||||
| EventMsg::PlanDelta(_)
|
||||
| EventMsg::ReasoningContentDelta(_)
|
||||
| EventMsg::ReasoningRawContentDelta(_)
|
||||
| EventMsg::SkillsUpdateAvailable
|
||||
| EventMsg::CollabAgentSpawnBegin(_)
|
||||
| EventMsg::CollabAgentInteractionBegin(_)
|
||||
| EventMsg::CollabWaitingBegin(_)
|
||||
| EventMsg::CollabCloseBegin(_)
|
||||
| EventMsg::CollabResumeBegin(_)
|
||||
| EventMsg::ImageGenerationBegin(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::EventPersistenceMode;
|
||||
use super::should_persist_event_msg;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::ImageGenerationEndEvent;
|
||||
|
||||
#[test]
|
||||
fn persists_image_generation_end_events_in_limited_mode() {
|
||||
let event = EventMsg::ImageGenerationEnd(ImageGenerationEndEvent {
|
||||
call_id: "ig_123".into(),
|
||||
status: "completed".into(),
|
||||
revised_prompt: Some("final prompt".into()),
|
||||
result: "Zm9v".into(),
|
||||
saved_path: None,
|
||||
});
|
||||
|
||||
assert!(should_persist_event_msg(
|
||||
&event,
|
||||
EventPersistenceMode::Limited
|
||||
));
|
||||
}
|
||||
}
|
||||
1111
codex-rs/rollout/src/recorder.rs
Normal file
1111
codex-rs/rollout/src/recorder.rs
Normal file
File diff suppressed because it is too large
Load Diff
485
codex-rs/rollout/src/recorder_tests.rs
Normal file
485
codex-rs/rollout/src/recorder_tests.rs
Normal file
@@ -0,0 +1,485 @@
|
||||
#![allow(warnings, clippy::all)]
|
||||
|
||||
use super::*;
|
||||
use crate::config::RolloutConfig;
|
||||
use chrono::TimeZone;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::protocol::AgentMessageEvent;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_protocol::protocol::TurnContextItem;
|
||||
use codex_protocol::protocol::UserMessageEvent;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
use tempfile::TempDir;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn test_config(codex_home: &Path) -> RolloutConfig {
|
||||
RolloutConfig {
|
||||
codex_home: codex_home.to_path_buf(),
|
||||
sqlite_home: codex_home.to_path_buf(),
|
||||
cwd: codex_home.to_path_buf(),
|
||||
model_provider_id: "test-provider".to_string(),
|
||||
generate_memories: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn write_session_file(root: &Path, ts: &str, uuid: Uuid) -> std::io::Result<PathBuf> {
|
||||
let day_dir = root.join("sessions/2025/01/03");
|
||||
fs::create_dir_all(&day_dir)?;
|
||||
let path = day_dir.join(format!("rollout-{ts}-{uuid}.jsonl"));
|
||||
let mut file = File::create(&path)?;
|
||||
let meta = serde_json::json!({
|
||||
"timestamp": ts,
|
||||
"type": "session_meta",
|
||||
"payload": {
|
||||
"id": uuid,
|
||||
"timestamp": ts,
|
||||
"cwd": ".",
|
||||
"originator": "test_originator",
|
||||
"cli_version": "test_version",
|
||||
"source": "cli",
|
||||
"model_provider": "test-provider",
|
||||
},
|
||||
});
|
||||
writeln!(file, "{meta}")?;
|
||||
let user_event = serde_json::json!({
|
||||
"timestamp": ts,
|
||||
"type": "event_msg",
|
||||
"payload": {
|
||||
"type": "user_message",
|
||||
"message": "Hello from user",
|
||||
"kind": "plain",
|
||||
},
|
||||
});
|
||||
writeln!(file, "{user_event}")?;
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recorder_materializes_only_after_explicit_persist() -> std::io::Result<()> {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
let thread_id = ThreadId::new();
|
||||
let recorder = RolloutRecorder::new(
|
||||
&config,
|
||||
RolloutRecorderParams::new(
|
||||
thread_id,
|
||||
None,
|
||||
SessionSource::Exec,
|
||||
BaseInstructions::default(),
|
||||
Vec::new(),
|
||||
EventPersistenceMode::Limited,
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let rollout_path = recorder.rollout_path().to_path_buf();
|
||||
assert!(
|
||||
!rollout_path.exists(),
|
||||
"rollout file should not exist before first user message"
|
||||
);
|
||||
|
||||
recorder
|
||||
.record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
|
||||
AgentMessageEvent {
|
||||
message: "buffered-event".to_string(),
|
||||
phase: None,
|
||||
memory_citation: None,
|
||||
},
|
||||
))])
|
||||
.await?;
|
||||
recorder.flush().await?;
|
||||
assert!(
|
||||
!rollout_path.exists(),
|
||||
"rollout file should remain deferred before first user message"
|
||||
);
|
||||
|
||||
recorder
|
||||
.record_items(&[RolloutItem::EventMsg(EventMsg::UserMessage(
|
||||
UserMessageEvent {
|
||||
message: "first-user-message".to_string(),
|
||||
images: None,
|
||||
local_images: Vec::new(),
|
||||
text_elements: Vec::new(),
|
||||
},
|
||||
))])
|
||||
.await?;
|
||||
recorder.flush().await?;
|
||||
assert!(
|
||||
!rollout_path.exists(),
|
||||
"user-message-like items should not materialize without explicit persist"
|
||||
);
|
||||
|
||||
recorder.persist().await?;
|
||||
// Second call verifies `persist()` is idempotent after materialization.
|
||||
recorder.persist().await?;
|
||||
assert!(rollout_path.exists(), "rollout file should be materialized");
|
||||
|
||||
let text = std::fs::read_to_string(&rollout_path)?;
|
||||
assert!(
|
||||
text.contains("\"type\":\"session_meta\""),
|
||||
"expected session metadata in rollout"
|
||||
);
|
||||
let buffered_idx = text
|
||||
.find("buffered-event")
|
||||
.expect("buffered event in rollout");
|
||||
let user_idx = text
|
||||
.find("first-user-message")
|
||||
.expect("first user message in rollout");
|
||||
assert!(
|
||||
buffered_idx < user_idx,
|
||||
"buffered items should preserve ordering"
|
||||
);
|
||||
let text_after_second_persist = std::fs::read_to_string(&rollout_path)?;
|
||||
assert_eq!(text_after_second_persist, text);
|
||||
|
||||
recorder.shutdown().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn metadata_irrelevant_events_touch_state_db_updated_at() -> std::io::Result<()> {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
|
||||
let state_db = StateRuntime::init(home.path().to_path_buf(), config.model_provider_id.clone())
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
state_db
|
||||
.mark_backfill_complete(None)
|
||||
.await
|
||||
.expect("backfill should be complete");
|
||||
|
||||
let thread_id = ThreadId::new();
|
||||
let recorder = RolloutRecorder::new(
|
||||
&config,
|
||||
RolloutRecorderParams::new(
|
||||
thread_id,
|
||||
None,
|
||||
SessionSource::Cli,
|
||||
BaseInstructions::default(),
|
||||
Vec::new(),
|
||||
EventPersistenceMode::Limited,
|
||||
),
|
||||
Some(state_db.clone()),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
recorder
|
||||
.record_items(&[RolloutItem::EventMsg(EventMsg::UserMessage(
|
||||
UserMessageEvent {
|
||||
message: "first-user-message".to_string(),
|
||||
images: None,
|
||||
local_images: Vec::new(),
|
||||
text_elements: Vec::new(),
|
||||
},
|
||||
))])
|
||||
.await?;
|
||||
recorder.persist().await?;
|
||||
recorder.flush().await?;
|
||||
let initial_thread = state_db
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("thread should load")
|
||||
.expect("thread should exist");
|
||||
let initial_updated_at = initial_thread.updated_at;
|
||||
let initial_title = initial_thread.title.clone();
|
||||
let initial_first_user_message = initial_thread.first_user_message.clone();
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
|
||||
recorder
|
||||
.record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
|
||||
AgentMessageEvent {
|
||||
message: "assistant text".to_string(),
|
||||
phase: None,
|
||||
memory_citation: None,
|
||||
},
|
||||
))])
|
||||
.await?;
|
||||
recorder.flush().await?;
|
||||
|
||||
let updated_thread = state_db
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("thread should load after agent message")
|
||||
.expect("thread should still exist");
|
||||
|
||||
assert!(updated_thread.updated_at > initial_updated_at);
|
||||
assert_eq!(updated_thread.title, initial_title);
|
||||
assert_eq!(
|
||||
updated_thread.first_user_message,
|
||||
initial_first_user_message
|
||||
);
|
||||
|
||||
recorder.shutdown().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn metadata_irrelevant_events_fall_back_to_upsert_when_thread_missing() -> std::io::Result<()>
|
||||
{
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
|
||||
let state_db = StateRuntime::init(home.path().to_path_buf(), config.model_provider_id.clone())
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let thread_id = ThreadId::new();
|
||||
let rollout_path = home.path().join("rollout.jsonl");
|
||||
let builder = ThreadMetadataBuilder::new(
|
||||
thread_id,
|
||||
rollout_path.clone(),
|
||||
Utc::now(),
|
||||
SessionSource::Cli,
|
||||
);
|
||||
let items = vec![RolloutItem::EventMsg(EventMsg::AgentMessage(
|
||||
AgentMessageEvent {
|
||||
message: "assistant text".to_string(),
|
||||
phase: None,
|
||||
memory_citation: None,
|
||||
},
|
||||
))];
|
||||
|
||||
sync_thread_state_after_write(
|
||||
Some(state_db.as_ref()),
|
||||
rollout_path.as_path(),
|
||||
Some(&builder),
|
||||
items.as_slice(),
|
||||
config.model_provider_id.as_str(),
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
let thread = state_db
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("thread should load after fallback")
|
||||
.expect("thread should be inserted after fallback");
|
||||
assert_eq!(thread.id, thread_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_threads_db_disabled_does_not_skip_paginated_items() -> std::io::Result<()> {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
|
||||
let newest = write_session_file(home.path(), "2025-01-03T12-00-00", Uuid::from_u128(9001))?;
|
||||
let middle = write_session_file(home.path(), "2025-01-02T12-00-00", Uuid::from_u128(9002))?;
|
||||
let _oldest = write_session_file(home.path(), "2025-01-01T12-00-00", Uuid::from_u128(9003))?;
|
||||
|
||||
let default_provider = config.model_provider_id.clone();
|
||||
let page1 = RolloutRecorder::list_threads(
|
||||
&config,
|
||||
1,
|
||||
None,
|
||||
ThreadSortKey::CreatedAt,
|
||||
&[],
|
||||
None,
|
||||
default_provider.as_str(),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(page1.items.len(), 1);
|
||||
assert_eq!(page1.items[0].path, newest);
|
||||
let cursor = page1.next_cursor.clone().expect("cursor should be present");
|
||||
|
||||
let page2 = RolloutRecorder::list_threads(
|
||||
&config,
|
||||
1,
|
||||
Some(&cursor),
|
||||
ThreadSortKey::CreatedAt,
|
||||
&[],
|
||||
None,
|
||||
default_provider.as_str(),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(page2.items.len(), 1);
|
||||
assert_eq!(page2.items[0].path, middle);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_threads_db_enabled_drops_missing_rollout_paths() -> std::io::Result<()> {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
|
||||
let uuid = Uuid::from_u128(9010);
|
||||
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
|
||||
let stale_path = home.path().join(format!(
|
||||
"sessions/2099/01/01/rollout-2099-01-01T00-00-00-{uuid}.jsonl"
|
||||
));
|
||||
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
home.path().to_path_buf(),
|
||||
config.model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
runtime
|
||||
.mark_backfill_complete(None)
|
||||
.await
|
||||
.expect("backfill should be complete");
|
||||
let created_at = chrono::Utc
|
||||
.with_ymd_and_hms(2025, 1, 3, 13, 0, 0)
|
||||
.single()
|
||||
.expect("valid datetime");
|
||||
let mut builder = codex_state::ThreadMetadataBuilder::new(
|
||||
thread_id,
|
||||
stale_path,
|
||||
created_at,
|
||||
SessionSource::Cli,
|
||||
);
|
||||
builder.model_provider = Some(config.model_provider_id.clone());
|
||||
builder.cwd = home.path().to_path_buf();
|
||||
let mut metadata = builder.build(config.model_provider_id.as_str());
|
||||
metadata.first_user_message = Some("Hello from user".to_string());
|
||||
runtime
|
||||
.upsert_thread(&metadata)
|
||||
.await
|
||||
.expect("state db upsert should succeed");
|
||||
|
||||
let default_provider = config.model_provider_id.clone();
|
||||
let page = RolloutRecorder::list_threads(
|
||||
&config,
|
||||
10,
|
||||
None,
|
||||
ThreadSortKey::CreatedAt,
|
||||
&[],
|
||||
None,
|
||||
default_provider.as_str(),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(page.items.len(), 0);
|
||||
let stored_path = runtime
|
||||
.find_rollout_path_by_id(thread_id, Some(false))
|
||||
.await
|
||||
.expect("state db lookup should succeed");
|
||||
assert_eq!(stored_path, None);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_threads_db_enabled_repairs_stale_rollout_paths() -> std::io::Result<()> {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
|
||||
let uuid = Uuid::from_u128(9011);
|
||||
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
|
||||
let real_path = write_session_file(home.path(), "2025-01-03T13-00-00", uuid)?;
|
||||
let stale_path = home.path().join(format!(
|
||||
"sessions/2099/01/01/rollout-2099-01-01T00-00-00-{uuid}.jsonl"
|
||||
));
|
||||
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
home.path().to_path_buf(),
|
||||
config.model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
runtime
|
||||
.mark_backfill_complete(None)
|
||||
.await
|
||||
.expect("backfill should be complete");
|
||||
let created_at = chrono::Utc
|
||||
.with_ymd_and_hms(2025, 1, 3, 13, 0, 0)
|
||||
.single()
|
||||
.expect("valid datetime");
|
||||
let mut builder = codex_state::ThreadMetadataBuilder::new(
|
||||
thread_id,
|
||||
stale_path,
|
||||
created_at,
|
||||
SessionSource::Cli,
|
||||
);
|
||||
builder.model_provider = Some(config.model_provider_id.clone());
|
||||
builder.cwd = home.path().to_path_buf();
|
||||
let mut metadata = builder.build(config.model_provider_id.as_str());
|
||||
metadata.first_user_message = Some("Hello from user".to_string());
|
||||
runtime
|
||||
.upsert_thread(&metadata)
|
||||
.await
|
||||
.expect("state db upsert should succeed");
|
||||
|
||||
let default_provider = config.model_provider_id.clone();
|
||||
let page = RolloutRecorder::list_threads(
|
||||
&config,
|
||||
1,
|
||||
None,
|
||||
ThreadSortKey::CreatedAt,
|
||||
&[],
|
||||
None,
|
||||
default_provider.as_str(),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(page.items.len(), 1);
|
||||
assert_eq!(page.items[0].path, real_path);
|
||||
|
||||
let repaired_path = runtime
|
||||
.find_rollout_path_by_id(thread_id, Some(false))
|
||||
.await
|
||||
.expect("state db lookup should succeed");
|
||||
assert_eq!(repaired_path, Some(real_path));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn resume_candidate_matches_cwd_reads_latest_turn_context() -> std::io::Result<()> {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let stale_cwd = home.path().join("stale");
|
||||
let latest_cwd = home.path().join("latest");
|
||||
fs::create_dir_all(&stale_cwd)?;
|
||||
fs::create_dir_all(&latest_cwd)?;
|
||||
|
||||
let path = write_session_file(home.path(), "2025-01-03T13-00-00", Uuid::from_u128(9012))?;
|
||||
let mut file = std::fs::OpenOptions::new().append(true).open(&path)?;
|
||||
let turn_context = RolloutLine {
|
||||
timestamp: "2025-01-03T13:00:01Z".to_string(),
|
||||
item: RolloutItem::TurnContext(TurnContextItem {
|
||||
turn_id: Some("turn-1".to_string()),
|
||||
trace_id: None,
|
||||
cwd: latest_cwd.clone(),
|
||||
current_date: None,
|
||||
timezone: None,
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
network: None,
|
||||
model: "test-model".to_string(),
|
||||
personality: None,
|
||||
collaboration_mode: None,
|
||||
realtime_active: None,
|
||||
effort: None,
|
||||
summary: ReasoningSummaryConfig::Auto,
|
||||
user_instructions: None,
|
||||
developer_instructions: None,
|
||||
final_output_json_schema: None,
|
||||
truncation_policy: None,
|
||||
}),
|
||||
};
|
||||
writeln!(file, "{}", serde_json::to_string(&turn_context)?)?;
|
||||
|
||||
assert!(
|
||||
resume_candidate_matches_cwd(
|
||||
path.as_path(),
|
||||
Some(stale_cwd.as_path()),
|
||||
latest_cwd.as_path(),
|
||||
"test-provider",
|
||||
)
|
||||
.await
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
233
codex-rs/rollout/src/session_index.rs
Normal file
233
codex-rs/rollout/src/session_index.rs
Normal file
@@ -0,0 +1,233 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
use std::io::Seek;
|
||||
use std::io::SeekFrom;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_protocol::ThreadId;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
const SESSION_INDEX_FILE: &str = "session_index.jsonl";
|
||||
const READ_CHUNK_SIZE: usize = 8192;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct SessionIndexEntry {
|
||||
pub id: ThreadId,
|
||||
pub thread_name: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
/// Append a thread name update to the session index.
|
||||
/// The index is append-only; the most recent entry wins when resolving names or ids.
|
||||
pub async fn append_thread_name(
|
||||
codex_home: &Path,
|
||||
thread_id: ThreadId,
|
||||
name: &str,
|
||||
) -> std::io::Result<()> {
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::well_known::Rfc3339;
|
||||
|
||||
let updated_at = OffsetDateTime::now_utc()
|
||||
.format(&Rfc3339)
|
||||
.unwrap_or_else(|_| "unknown".to_string());
|
||||
let entry = SessionIndexEntry {
|
||||
id: thread_id,
|
||||
thread_name: name.to_string(),
|
||||
updated_at,
|
||||
};
|
||||
append_session_index_entry(codex_home, &entry).await
|
||||
}
|
||||
|
||||
/// Append a raw session index entry to `session_index.jsonl`.
|
||||
/// The file is append-only; consumers scan from the end to find the newest match.
|
||||
pub async fn append_session_index_entry(
|
||||
codex_home: &Path,
|
||||
entry: &SessionIndexEntry,
|
||||
) -> std::io::Result<()> {
|
||||
let path = session_index_path(codex_home);
|
||||
let mut file = tokio::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&path)
|
||||
.await?;
|
||||
let mut line = serde_json::to_string(entry).map_err(std::io::Error::other)?;
|
||||
line.push('\n');
|
||||
file.write_all(line.as_bytes()).await?;
|
||||
file.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Find the latest thread name for a thread id, if any.
|
||||
pub async fn find_thread_name_by_id(
|
||||
codex_home: &Path,
|
||||
thread_id: &ThreadId,
|
||||
) -> std::io::Result<Option<String>> {
|
||||
let path = session_index_path(codex_home);
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
let id = *thread_id;
|
||||
let entry = tokio::task::spawn_blocking(move || scan_index_from_end_by_id(&path, &id))
|
||||
.await
|
||||
.map_err(std::io::Error::other)??;
|
||||
Ok(entry.map(|entry| entry.thread_name))
|
||||
}
|
||||
|
||||
/// Find the latest thread names for a batch of thread ids.
|
||||
pub async fn find_thread_names_by_ids(
|
||||
codex_home: &Path,
|
||||
thread_ids: &HashSet<ThreadId>,
|
||||
) -> std::io::Result<HashMap<ThreadId, String>> {
|
||||
let path = session_index_path(codex_home);
|
||||
if thread_ids.is_empty() || !path.exists() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
|
||||
let file = tokio::fs::File::open(&path).await?;
|
||||
let reader = tokio::io::BufReader::new(file);
|
||||
let mut lines = reader.lines();
|
||||
let mut names = HashMap::with_capacity(thread_ids.len());
|
||||
|
||||
while let Some(line) = lines.next_line().await? {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let Ok(entry) = serde_json::from_str::<SessionIndexEntry>(trimmed) else {
|
||||
continue;
|
||||
};
|
||||
let name = entry.thread_name.trim();
|
||||
if !name.is_empty() && thread_ids.contains(&entry.id) {
|
||||
names.insert(entry.id, name.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(names)
|
||||
}
|
||||
|
||||
/// Find the most recently updated thread id for a thread name, if any.
|
||||
pub async fn find_thread_id_by_name(
|
||||
codex_home: &Path,
|
||||
name: &str,
|
||||
) -> std::io::Result<Option<ThreadId>> {
|
||||
if name.trim().is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
let path = session_index_path(codex_home);
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
let name = name.to_string();
|
||||
let entry = tokio::task::spawn_blocking(move || scan_index_from_end_by_name(&path, &name))
|
||||
.await
|
||||
.map_err(std::io::Error::other)??;
|
||||
Ok(entry.map(|entry| entry.id))
|
||||
}
|
||||
|
||||
/// Locate a recorded thread rollout file by thread name using newest-first ordering.
|
||||
/// Returns `Ok(Some(path))` if found, `Ok(None)` if not present.
|
||||
pub async fn find_thread_path_by_name_str(
|
||||
codex_home: &Path,
|
||||
name: &str,
|
||||
) -> std::io::Result<Option<PathBuf>> {
|
||||
let Some(thread_id) = find_thread_id_by_name(codex_home, name).await? else {
|
||||
return Ok(None);
|
||||
};
|
||||
super::list::find_thread_path_by_id_str(codex_home, &thread_id.to_string()).await
|
||||
}
|
||||
|
||||
fn session_index_path(codex_home: &Path) -> PathBuf {
|
||||
codex_home.join(SESSION_INDEX_FILE)
|
||||
}
|
||||
|
||||
fn scan_index_from_end_by_id(
|
||||
path: &Path,
|
||||
thread_id: &ThreadId,
|
||||
) -> std::io::Result<Option<SessionIndexEntry>> {
|
||||
scan_index_from_end(path, |entry| entry.id == *thread_id)
|
||||
}
|
||||
|
||||
fn scan_index_from_end_by_name(
|
||||
path: &Path,
|
||||
name: &str,
|
||||
) -> std::io::Result<Option<SessionIndexEntry>> {
|
||||
scan_index_from_end(path, |entry| entry.thread_name == name)
|
||||
}
|
||||
|
||||
fn scan_index_from_end<F>(
|
||||
path: &Path,
|
||||
mut predicate: F,
|
||||
) -> std::io::Result<Option<SessionIndexEntry>>
|
||||
where
|
||||
F: FnMut(&SessionIndexEntry) -> bool,
|
||||
{
|
||||
let mut file = File::open(path)?;
|
||||
let mut remaining = file.metadata()?.len();
|
||||
let mut line_rev: Vec<u8> = Vec::new();
|
||||
let mut buf = vec![0u8; READ_CHUNK_SIZE];
|
||||
|
||||
while remaining > 0 {
|
||||
let read_size = usize::try_from(remaining.min(READ_CHUNK_SIZE as u64))
|
||||
.map_err(std::io::Error::other)?;
|
||||
remaining -= read_size as u64;
|
||||
file.seek(SeekFrom::Start(remaining))?;
|
||||
file.read_exact(&mut buf[..read_size])?;
|
||||
|
||||
for &byte in buf[..read_size].iter().rev() {
|
||||
if byte == b'\n' {
|
||||
if let Some(entry) = parse_line_from_rev(&mut line_rev, &mut predicate)? {
|
||||
return Ok(Some(entry));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
line_rev.push(byte);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(entry) = parse_line_from_rev(&mut line_rev, &mut predicate)? {
|
||||
return Ok(Some(entry));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn parse_line_from_rev<F>(
|
||||
line_rev: &mut Vec<u8>,
|
||||
predicate: &mut F,
|
||||
) -> std::io::Result<Option<SessionIndexEntry>>
|
||||
where
|
||||
F: FnMut(&SessionIndexEntry) -> bool,
|
||||
{
|
||||
if line_rev.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
line_rev.reverse();
|
||||
let line = std::mem::take(line_rev);
|
||||
let Ok(mut line) = String::from_utf8(line) else {
|
||||
return Ok(None);
|
||||
};
|
||||
if line.ends_with('\r') {
|
||||
line.pop();
|
||||
}
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
let Ok(entry) = serde_json::from_str::<SessionIndexEntry>(trimmed) else {
|
||||
return Ok(None);
|
||||
};
|
||||
if predicate(&entry) {
|
||||
return Ok(Some(entry));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "session_index_tests.rs"]
|
||||
mod tests;
|
||||
169
codex-rs/rollout/src/session_index_tests.rs
Normal file
169
codex-rs/rollout/src/session_index_tests.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
#![allow(warnings, clippy::all)]
|
||||
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use tempfile::TempDir;
|
||||
fn write_index(path: &Path, lines: &[SessionIndexEntry]) -> std::io::Result<()> {
|
||||
let mut out = String::new();
|
||||
for entry in lines {
|
||||
out.push_str(&serde_json::to_string(entry).unwrap());
|
||||
out.push('\n');
|
||||
}
|
||||
std::fs::write(path, out)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_thread_id_by_name_prefers_latest_entry() -> std::io::Result<()> {
|
||||
let temp = TempDir::new()?;
|
||||
let path = session_index_path(temp.path());
|
||||
let id1 = ThreadId::new();
|
||||
let id2 = ThreadId::new();
|
||||
let lines = vec![
|
||||
SessionIndexEntry {
|
||||
id: id1,
|
||||
thread_name: "same".to_string(),
|
||||
updated_at: "2024-01-01T00:00:00Z".to_string(),
|
||||
},
|
||||
SessionIndexEntry {
|
||||
id: id2,
|
||||
thread_name: "same".to_string(),
|
||||
updated_at: "2024-01-02T00:00:00Z".to_string(),
|
||||
},
|
||||
];
|
||||
write_index(&path, &lines)?;
|
||||
|
||||
let found = scan_index_from_end_by_name(&path, "same")?;
|
||||
assert_eq!(found.map(|entry| entry.id), Some(id2));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_thread_name_by_id_prefers_latest_entry() -> std::io::Result<()> {
|
||||
let temp = TempDir::new()?;
|
||||
let path = session_index_path(temp.path());
|
||||
let id = ThreadId::new();
|
||||
let lines = vec![
|
||||
SessionIndexEntry {
|
||||
id,
|
||||
thread_name: "first".to_string(),
|
||||
updated_at: "2024-01-01T00:00:00Z".to_string(),
|
||||
},
|
||||
SessionIndexEntry {
|
||||
id,
|
||||
thread_name: "second".to_string(),
|
||||
updated_at: "2024-01-02T00:00:00Z".to_string(),
|
||||
},
|
||||
];
|
||||
write_index(&path, &lines)?;
|
||||
|
||||
let found = scan_index_from_end_by_id(&path, &id)?;
|
||||
assert_eq!(
|
||||
found.map(|entry| entry.thread_name),
|
||||
Some("second".to_string())
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scan_index_returns_none_when_entry_missing() -> std::io::Result<()> {
|
||||
let temp = TempDir::new()?;
|
||||
let path = session_index_path(temp.path());
|
||||
let id = ThreadId::new();
|
||||
let lines = vec![SessionIndexEntry {
|
||||
id,
|
||||
thread_name: "present".to_string(),
|
||||
updated_at: "2024-01-01T00:00:00Z".to_string(),
|
||||
}];
|
||||
write_index(&path, &lines)?;
|
||||
|
||||
let missing_name = scan_index_from_end_by_name(&path, "missing")?;
|
||||
assert_eq!(missing_name, None);
|
||||
|
||||
let missing_id = scan_index_from_end_by_id(&path, &ThreadId::new())?;
|
||||
assert_eq!(missing_id, None);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn find_thread_names_by_ids_prefers_latest_entry() -> std::io::Result<()> {
|
||||
let temp = TempDir::new()?;
|
||||
let path = session_index_path(temp.path());
|
||||
let id1 = ThreadId::new();
|
||||
let id2 = ThreadId::new();
|
||||
let lines = vec![
|
||||
SessionIndexEntry {
|
||||
id: id1,
|
||||
thread_name: "first".to_string(),
|
||||
updated_at: "2024-01-01T00:00:00Z".to_string(),
|
||||
},
|
||||
SessionIndexEntry {
|
||||
id: id2,
|
||||
thread_name: "other".to_string(),
|
||||
updated_at: "2024-01-01T00:00:00Z".to_string(),
|
||||
},
|
||||
SessionIndexEntry {
|
||||
id: id1,
|
||||
thread_name: "latest".to_string(),
|
||||
updated_at: "2024-01-02T00:00:00Z".to_string(),
|
||||
},
|
||||
];
|
||||
write_index(&path, &lines)?;
|
||||
|
||||
let mut ids = HashSet::new();
|
||||
ids.insert(id1);
|
||||
ids.insert(id2);
|
||||
|
||||
let mut expected = HashMap::new();
|
||||
expected.insert(id1, "latest".to_string());
|
||||
expected.insert(id2, "other".to_string());
|
||||
|
||||
let found = find_thread_names_by_ids(temp.path(), &ids).await?;
|
||||
assert_eq!(found, expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scan_index_finds_latest_match_among_mixed_entries() -> std::io::Result<()> {
|
||||
let temp = TempDir::new()?;
|
||||
let path = session_index_path(temp.path());
|
||||
let id_target = ThreadId::new();
|
||||
let id_other = ThreadId::new();
|
||||
let expected = SessionIndexEntry {
|
||||
id: id_target,
|
||||
thread_name: "target".to_string(),
|
||||
updated_at: "2024-01-03T00:00:00Z".to_string(),
|
||||
};
|
||||
let expected_other = SessionIndexEntry {
|
||||
id: id_other,
|
||||
thread_name: "target".to_string(),
|
||||
updated_at: "2024-01-02T00:00:00Z".to_string(),
|
||||
};
|
||||
// Resolution is based on append order (scan from end), not updated_at.
|
||||
let lines = vec![
|
||||
SessionIndexEntry {
|
||||
id: id_target,
|
||||
thread_name: "target".to_string(),
|
||||
updated_at: "2024-01-01T00:00:00Z".to_string(),
|
||||
},
|
||||
expected_other.clone(),
|
||||
expected.clone(),
|
||||
SessionIndexEntry {
|
||||
id: ThreadId::new(),
|
||||
thread_name: "another".to_string(),
|
||||
updated_at: "2024-01-04T00:00:00Z".to_string(),
|
||||
},
|
||||
];
|
||||
write_index(&path, &lines)?;
|
||||
|
||||
let found_by_name = scan_index_from_end_by_name(&path, "target")?;
|
||||
assert_eq!(found_by_name, Some(expected.clone()));
|
||||
|
||||
let found_by_id = scan_index_from_end_by_id(&path, &id_target)?;
|
||||
assert_eq!(found_by_id, Some(expected));
|
||||
|
||||
let found_other_by_id = scan_index_from_end_by_id(&path, &id_other)?;
|
||||
assert_eq!(found_other_by_id, Some(expected_other));
|
||||
Ok(())
|
||||
}
|
||||
548
codex-rs/rollout/src/state_db.rs
Normal file
548
codex-rs/rollout/src/state_db.rs
Normal file
@@ -0,0 +1,548 @@
|
||||
use crate::config::RolloutConfig;
|
||||
use crate::config::RolloutConfigView;
|
||||
use crate::list::Cursor;
|
||||
use crate::list::ThreadSortKey;
|
||||
use crate::metadata;
|
||||
use chrono::DateTime;
|
||||
use chrono::NaiveDateTime;
|
||||
use chrono::Timelike;
|
||||
use chrono::Utc;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::dynamic_tools::DynamicToolSpec;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
pub use codex_state::LogEntry;
|
||||
use codex_state::ThreadMetadataBuilder;
|
||||
use codex_utils_path::normalize_for_path_comparison;
|
||||
use serde_json::Value;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Core-facing handle to the SQLite-backed state runtime.
|
||||
pub type StateDbHandle = Arc<codex_state::StateRuntime>;
|
||||
|
||||
/// Initialize the state runtime for thread state persistence and backfill checks.
|
||||
pub async fn init(config: &impl RolloutConfigView) -> Option<StateDbHandle> {
|
||||
let config = RolloutConfig::from_view(config);
|
||||
let runtime = match codex_state::StateRuntime::init(
|
||||
config.sqlite_home.clone(),
|
||||
config.model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(runtime) => runtime,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"failed to initialize state runtime at {}: {err}",
|
||||
config.sqlite_home.display()
|
||||
);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
let backfill_state = match runtime.get_backfill_state().await {
|
||||
Ok(state) => state,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"failed to read backfill state at {}: {err}",
|
||||
config.codex_home.display()
|
||||
);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
if backfill_state.status != codex_state::BackfillStatus::Complete {
|
||||
let runtime_for_backfill = runtime.clone();
|
||||
let config = config.clone();
|
||||
tokio::spawn(async move {
|
||||
metadata::backfill_sessions(runtime_for_backfill.as_ref(), &config).await;
|
||||
});
|
||||
}
|
||||
Some(runtime)
|
||||
}
|
||||
|
||||
/// Get the DB if the feature is enabled and the DB exists.
|
||||
pub async fn get_state_db(config: &impl RolloutConfigView) -> Option<StateDbHandle> {
|
||||
let state_path = codex_state::state_db_path(config.sqlite_home());
|
||||
if !tokio::fs::try_exists(&state_path).await.unwrap_or(false) {
|
||||
return None;
|
||||
}
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
config.sqlite_home().to_path_buf(),
|
||||
config.model_provider_id().to_string(),
|
||||
)
|
||||
.await
|
||||
.ok()?;
|
||||
require_backfill_complete(runtime, config.sqlite_home()).await
|
||||
}
|
||||
|
||||
/// Open the state runtime when the SQLite file exists, without feature gating.
|
||||
///
|
||||
/// This is used for parity checks during the SQLite migration phase.
|
||||
pub async fn open_if_present(codex_home: &Path, default_provider: &str) -> Option<StateDbHandle> {
|
||||
let db_path = codex_state::state_db_path(codex_home);
|
||||
if !tokio::fs::try_exists(&db_path).await.unwrap_or(false) {
|
||||
return None;
|
||||
}
|
||||
let runtime =
|
||||
codex_state::StateRuntime::init(codex_home.to_path_buf(), default_provider.to_string())
|
||||
.await
|
||||
.ok()?;
|
||||
require_backfill_complete(runtime, codex_home).await
|
||||
}
|
||||
|
||||
async fn require_backfill_complete(
|
||||
runtime: StateDbHandle,
|
||||
codex_home: &Path,
|
||||
) -> Option<StateDbHandle> {
|
||||
match runtime.get_backfill_state().await {
|
||||
Ok(state) if state.status == codex_state::BackfillStatus::Complete => Some(runtime),
|
||||
Ok(state) => {
|
||||
warn!(
|
||||
"state db backfill not complete at {} (status: {})",
|
||||
codex_home.display(),
|
||||
state.status.as_str()
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"failed to read backfill state at {}: {err}",
|
||||
codex_home.display()
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cursor_to_anchor(cursor: Option<&Cursor>) -> Option<codex_state::Anchor> {
|
||||
let cursor = cursor?;
|
||||
let value = serde_json::to_value(cursor).ok()?;
|
||||
let cursor_str = value.as_str()?;
|
||||
let (ts_str, id_str) = cursor_str.split_once('|')?;
|
||||
if id_str.contains('|') {
|
||||
return None;
|
||||
}
|
||||
let id = Uuid::parse_str(id_str).ok()?;
|
||||
let ts = if let Ok(naive) = NaiveDateTime::parse_from_str(ts_str, "%Y-%m-%dT%H-%M-%S") {
|
||||
DateTime::<Utc>::from_naive_utc_and_offset(naive, Utc)
|
||||
} else if let Ok(dt) = DateTime::parse_from_rfc3339(ts_str) {
|
||||
dt.with_timezone(&Utc)
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
.with_nanosecond(0)?;
|
||||
Some(codex_state::Anchor { ts, id })
|
||||
}
|
||||
|
||||
pub fn normalize_cwd_for_state_db(cwd: &Path) -> PathBuf {
|
||||
normalize_for_path_comparison(cwd).unwrap_or_else(|_| cwd.to_path_buf())
|
||||
}
|
||||
|
||||
/// List thread ids from SQLite for parity checks without rollout scanning.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn list_thread_ids_db(
|
||||
context: Option<&codex_state::StateRuntime>,
|
||||
codex_home: &Path,
|
||||
page_size: usize,
|
||||
cursor: Option<&Cursor>,
|
||||
sort_key: ThreadSortKey,
|
||||
allowed_sources: &[SessionSource],
|
||||
model_providers: Option<&[String]>,
|
||||
archived_only: bool,
|
||||
stage: &str,
|
||||
) -> Option<Vec<ThreadId>> {
|
||||
let ctx = context?;
|
||||
if ctx.codex_home() != codex_home {
|
||||
warn!(
|
||||
"state db codex_home mismatch: expected {}, got {}",
|
||||
ctx.codex_home().display(),
|
||||
codex_home.display()
|
||||
);
|
||||
}
|
||||
|
||||
let anchor = cursor_to_anchor(cursor);
|
||||
let allowed_sources: Vec<String> = allowed_sources
|
||||
.iter()
|
||||
.map(|value| match serde_json::to_value(value) {
|
||||
Ok(Value::String(s)) => s,
|
||||
Ok(other) => other.to_string(),
|
||||
Err(_) => String::new(),
|
||||
})
|
||||
.collect();
|
||||
let model_providers = model_providers.map(<[String]>::to_vec);
|
||||
match ctx
|
||||
.list_thread_ids(
|
||||
page_size,
|
||||
anchor.as_ref(),
|
||||
match sort_key {
|
||||
ThreadSortKey::CreatedAt => codex_state::SortKey::CreatedAt,
|
||||
ThreadSortKey::UpdatedAt => codex_state::SortKey::UpdatedAt,
|
||||
},
|
||||
allowed_sources.as_slice(),
|
||||
model_providers.as_deref(),
|
||||
archived_only,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(ids) => Some(ids),
|
||||
Err(err) => {
|
||||
warn!("state db list_thread_ids failed during {stage}: {err}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// List thread metadata from SQLite without rollout directory traversal.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn list_threads_db(
|
||||
context: Option<&codex_state::StateRuntime>,
|
||||
codex_home: &Path,
|
||||
page_size: usize,
|
||||
cursor: Option<&Cursor>,
|
||||
sort_key: ThreadSortKey,
|
||||
allowed_sources: &[SessionSource],
|
||||
model_providers: Option<&[String]>,
|
||||
archived: bool,
|
||||
search_term: Option<&str>,
|
||||
) -> Option<codex_state::ThreadsPage> {
|
||||
let ctx = context?;
|
||||
if ctx.codex_home() != codex_home {
|
||||
warn!(
|
||||
"state db codex_home mismatch: expected {}, got {}",
|
||||
ctx.codex_home().display(),
|
||||
codex_home.display()
|
||||
);
|
||||
}
|
||||
|
||||
let anchor = cursor_to_anchor(cursor);
|
||||
let allowed_sources: Vec<String> = allowed_sources
|
||||
.iter()
|
||||
.map(|value| match serde_json::to_value(value) {
|
||||
Ok(Value::String(s)) => s,
|
||||
Ok(other) => other.to_string(),
|
||||
Err(_) => String::new(),
|
||||
})
|
||||
.collect();
|
||||
let model_providers = model_providers.map(<[String]>::to_vec);
|
||||
match ctx
|
||||
.list_threads(
|
||||
page_size,
|
||||
anchor.as_ref(),
|
||||
match sort_key {
|
||||
ThreadSortKey::CreatedAt => codex_state::SortKey::CreatedAt,
|
||||
ThreadSortKey::UpdatedAt => codex_state::SortKey::UpdatedAt,
|
||||
},
|
||||
allowed_sources.as_slice(),
|
||||
model_providers.as_deref(),
|
||||
archived,
|
||||
search_term,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(mut page) => {
|
||||
let mut valid_items = Vec::with_capacity(page.items.len());
|
||||
for item in page.items {
|
||||
if tokio::fs::try_exists(&item.rollout_path)
|
||||
.await
|
||||
.unwrap_or(false)
|
||||
{
|
||||
valid_items.push(item);
|
||||
} else {
|
||||
warn!(
|
||||
"state db list_threads returned stale rollout path for thread {}: {}",
|
||||
item.id,
|
||||
item.rollout_path.display()
|
||||
);
|
||||
warn!("state db discrepancy during list_threads_db: stale_db_path_dropped");
|
||||
let _ = ctx.delete_thread(item.id).await;
|
||||
}
|
||||
}
|
||||
page.items = valid_items;
|
||||
Some(page)
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("state db list_threads failed: {err}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Look up the rollout path for a thread id using SQLite.
|
||||
pub async fn find_rollout_path_by_id(
|
||||
context: Option<&codex_state::StateRuntime>,
|
||||
thread_id: ThreadId,
|
||||
archived_only: Option<bool>,
|
||||
stage: &str,
|
||||
) -> Option<PathBuf> {
|
||||
let ctx = context?;
|
||||
ctx.find_rollout_path_by_id(thread_id, archived_only)
|
||||
.await
|
||||
.unwrap_or_else(|err| {
|
||||
warn!("state db find_rollout_path_by_id failed during {stage}: {err}");
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
/// Get dynamic tools for a thread id using SQLite.
|
||||
pub async fn get_dynamic_tools(
|
||||
context: Option<&codex_state::StateRuntime>,
|
||||
thread_id: ThreadId,
|
||||
stage: &str,
|
||||
) -> Option<Vec<DynamicToolSpec>> {
|
||||
let ctx = context?;
|
||||
match ctx.get_dynamic_tools(thread_id).await {
|
||||
Ok(tools) => tools,
|
||||
Err(err) => {
|
||||
warn!("state db get_dynamic_tools failed during {stage}: {err}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Persist dynamic tools for a thread id using SQLite, if none exist yet.
|
||||
pub async fn persist_dynamic_tools(
|
||||
context: Option<&codex_state::StateRuntime>,
|
||||
thread_id: ThreadId,
|
||||
tools: Option<&[DynamicToolSpec]>,
|
||||
stage: &str,
|
||||
) {
|
||||
let Some(ctx) = context else {
|
||||
return;
|
||||
};
|
||||
if let Err(err) = ctx.persist_dynamic_tools(thread_id, tools).await {
|
||||
warn!("state db persist_dynamic_tools failed during {stage}: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn mark_thread_memory_mode_polluted(
|
||||
context: Option<&codex_state::StateRuntime>,
|
||||
thread_id: ThreadId,
|
||||
stage: &str,
|
||||
) {
|
||||
let Some(ctx) = context else {
|
||||
return;
|
||||
};
|
||||
if let Err(err) = ctx.mark_thread_memory_mode_polluted(thread_id).await {
|
||||
warn!("state db mark_thread_memory_mode_polluted failed during {stage}: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Reconcile rollout items into SQLite, falling back to scanning the rollout file.
|
||||
pub async fn reconcile_rollout(
|
||||
context: Option<&codex_state::StateRuntime>,
|
||||
rollout_path: &Path,
|
||||
default_provider: &str,
|
||||
builder: Option<&ThreadMetadataBuilder>,
|
||||
items: &[RolloutItem],
|
||||
archived_only: Option<bool>,
|
||||
new_thread_memory_mode: Option<&str>,
|
||||
) {
|
||||
let Some(ctx) = context else {
|
||||
return;
|
||||
};
|
||||
if builder.is_some() || !items.is_empty() {
|
||||
apply_rollout_items(
|
||||
Some(ctx),
|
||||
rollout_path,
|
||||
default_provider,
|
||||
builder,
|
||||
items,
|
||||
"reconcile_rollout",
|
||||
new_thread_memory_mode,
|
||||
/*updated_at_override*/ None,
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
let outcome =
|
||||
match metadata::extract_metadata_from_rollout(rollout_path, default_provider).await {
|
||||
Ok(outcome) => outcome,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"state db reconcile_rollout extraction failed {}: {err}",
|
||||
rollout_path.display()
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let mut metadata = outcome.metadata;
|
||||
let memory_mode = outcome.memory_mode.unwrap_or_else(|| "enabled".to_string());
|
||||
metadata.cwd = normalize_cwd_for_state_db(&metadata.cwd);
|
||||
if let Ok(Some(existing_metadata)) = ctx.get_thread(metadata.id).await {
|
||||
metadata.prefer_existing_git_info(&existing_metadata);
|
||||
}
|
||||
match archived_only {
|
||||
Some(true) if metadata.archived_at.is_none() => {
|
||||
metadata.archived_at = Some(metadata.updated_at);
|
||||
}
|
||||
Some(false) => {
|
||||
metadata.archived_at = None;
|
||||
}
|
||||
Some(true) | None => {}
|
||||
}
|
||||
if let Err(err) = ctx.upsert_thread(&metadata).await {
|
||||
warn!(
|
||||
"state db reconcile_rollout upsert failed {}: {err}",
|
||||
rollout_path.display()
|
||||
);
|
||||
return;
|
||||
}
|
||||
if let Err(err) = ctx
|
||||
.set_thread_memory_mode(metadata.id, memory_mode.as_str())
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
"state db reconcile_rollout memory_mode update failed {}: {err}",
|
||||
rollout_path.display()
|
||||
);
|
||||
return;
|
||||
}
|
||||
if let Ok(meta_line) = crate::list::read_session_meta_line(rollout_path).await {
|
||||
persist_dynamic_tools(
|
||||
Some(ctx),
|
||||
meta_line.meta.id,
|
||||
meta_line.meta.dynamic_tools.as_deref(),
|
||||
"reconcile_rollout",
|
||||
)
|
||||
.await;
|
||||
} else {
|
||||
warn!(
|
||||
"state db reconcile_rollout missing session meta {}",
|
||||
rollout_path.display()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Repair a thread's rollout path after filesystem fallback succeeds.
|
||||
pub async fn read_repair_rollout_path(
|
||||
context: Option<&codex_state::StateRuntime>,
|
||||
thread_id: Option<ThreadId>,
|
||||
archived_only: Option<bool>,
|
||||
rollout_path: &Path,
|
||||
) {
|
||||
let Some(ctx) = context else {
|
||||
return;
|
||||
};
|
||||
|
||||
// Fast path: update an existing metadata row in place, but avoid writes when
|
||||
// read-repair computes no effective change.
|
||||
let mut saw_existing_metadata = false;
|
||||
if let Some(thread_id) = thread_id
|
||||
&& let Ok(Some(metadata)) = ctx.get_thread(thread_id).await
|
||||
{
|
||||
saw_existing_metadata = true;
|
||||
let mut repaired = metadata.clone();
|
||||
repaired.rollout_path = rollout_path.to_path_buf();
|
||||
repaired.cwd = normalize_cwd_for_state_db(&repaired.cwd);
|
||||
match archived_only {
|
||||
Some(true) if repaired.archived_at.is_none() => {
|
||||
repaired.archived_at = Some(repaired.updated_at);
|
||||
}
|
||||
Some(false) => {
|
||||
repaired.archived_at = None;
|
||||
}
|
||||
Some(true) | None => {}
|
||||
}
|
||||
if repaired == metadata {
|
||||
return;
|
||||
}
|
||||
warn!("state db discrepancy during read_repair_rollout_path: upsert_needed (fast path)");
|
||||
if let Err(err) = ctx.upsert_thread(&repaired).await {
|
||||
warn!(
|
||||
"state db read-repair upsert failed for {}: {err}",
|
||||
rollout_path.display()
|
||||
);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: when the row is missing/unreadable (or direct upsert failed),
|
||||
// rebuild metadata from rollout contents and reconcile it into SQLite.
|
||||
if !saw_existing_metadata {
|
||||
warn!("state db discrepancy during read_repair_rollout_path: upsert_needed (slow path)");
|
||||
}
|
||||
let default_provider = crate::list::read_session_meta_line(rollout_path)
|
||||
.await
|
||||
.ok()
|
||||
.and_then(|meta| meta.meta.model_provider)
|
||||
.unwrap_or_default();
|
||||
reconcile_rollout(
|
||||
Some(ctx),
|
||||
rollout_path,
|
||||
default_provider.as_str(),
|
||||
/*builder*/ None,
|
||||
&[],
|
||||
archived_only,
|
||||
/*new_thread_memory_mode*/ None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Apply rollout items incrementally to SQLite.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn apply_rollout_items(
|
||||
context: Option<&codex_state::StateRuntime>,
|
||||
rollout_path: &Path,
|
||||
_default_provider: &str,
|
||||
builder: Option<&ThreadMetadataBuilder>,
|
||||
items: &[RolloutItem],
|
||||
stage: &str,
|
||||
new_thread_memory_mode: Option<&str>,
|
||||
updated_at_override: Option<DateTime<Utc>>,
|
||||
) {
|
||||
let Some(ctx) = context else {
|
||||
return;
|
||||
};
|
||||
let mut builder = match builder {
|
||||
Some(builder) => builder.clone(),
|
||||
None => match metadata::builder_from_items(items, rollout_path) {
|
||||
Some(builder) => builder,
|
||||
None => {
|
||||
warn!(
|
||||
"state db apply_rollout_items missing builder during {stage}: {}",
|
||||
rollout_path.display()
|
||||
);
|
||||
warn!("state db discrepancy during apply_rollout_items: {stage}, missing_builder");
|
||||
return;
|
||||
}
|
||||
},
|
||||
};
|
||||
builder.rollout_path = rollout_path.to_path_buf();
|
||||
builder.cwd = normalize_cwd_for_state_db(&builder.cwd);
|
||||
if let Err(err) = ctx
|
||||
.apply_rollout_items(&builder, items, new_thread_memory_mode, updated_at_override)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
"state db apply_rollout_items failed during {stage} for {}: {err}",
|
||||
rollout_path.display()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn touch_thread_updated_at(
|
||||
context: Option<&codex_state::StateRuntime>,
|
||||
thread_id: Option<ThreadId>,
|
||||
updated_at: DateTime<Utc>,
|
||||
stage: &str,
|
||||
) -> bool {
|
||||
let Some(ctx) = context else {
|
||||
return false;
|
||||
};
|
||||
let Some(thread_id) = thread_id else {
|
||||
return false;
|
||||
};
|
||||
ctx.touch_thread_updated_at(thread_id, updated_at)
|
||||
.await
|
||||
.unwrap_or_else(|err| {
|
||||
warn!("state db touch_thread_updated_at failed during {stage} for {thread_id}: {err}");
|
||||
false
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "state_db_tests.rs"]
|
||||
mod tests;
|
||||
28
codex-rs/rollout/src/state_db_tests.rs
Normal file
28
codex-rs/rollout/src/state_db_tests.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
#![allow(warnings, clippy::all)]
|
||||
|
||||
use super::*;
|
||||
use crate::list::parse_cursor;
|
||||
use chrono::DateTime;
|
||||
use chrono::NaiveDateTime;
|
||||
use chrono::Timelike;
|
||||
use chrono::Utc;
|
||||
use pretty_assertions::assert_eq;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[test]
|
||||
fn cursor_to_anchor_normalizes_timestamp_format() {
|
||||
let uuid = Uuid::new_v4();
|
||||
let ts_str = "2026-01-27T12-34-56";
|
||||
let token = format!("{ts_str}|{uuid}");
|
||||
let cursor = parse_cursor(token.as_str()).expect("cursor should parse");
|
||||
let anchor = cursor_to_anchor(Some(&cursor)).expect("anchor should parse");
|
||||
|
||||
let naive =
|
||||
NaiveDateTime::parse_from_str(ts_str, "%Y-%m-%dT%H-%M-%S").expect("ts should parse");
|
||||
let expected_ts = DateTime::<Utc>::from_naive_utc_and_offset(naive, Utc)
|
||||
.with_nanosecond(0)
|
||||
.expect("nanosecond");
|
||||
|
||||
assert_eq!(anchor.id, uuid);
|
||||
assert_eq!(anchor.ts, expected_ts);
|
||||
}
|
||||
1474
codex-rs/rollout/src/tests.rs
Normal file
1474
codex-rs/rollout/src/tests.rs
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user