mirror of
https://github.com/openai/codex.git
synced 2026-05-18 02:02:30 +00:00
- Route ThreadManager rollout-path resume/fork through ThreadStore history reads. - Add in-memory store coverage proving path-addressed reads are used. This isn't strictly necessary for the ThreadStore migration, since these ThreadManager methods _only_ work for path-based lookups, but I'm trying to migrate all the rollout recorder callsites to use the threadstore were possible for consistency.
293 lines
9.4 KiB
Rust
293 lines
9.4 KiB
Rust
use std::collections::HashMap;
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
use std::sync::Mutex;
|
|
use std::sync::MutexGuard;
|
|
use std::sync::OnceLock;
|
|
|
|
use async_trait::async_trait;
|
|
use chrono::Utc;
|
|
use codex_protocol::ThreadId;
|
|
use codex_protocol::protocol::AskForApproval;
|
|
use codex_protocol::protocol::RolloutItem;
|
|
use codex_protocol::protocol::SandboxPolicy;
|
|
|
|
use crate::AppendThreadItemsParams;
|
|
use crate::ArchiveThreadParams;
|
|
use crate::CreateThreadParams;
|
|
use crate::ListThreadsParams;
|
|
use crate::LoadThreadHistoryParams;
|
|
use crate::ReadThreadByRolloutPathParams;
|
|
use crate::ReadThreadParams;
|
|
use crate::ResumeThreadParams;
|
|
use crate::StoredThread;
|
|
use crate::StoredThreadHistory;
|
|
use crate::ThreadPage;
|
|
use crate::ThreadStore;
|
|
use crate::ThreadStoreError;
|
|
use crate::ThreadStoreResult;
|
|
use crate::UpdateThreadMetadataParams;
|
|
|
|
static IN_MEMORY_THREAD_STORES: OnceLock<Mutex<HashMap<String, Arc<InMemoryThreadStore>>>> =
|
|
OnceLock::new();
|
|
|
|
fn stores() -> &'static Mutex<HashMap<String, Arc<InMemoryThreadStore>>> {
|
|
IN_MEMORY_THREAD_STORES.get_or_init(|| Mutex::new(HashMap::new()))
|
|
}
|
|
|
|
fn stores_guard() -> MutexGuard<'static, HashMap<String, Arc<InMemoryThreadStore>>> {
|
|
match stores().lock() {
|
|
Ok(guard) => guard,
|
|
Err(poisoned) => poisoned.into_inner(),
|
|
}
|
|
}
|
|
|
|
/// Recorded call counts for [`InMemoryThreadStore`].
|
|
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
|
pub struct InMemoryThreadStoreCalls {
|
|
pub create_thread: usize,
|
|
pub resume_thread: usize,
|
|
pub append_items: usize,
|
|
pub persist_thread: usize,
|
|
pub flush_thread: usize,
|
|
pub shutdown_thread: usize,
|
|
pub discard_thread: usize,
|
|
pub load_history: usize,
|
|
pub read_thread: usize,
|
|
pub read_thread_by_rollout_path: usize,
|
|
pub list_threads: usize,
|
|
pub update_thread_metadata: usize,
|
|
pub archive_thread: usize,
|
|
pub unarchive_thread: usize,
|
|
}
|
|
|
|
/// In-memory [`ThreadStore`] implementation for tests and debug configs.
|
|
///
|
|
/// Test and debug configs can select this store by id, letting tests exercise
|
|
/// config-driven non-local persistence without requiring the real remote gRPC
|
|
/// service.
|
|
#[derive(Default)]
|
|
pub struct InMemoryThreadStore {
|
|
state: tokio::sync::Mutex<InMemoryThreadStoreState>,
|
|
}
|
|
|
|
#[derive(Default)]
|
|
struct InMemoryThreadStoreState {
|
|
calls: InMemoryThreadStoreCalls,
|
|
created_threads: HashMap<ThreadId, CreateThreadParams>,
|
|
histories: HashMap<ThreadId, Vec<RolloutItem>>,
|
|
names: HashMap<ThreadId, Option<String>>,
|
|
rollout_paths: HashMap<PathBuf, ThreadId>,
|
|
}
|
|
|
|
impl InMemoryThreadStore {
|
|
/// Returns the store associated with `id`, creating it if needed.
|
|
pub fn for_id(id: impl Into<String>) -> Arc<Self> {
|
|
let id = id.into();
|
|
let mut stores = stores_guard();
|
|
stores
|
|
.entry(id)
|
|
.or_insert_with(|| Arc::new(Self::default()))
|
|
.clone()
|
|
}
|
|
|
|
/// Removes a shared in-memory store for `id`.
|
|
pub fn remove_id(id: &str) -> Option<Arc<Self>> {
|
|
stores_guard().remove(id)
|
|
}
|
|
|
|
/// Returns the calls observed by this store.
|
|
pub async fn calls(&self) -> InMemoryThreadStoreCalls {
|
|
self.state.lock().await.calls.clone()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl ThreadStore for InMemoryThreadStore {
|
|
fn as_any(&self) -> &dyn std::any::Any {
|
|
self
|
|
}
|
|
|
|
async fn create_thread(&self, params: CreateThreadParams) -> ThreadStoreResult<()> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.create_thread += 1;
|
|
state.histories.entry(params.thread_id).or_default();
|
|
state.created_threads.insert(params.thread_id, params);
|
|
Ok(())
|
|
}
|
|
|
|
async fn resume_thread(&self, params: ResumeThreadParams) -> ThreadStoreResult<()> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.resume_thread += 1;
|
|
state.histories.entry(params.thread_id).or_default();
|
|
if let Some(rollout_path) = params.rollout_path {
|
|
state.rollout_paths.insert(rollout_path, params.thread_id);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn append_items(&self, params: AppendThreadItemsParams) -> ThreadStoreResult<()> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.append_items += 1;
|
|
state
|
|
.histories
|
|
.entry(params.thread_id)
|
|
.or_default()
|
|
.extend(params.items);
|
|
Ok(())
|
|
}
|
|
|
|
async fn persist_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> {
|
|
self.state.lock().await.calls.persist_thread += 1;
|
|
Ok(())
|
|
}
|
|
|
|
async fn flush_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> {
|
|
self.state.lock().await.calls.flush_thread += 1;
|
|
Ok(())
|
|
}
|
|
|
|
async fn shutdown_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> {
|
|
self.state.lock().await.calls.shutdown_thread += 1;
|
|
Ok(())
|
|
}
|
|
|
|
async fn discard_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> {
|
|
self.state.lock().await.calls.discard_thread += 1;
|
|
Ok(())
|
|
}
|
|
|
|
async fn load_history(
|
|
&self,
|
|
params: LoadThreadHistoryParams,
|
|
) -> ThreadStoreResult<StoredThreadHistory> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.load_history += 1;
|
|
let items = state.histories.get(¶ms.thread_id).cloned().ok_or(
|
|
ThreadStoreError::ThreadNotFound {
|
|
thread_id: params.thread_id,
|
|
},
|
|
)?;
|
|
Ok(StoredThreadHistory {
|
|
thread_id: params.thread_id,
|
|
items,
|
|
})
|
|
}
|
|
|
|
async fn read_thread(&self, params: ReadThreadParams) -> ThreadStoreResult<StoredThread> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.read_thread += 1;
|
|
stored_thread_from_state(&state, params.thread_id, params.include_history)
|
|
}
|
|
|
|
async fn read_thread_by_rollout_path(
|
|
&self,
|
|
params: ReadThreadByRolloutPathParams,
|
|
) -> ThreadStoreResult<StoredThread> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.read_thread_by_rollout_path += 1;
|
|
let Some(thread_id) = state.rollout_paths.get(¶ms.rollout_path).copied() else {
|
|
return Err(ThreadStoreError::InvalidRequest {
|
|
message: format!(
|
|
"in-memory thread store does not know rollout path {}",
|
|
params.rollout_path.display()
|
|
),
|
|
});
|
|
};
|
|
stored_thread_from_state(&state, thread_id, params.include_history)
|
|
}
|
|
|
|
async fn list_threads(&self, _params: ListThreadsParams) -> ThreadStoreResult<ThreadPage> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.list_threads += 1;
|
|
let mut items = state
|
|
.created_threads
|
|
.keys()
|
|
.map(|thread_id| {
|
|
stored_thread_from_state(&state, *thread_id, /*include_history*/ false)
|
|
})
|
|
.collect::<ThreadStoreResult<Vec<_>>>()?;
|
|
items.sort_by_key(|item| item.thread_id.to_string());
|
|
Ok(ThreadPage {
|
|
items,
|
|
next_cursor: None,
|
|
})
|
|
}
|
|
|
|
async fn update_thread_metadata(
|
|
&self,
|
|
params: UpdateThreadMetadataParams,
|
|
) -> ThreadStoreResult<StoredThread> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.update_thread_metadata += 1;
|
|
if let Some(name) = params.patch.name {
|
|
state.names.insert(params.thread_id, Some(name));
|
|
}
|
|
stored_thread_from_state(&state, params.thread_id, /*include_history*/ false)
|
|
}
|
|
|
|
async fn archive_thread(&self, _params: ArchiveThreadParams) -> ThreadStoreResult<()> {
|
|
self.state.lock().await.calls.archive_thread += 1;
|
|
Ok(())
|
|
}
|
|
|
|
async fn unarchive_thread(
|
|
&self,
|
|
params: ArchiveThreadParams,
|
|
) -> ThreadStoreResult<StoredThread> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.unarchive_thread += 1;
|
|
stored_thread_from_state(&state, params.thread_id, /*include_history*/ false)
|
|
}
|
|
}
|
|
|
|
fn stored_thread_from_state(
|
|
state: &InMemoryThreadStoreState,
|
|
thread_id: ThreadId,
|
|
include_history: bool,
|
|
) -> ThreadStoreResult<StoredThread> {
|
|
let created = state
|
|
.created_threads
|
|
.get(&thread_id)
|
|
.ok_or(ThreadStoreError::ThreadNotFound { thread_id })?;
|
|
let history_items = state.histories.get(&thread_id).cloned().unwrap_or_default();
|
|
let history = include_history.then(|| StoredThreadHistory {
|
|
thread_id,
|
|
items: history_items.clone(),
|
|
});
|
|
let name = state.names.get(&thread_id).cloned().flatten();
|
|
let rollout_path = state
|
|
.rollout_paths
|
|
.iter()
|
|
.find_map(|(path, mapped_thread_id)| {
|
|
(*mapped_thread_id == thread_id).then(|| path.clone())
|
|
});
|
|
|
|
Ok(StoredThread {
|
|
thread_id,
|
|
rollout_path,
|
|
forked_from_id: created.forked_from_id,
|
|
preview: String::new(),
|
|
name,
|
|
model_provider: "test".to_string(),
|
|
model: None,
|
|
reasoning_effort: None,
|
|
created_at: Utc::now(),
|
|
updated_at: Utc::now(),
|
|
archived_at: None,
|
|
cwd: PathBuf::new(),
|
|
cli_version: "test".to_string(),
|
|
source: created.source.clone(),
|
|
thread_source: created.thread_source,
|
|
agent_nickname: None,
|
|
agent_role: None,
|
|
agent_path: None,
|
|
git_info: None,
|
|
approval_mode: AskForApproval::Never,
|
|
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
|
token_usage: None,
|
|
first_user_message: None,
|
|
history,
|
|
})
|
|
}
|