mirror of
https://github.com/openai/codex.git
synced 2026-05-14 16:22:51 +00:00
- migrate `thread/turns/list` to ThreadStore. Uses ThreadStore for most data now but merges in the in-memory state from thread manager - keep v2 `thread/list` pathless-store friendly by converting `StoredThread` directly to API `Thread` - add regression coverage for pathless store history/listing
286 lines
9.2 KiB
Rust
286 lines
9.2 KiB
Rust
use std::collections::HashMap;
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
use std::sync::Mutex;
|
|
use std::sync::MutexGuard;
|
|
use std::sync::OnceLock;
|
|
|
|
use async_trait::async_trait;
|
|
use chrono::Utc;
|
|
use codex_protocol::ThreadId;
|
|
use codex_protocol::protocol::AskForApproval;
|
|
use codex_protocol::protocol::RolloutItem;
|
|
use codex_protocol::protocol::SandboxPolicy;
|
|
|
|
use crate::AppendThreadItemsParams;
|
|
use crate::ArchiveThreadParams;
|
|
use crate::CreateThreadParams;
|
|
use crate::ListThreadsParams;
|
|
use crate::LoadThreadHistoryParams;
|
|
use crate::ReadThreadByRolloutPathParams;
|
|
use crate::ReadThreadParams;
|
|
use crate::ResumeThreadParams;
|
|
use crate::StoredThread;
|
|
use crate::StoredThreadHistory;
|
|
use crate::ThreadPage;
|
|
use crate::ThreadStore;
|
|
use crate::ThreadStoreError;
|
|
use crate::ThreadStoreResult;
|
|
use crate::UpdateThreadMetadataParams;
|
|
|
|
static IN_MEMORY_THREAD_STORES: OnceLock<Mutex<HashMap<String, Arc<InMemoryThreadStore>>>> =
|
|
OnceLock::new();
|
|
|
|
fn stores() -> &'static Mutex<HashMap<String, Arc<InMemoryThreadStore>>> {
|
|
IN_MEMORY_THREAD_STORES.get_or_init(|| Mutex::new(HashMap::new()))
|
|
}
|
|
|
|
fn stores_guard() -> MutexGuard<'static, HashMap<String, Arc<InMemoryThreadStore>>> {
|
|
match stores().lock() {
|
|
Ok(guard) => guard,
|
|
Err(poisoned) => poisoned.into_inner(),
|
|
}
|
|
}
|
|
|
|
/// Recorded call counts for [`InMemoryThreadStore`].
|
|
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
|
pub struct InMemoryThreadStoreCalls {
|
|
pub create_thread: usize,
|
|
pub resume_thread: usize,
|
|
pub append_items: usize,
|
|
pub persist_thread: usize,
|
|
pub flush_thread: usize,
|
|
pub shutdown_thread: usize,
|
|
pub discard_thread: usize,
|
|
pub load_history: usize,
|
|
pub read_thread: usize,
|
|
pub read_thread_by_rollout_path: usize,
|
|
pub list_threads: usize,
|
|
pub update_thread_metadata: usize,
|
|
pub archive_thread: usize,
|
|
pub unarchive_thread: usize,
|
|
}
|
|
|
|
/// In-memory [`ThreadStore`] implementation for tests and debug configs.
|
|
///
|
|
/// Test and debug configs can select this store by id, letting tests exercise
|
|
/// config-driven non-local persistence without requiring the real remote gRPC
|
|
/// service.
|
|
#[derive(Default)]
|
|
pub struct InMemoryThreadStore {
|
|
state: tokio::sync::Mutex<InMemoryThreadStoreState>,
|
|
}
|
|
|
|
#[derive(Default)]
|
|
struct InMemoryThreadStoreState {
|
|
calls: InMemoryThreadStoreCalls,
|
|
created_threads: HashMap<ThreadId, CreateThreadParams>,
|
|
histories: HashMap<ThreadId, Vec<RolloutItem>>,
|
|
names: HashMap<ThreadId, Option<String>>,
|
|
rollout_paths: HashMap<PathBuf, ThreadId>,
|
|
}
|
|
|
|
impl InMemoryThreadStore {
|
|
/// Returns the store associated with `id`, creating it if needed.
|
|
pub fn for_id(id: impl Into<String>) -> Arc<Self> {
|
|
let id = id.into();
|
|
let mut stores = stores_guard();
|
|
stores
|
|
.entry(id)
|
|
.or_insert_with(|| Arc::new(Self::default()))
|
|
.clone()
|
|
}
|
|
|
|
/// Removes a shared in-memory store for `id`.
|
|
pub fn remove_id(id: &str) -> Option<Arc<Self>> {
|
|
stores_guard().remove(id)
|
|
}
|
|
|
|
/// Returns the calls observed by this store.
|
|
pub async fn calls(&self) -> InMemoryThreadStoreCalls {
|
|
self.state.lock().await.calls.clone()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl ThreadStore for InMemoryThreadStore {
|
|
fn as_any(&self) -> &dyn std::any::Any {
|
|
self
|
|
}
|
|
|
|
async fn create_thread(&self, params: CreateThreadParams) -> ThreadStoreResult<()> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.create_thread += 1;
|
|
state.histories.entry(params.thread_id).or_default();
|
|
state.created_threads.insert(params.thread_id, params);
|
|
Ok(())
|
|
}
|
|
|
|
async fn resume_thread(&self, params: ResumeThreadParams) -> ThreadStoreResult<()> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.resume_thread += 1;
|
|
state.histories.entry(params.thread_id).or_default();
|
|
if let Some(rollout_path) = params.rollout_path {
|
|
state.rollout_paths.insert(rollout_path, params.thread_id);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn append_items(&self, params: AppendThreadItemsParams) -> ThreadStoreResult<()> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.append_items += 1;
|
|
state
|
|
.histories
|
|
.entry(params.thread_id)
|
|
.or_default()
|
|
.extend(params.items);
|
|
Ok(())
|
|
}
|
|
|
|
async fn persist_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> {
|
|
self.state.lock().await.calls.persist_thread += 1;
|
|
Ok(())
|
|
}
|
|
|
|
async fn flush_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> {
|
|
self.state.lock().await.calls.flush_thread += 1;
|
|
Ok(())
|
|
}
|
|
|
|
async fn shutdown_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> {
|
|
self.state.lock().await.calls.shutdown_thread += 1;
|
|
Ok(())
|
|
}
|
|
|
|
async fn discard_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> {
|
|
self.state.lock().await.calls.discard_thread += 1;
|
|
Ok(())
|
|
}
|
|
|
|
async fn load_history(
|
|
&self,
|
|
params: LoadThreadHistoryParams,
|
|
) -> ThreadStoreResult<StoredThreadHistory> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.load_history += 1;
|
|
let items = state.histories.get(¶ms.thread_id).cloned().ok_or(
|
|
ThreadStoreError::ThreadNotFound {
|
|
thread_id: params.thread_id,
|
|
},
|
|
)?;
|
|
Ok(StoredThreadHistory {
|
|
thread_id: params.thread_id,
|
|
items,
|
|
})
|
|
}
|
|
|
|
async fn read_thread(&self, params: ReadThreadParams) -> ThreadStoreResult<StoredThread> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.read_thread += 1;
|
|
stored_thread_from_state(&state, params.thread_id, params.include_history)
|
|
}
|
|
|
|
async fn read_thread_by_rollout_path(
|
|
&self,
|
|
params: ReadThreadByRolloutPathParams,
|
|
) -> ThreadStoreResult<StoredThread> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.read_thread_by_rollout_path += 1;
|
|
let Some(thread_id) = state.rollout_paths.get(¶ms.rollout_path).copied() else {
|
|
return Err(ThreadStoreError::InvalidRequest {
|
|
message: format!(
|
|
"in-memory thread store does not know rollout path {}",
|
|
params.rollout_path.display()
|
|
),
|
|
});
|
|
};
|
|
stored_thread_from_state(&state, thread_id, params.include_history)
|
|
}
|
|
|
|
async fn list_threads(&self, _params: ListThreadsParams) -> ThreadStoreResult<ThreadPage> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.list_threads += 1;
|
|
let mut items = state
|
|
.created_threads
|
|
.keys()
|
|
.map(|thread_id| {
|
|
stored_thread_from_state(&state, *thread_id, /*include_history*/ false)
|
|
})
|
|
.collect::<ThreadStoreResult<Vec<_>>>()?;
|
|
items.sort_by_key(|item| item.thread_id.to_string());
|
|
Ok(ThreadPage {
|
|
items,
|
|
next_cursor: None,
|
|
})
|
|
}
|
|
|
|
async fn update_thread_metadata(
|
|
&self,
|
|
params: UpdateThreadMetadataParams,
|
|
) -> ThreadStoreResult<StoredThread> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.update_thread_metadata += 1;
|
|
if let Some(name) = params.patch.name {
|
|
state.names.insert(params.thread_id, Some(name));
|
|
}
|
|
stored_thread_from_state(&state, params.thread_id, /*include_history*/ false)
|
|
}
|
|
|
|
async fn archive_thread(&self, _params: ArchiveThreadParams) -> ThreadStoreResult<()> {
|
|
self.state.lock().await.calls.archive_thread += 1;
|
|
Ok(())
|
|
}
|
|
|
|
async fn unarchive_thread(
|
|
&self,
|
|
params: ArchiveThreadParams,
|
|
) -> ThreadStoreResult<StoredThread> {
|
|
let mut state = self.state.lock().await;
|
|
state.calls.unarchive_thread += 1;
|
|
stored_thread_from_state(&state, params.thread_id, /*include_history*/ false)
|
|
}
|
|
}
|
|
|
|
fn stored_thread_from_state(
|
|
state: &InMemoryThreadStoreState,
|
|
thread_id: ThreadId,
|
|
include_history: bool,
|
|
) -> ThreadStoreResult<StoredThread> {
|
|
let created = state
|
|
.created_threads
|
|
.get(&thread_id)
|
|
.ok_or(ThreadStoreError::ThreadNotFound { thread_id })?;
|
|
let history_items = state.histories.get(&thread_id).cloned().unwrap_or_default();
|
|
let history = include_history.then(|| StoredThreadHistory {
|
|
thread_id,
|
|
items: history_items.clone(),
|
|
});
|
|
let name = state.names.get(&thread_id).cloned().flatten();
|
|
|
|
Ok(StoredThread {
|
|
thread_id,
|
|
rollout_path: None,
|
|
forked_from_id: created.forked_from_id,
|
|
preview: String::new(),
|
|
name,
|
|
model_provider: "test".to_string(),
|
|
model: None,
|
|
reasoning_effort: None,
|
|
created_at: Utc::now(),
|
|
updated_at: Utc::now(),
|
|
archived_at: None,
|
|
cwd: PathBuf::new(),
|
|
cli_version: "test".to_string(),
|
|
source: created.source.clone(),
|
|
agent_nickname: None,
|
|
agent_role: None,
|
|
agent_path: None,
|
|
git_info: None,
|
|
approval_mode: AskForApproval::Never,
|
|
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
|
token_usage: None,
|
|
first_user_message: None,
|
|
history,
|
|
})
|
|
}
|