Files
codex/codex-rs/thread-store/src/in_memory.rs
Tom c51c65ad09 Unify thread metadata updates above store (#22236)
- make ThreadStore::update_thread_metadata accept a broad range of
metadata patches
- keep ThreadStore::append_items as raw canonical history append (no
metadata side effects)
- in the local store, write these metadata updates to a combination of
sqlite and rollout jsonl files for backwards-compat. It special cases
which fields need to go into jsonl vs sqlite vs whatever, confining the
awkwardness to just this implementation
- in remote stores we can simply persist the metadata directly to a
database, no special casing required.
- move the "implicit metadata updates triggered by appending rollout
items" from the RolloutRecorder (which is local-threadstore-specific) to
the LiveThread layer above the ThreadStore, inside of a private helper
utility called ThreadMetadataSync. LiveThread calls ThreadStore
append_items and update_metadata separately.
- Add a generic update metadata method to ThreadManager that works on
both live threads and "cold" threads
- Call that ThreadManager method from app server code, so app server
doesn't need to worry about whether the thread is live or not
2026-05-13 00:28:15 +00:00

389 lines
13 KiB
Rust

use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::MutexGuard;
use std::sync::OnceLock;
use async_trait::async_trait;
use chrono::Utc;
use codex_protocol::ThreadId;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::SandboxPolicy;
use crate::AppendThreadItemsParams;
use crate::ArchiveThreadParams;
use crate::CreateThreadParams;
use crate::ListThreadsParams;
use crate::LoadThreadHistoryParams;
use crate::ReadThreadByRolloutPathParams;
use crate::ReadThreadParams;
use crate::ResumeThreadParams;
use crate::StoredThread;
use crate::StoredThreadHistory;
use crate::ThreadMetadataPatch;
use crate::ThreadPage;
use crate::ThreadStore;
use crate::ThreadStoreError;
use crate::ThreadStoreResult;
use crate::UpdateThreadMetadataParams;
static IN_MEMORY_THREAD_STORES: OnceLock<Mutex<HashMap<String, Arc<InMemoryThreadStore>>>> =
OnceLock::new();
fn stores() -> &'static Mutex<HashMap<String, Arc<InMemoryThreadStore>>> {
IN_MEMORY_THREAD_STORES.get_or_init(|| Mutex::new(HashMap::new()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ListItemsParams;
use crate::ListTurnsParams;
use crate::SortDirection;
use crate::StoredTurnItemsView;
#[tokio::test]
async fn default_turn_pagination_methods_return_unsupported() {
let store = InMemoryThreadStore::default();
let thread_id = ThreadId::default();
let turns_err = store
.list_turns(ListTurnsParams {
thread_id,
include_archived: true,
cursor: None,
page_size: 10,
sort_direction: SortDirection::Asc,
items_view: StoredTurnItemsView::Summary,
})
.await
.expect_err("default list_turns should be unsupported");
assert!(matches!(
turns_err,
ThreadStoreError::Unsupported {
operation: "list_turns"
}
));
let items_err = store
.list_items(ListItemsParams {
thread_id,
turn_id: "turn_1".to_string(),
include_archived: true,
cursor: None,
page_size: 10,
sort_direction: SortDirection::Asc,
})
.await
.expect_err("default list_items should be unsupported");
assert!(matches!(
items_err,
ThreadStoreError::Unsupported {
operation: "list_items"
}
));
}
}
fn stores_guard() -> MutexGuard<'static, HashMap<String, Arc<InMemoryThreadStore>>> {
match stores().lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
/// Recorded call counts for [`InMemoryThreadStore`].
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct InMemoryThreadStoreCalls {
pub create_thread: usize,
pub resume_thread: usize,
pub append_items: usize,
pub persist_thread: usize,
pub flush_thread: usize,
pub shutdown_thread: usize,
pub discard_thread: usize,
pub load_history: usize,
pub read_thread: usize,
pub read_thread_by_rollout_path: usize,
pub list_threads: usize,
pub update_thread_metadata: usize,
pub archive_thread: usize,
pub unarchive_thread: usize,
}
/// In-memory [`ThreadStore`] implementation for tests and debug configs.
///
/// Test and debug configs can select this store by id, letting tests exercise
/// config-driven non-local persistence without requiring the real remote gRPC
/// service.
#[derive(Default)]
pub struct InMemoryThreadStore {
state: tokio::sync::Mutex<InMemoryThreadStoreState>,
}
#[derive(Default)]
struct InMemoryThreadStoreState {
calls: InMemoryThreadStoreCalls,
created_threads: HashMap<ThreadId, CreateThreadParams>,
histories: HashMap<ThreadId, Vec<RolloutItem>>,
metadata_updates: HashMap<ThreadId, ThreadMetadataPatch>,
names: HashMap<ThreadId, Option<String>>,
rollout_paths: HashMap<PathBuf, ThreadId>,
}
impl InMemoryThreadStore {
/// Returns the store associated with `id`, creating it if needed.
pub fn for_id(id: impl Into<String>) -> Arc<Self> {
let id = id.into();
let mut stores = stores_guard();
stores
.entry(id)
.or_insert_with(|| Arc::new(Self::default()))
.clone()
}
/// Removes a shared in-memory store for `id`.
pub fn remove_id(id: &str) -> Option<Arc<Self>> {
stores_guard().remove(id)
}
/// Returns the calls observed by this store.
pub async fn calls(&self) -> InMemoryThreadStoreCalls {
self.state.lock().await.calls.clone()
}
}
#[async_trait]
impl ThreadStore for InMemoryThreadStore {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn create_thread(&self, params: CreateThreadParams) -> ThreadStoreResult<()> {
let mut state = self.state.lock().await;
state.calls.create_thread += 1;
state.histories.entry(params.thread_id).or_default();
state.created_threads.insert(params.thread_id, params);
Ok(())
}
async fn resume_thread(&self, params: ResumeThreadParams) -> ThreadStoreResult<()> {
let mut state = self.state.lock().await;
state.calls.resume_thread += 1;
state.histories.entry(params.thread_id).or_default();
if let Some(rollout_path) = params.rollout_path {
state.rollout_paths.insert(rollout_path, params.thread_id);
}
Ok(())
}
async fn append_items(&self, params: AppendThreadItemsParams) -> ThreadStoreResult<()> {
let mut state = self.state.lock().await;
state.calls.append_items += 1;
state
.histories
.entry(params.thread_id)
.or_default()
.extend(params.items);
Ok(())
}
async fn persist_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> {
self.state.lock().await.calls.persist_thread += 1;
Ok(())
}
async fn flush_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> {
self.state.lock().await.calls.flush_thread += 1;
Ok(())
}
async fn shutdown_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> {
self.state.lock().await.calls.shutdown_thread += 1;
Ok(())
}
async fn discard_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> {
self.state.lock().await.calls.discard_thread += 1;
Ok(())
}
async fn load_history(
&self,
params: LoadThreadHistoryParams,
) -> ThreadStoreResult<StoredThreadHistory> {
let mut state = self.state.lock().await;
state.calls.load_history += 1;
let items = state.histories.get(&params.thread_id).cloned().ok_or(
ThreadStoreError::ThreadNotFound {
thread_id: params.thread_id,
},
)?;
Ok(StoredThreadHistory {
thread_id: params.thread_id,
items,
})
}
async fn read_thread(&self, params: ReadThreadParams) -> ThreadStoreResult<StoredThread> {
let mut state = self.state.lock().await;
state.calls.read_thread += 1;
stored_thread_from_state(&state, params.thread_id, params.include_history)
}
async fn read_thread_by_rollout_path(
&self,
params: ReadThreadByRolloutPathParams,
) -> ThreadStoreResult<StoredThread> {
let mut state = self.state.lock().await;
state.calls.read_thread_by_rollout_path += 1;
let Some(thread_id) = state.rollout_paths.get(&params.rollout_path).copied() else {
return Err(ThreadStoreError::InvalidRequest {
message: format!(
"in-memory thread store does not know rollout path {}",
params.rollout_path.display()
),
});
};
stored_thread_from_state(&state, thread_id, params.include_history)
}
async fn list_threads(&self, _params: ListThreadsParams) -> ThreadStoreResult<ThreadPage> {
let mut state = self.state.lock().await;
state.calls.list_threads += 1;
let mut items = state
.created_threads
.keys()
.map(|thread_id| {
stored_thread_from_state(&state, *thread_id, /*include_history*/ false)
})
.collect::<ThreadStoreResult<Vec<_>>>()?;
items.sort_by_key(|item| item.thread_id.to_string());
Ok(ThreadPage {
items,
next_cursor: None,
})
}
async fn update_thread_metadata(
&self,
params: UpdateThreadMetadataParams,
) -> ThreadStoreResult<StoredThread> {
let mut state = self.state.lock().await;
state.calls.update_thread_metadata += 1;
if let Some(name) = params.patch.name.clone() {
state.names.insert(params.thread_id, name);
}
state
.metadata_updates
.entry(params.thread_id)
.or_default()
.merge(params.patch);
stored_thread_from_state(&state, params.thread_id, /*include_history*/ false)
}
async fn archive_thread(&self, _params: ArchiveThreadParams) -> ThreadStoreResult<()> {
self.state.lock().await.calls.archive_thread += 1;
Ok(())
}
async fn unarchive_thread(
&self,
params: ArchiveThreadParams,
) -> ThreadStoreResult<StoredThread> {
let mut state = self.state.lock().await;
state.calls.unarchive_thread += 1;
stored_thread_from_state(&state, params.thread_id, /*include_history*/ false)
}
}
fn stored_thread_from_state(
state: &InMemoryThreadStoreState,
thread_id: ThreadId,
include_history: bool,
) -> ThreadStoreResult<StoredThread> {
let created = state
.created_threads
.get(&thread_id)
.ok_or(ThreadStoreError::ThreadNotFound { thread_id })?;
let history_items = state.histories.get(&thread_id).cloned().unwrap_or_default();
let history = include_history.then(|| StoredThreadHistory {
thread_id,
items: history_items.clone(),
});
let name = state.names.get(&thread_id).cloned().flatten();
let metadata = state.metadata_updates.get(&thread_id);
let rollout_path = state
.rollout_paths
.iter()
.find_map(|(path, mapped_thread_id)| {
(*mapped_thread_id == thread_id).then(|| path.clone())
});
Ok(StoredThread {
thread_id,
rollout_path: metadata
.and_then(|metadata| metadata.rollout_path.clone())
.or(rollout_path),
forked_from_id: created.forked_from_id,
preview: metadata
.and_then(|metadata| metadata.preview.clone())
.unwrap_or_default(),
name,
model_provider: metadata
.and_then(|metadata| metadata.model_provider.clone())
.unwrap_or_else(|| "test".to_string()),
model: metadata.and_then(|metadata| metadata.model.clone()),
reasoning_effort: metadata.and_then(|metadata| metadata.reasoning_effort),
created_at: metadata
.and_then(|metadata| metadata.created_at)
.unwrap_or_else(Utc::now),
updated_at: metadata
.and_then(|metadata| metadata.updated_at)
.unwrap_or_else(Utc::now),
archived_at: None,
cwd: metadata
.and_then(|metadata| metadata.cwd.clone())
.unwrap_or_default(),
cli_version: metadata
.and_then(|metadata| metadata.cli_version.clone())
.unwrap_or_else(|| "test".to_string()),
source: metadata
.and_then(|metadata| metadata.source.clone())
.unwrap_or_else(|| created.source.clone()),
thread_source: metadata
.and_then(|metadata| metadata.thread_source)
.unwrap_or(created.thread_source),
agent_nickname: metadata.and_then(|metadata| metadata.agent_nickname.clone().flatten()),
agent_role: metadata.and_then(|metadata| metadata.agent_role.clone().flatten()),
agent_path: metadata.and_then(|metadata| metadata.agent_path.clone().flatten()),
git_info: metadata.and_then(git_info_from_patch),
approval_mode: metadata
.and_then(|metadata| metadata.approval_mode)
.unwrap_or(AskForApproval::Never),
sandbox_policy: metadata
.and_then(|metadata| metadata.sandbox_policy.clone())
.unwrap_or_else(SandboxPolicy::new_read_only_policy),
token_usage: metadata.and_then(|metadata| metadata.token_usage.clone()),
first_user_message: metadata.and_then(|metadata| metadata.first_user_message.clone()),
history,
})
}
fn git_info_from_patch(patch: &ThreadMetadataPatch) -> Option<codex_protocol::protocol::GitInfo> {
let git_info = patch.git_info.as_ref()?;
let sha = git_info.sha.clone().flatten();
let branch = git_info.branch.clone().flatten();
let origin_url = git_info.origin_url.clone().flatten();
if sha.is_none() && branch.is_none() && origin_url.is_none() {
return None;
}
Some(codex_protocol::protocol::GitInfo {
commit_hash: sha.as_deref().map(codex_git_utils::GitSha::new),
branch,
repository_url: origin_url,
})
}