mirror of
https://github.com/openai/codex.git
synced 2026-04-30 01:16:54 +00:00
refactor: codex app-server ThreadState (#11419)
this is a no-op functionality wise. consolidates thread-specific message processor / event handling state in ThreadState
This commit is contained in:
@@ -137,7 +137,6 @@ use codex_app_server_protocol::ThreadStartedNotification;
|
||||
use codex_app_server_protocol::ThreadUnarchiveParams;
|
||||
use codex_app_server_protocol::ThreadUnarchiveResponse;
|
||||
use codex_app_server_protocol::Turn;
|
||||
use codex_app_server_protocol::TurnError;
|
||||
use codex_app_server_protocol::TurnInterruptParams;
|
||||
use codex_app_server_protocol::TurnStartParams;
|
||||
use codex_app_server_protocol::TurnStartResponse;
|
||||
@@ -252,20 +251,7 @@ use uuid::Uuid;
|
||||
|
||||
use crate::filters::compute_source_filters;
|
||||
use crate::filters::source_kind_matches;
|
||||
|
||||
type PendingInterruptQueue = Vec<(ConnectionRequestId, ApiVersion)>;
|
||||
pub(crate) type PendingInterrupts = Arc<Mutex<HashMap<ThreadId, PendingInterruptQueue>>>;
|
||||
|
||||
pub(crate) type PendingRollbacks = Arc<Mutex<HashMap<ThreadId, ConnectionRequestId>>>;
|
||||
|
||||
/// Per-conversation accumulation of the latest states e.g. error message while a turn runs.
|
||||
#[derive(Default, Clone)]
|
||||
pub(crate) struct TurnSummary {
|
||||
pub(crate) file_change_started: HashSet<String>,
|
||||
pub(crate) last_error: Option<TurnError>,
|
||||
}
|
||||
|
||||
pub(crate) type TurnSummaryStore = Arc<Mutex<HashMap<ThreadId, TurnSummary>>>;
|
||||
use crate::thread_state::ThreadStateManager;
|
||||
|
||||
const THREAD_LIST_DEFAULT_LIMIT: usize = 25;
|
||||
const THREAD_LIST_MAX_LIMIT: usize = 100;
|
||||
@@ -303,21 +289,16 @@ pub(crate) struct CodexMessageProcessor {
|
||||
config: Arc<Config>,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
cloud_requirements: Arc<RwLock<CloudRequirementsLoader>>,
|
||||
conversation_listeners: HashMap<Uuid, oneshot::Sender<()>>,
|
||||
listener_thread_ids_by_subscription: HashMap<Uuid, ThreadId>,
|
||||
active_login: Arc<Mutex<Option<ActiveLogin>>>,
|
||||
// Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives.
|
||||
pending_interrupts: PendingInterrupts,
|
||||
// Queue of pending rollback requests per conversation. We reply when ThreadRollback arrives.
|
||||
pending_rollbacks: PendingRollbacks,
|
||||
turn_summary_store: TurnSummaryStore,
|
||||
thread_state_manager: ThreadStateManager,
|
||||
pending_fuzzy_searches: Arc<Mutex<HashMap<String, Arc<AtomicBool>>>>,
|
||||
feedback: CodexFeedback,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub(crate) enum ApiVersion {
|
||||
V1,
|
||||
#[default]
|
||||
V2,
|
||||
}
|
||||
|
||||
@@ -375,12 +356,8 @@ impl CodexMessageProcessor {
|
||||
config,
|
||||
cli_overrides,
|
||||
cloud_requirements,
|
||||
conversation_listeners: HashMap::new(),
|
||||
listener_thread_ids_by_subscription: HashMap::new(),
|
||||
active_login: Arc::new(Mutex::new(None)),
|
||||
pending_interrupts: Arc::new(Mutex::new(HashMap::new())),
|
||||
pending_rollbacks: Arc::new(Mutex::new(HashMap::new())),
|
||||
turn_summary_store: Arc::new(Mutex::new(HashMap::new())),
|
||||
thread_state_manager: ThreadStateManager::new(),
|
||||
pending_fuzzy_searches: Arc::new(Mutex::new(HashMap::new())),
|
||||
feedback,
|
||||
}
|
||||
@@ -1012,7 +989,7 @@ impl CodexMessageProcessor {
|
||||
cloud_requirements.as_ref(),
|
||||
auth_manager.clone(),
|
||||
chatgpt_base_url,
|
||||
codex_home.clone(),
|
||||
codex_home,
|
||||
);
|
||||
sync_default_client_residency_requirement(
|
||||
&cli_overrides,
|
||||
@@ -1120,7 +1097,7 @@ impl CodexMessageProcessor {
|
||||
cloud_requirements.as_ref(),
|
||||
auth_manager.clone(),
|
||||
chatgpt_base_url,
|
||||
codex_home.clone(),
|
||||
codex_home,
|
||||
);
|
||||
sync_default_client_residency_requirement(
|
||||
&cli_overrides,
|
||||
@@ -2330,25 +2307,32 @@ impl CodexMessageProcessor {
|
||||
|
||||
let request = request_id.clone();
|
||||
|
||||
{
|
||||
let mut map = self.pending_rollbacks.lock().await;
|
||||
if map.contains_key(&thread_id) {
|
||||
self.send_invalid_request_error(
|
||||
request.clone(),
|
||||
"rollback already in progress for this thread".to_string(),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
let rollback_already_in_progress = {
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_id);
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
if thread_state.pending_rollbacks.is_some() {
|
||||
true
|
||||
} else {
|
||||
thread_state.pending_rollbacks = Some(request.clone());
|
||||
false
|
||||
}
|
||||
|
||||
map.insert(thread_id, request.clone());
|
||||
};
|
||||
if rollback_already_in_progress {
|
||||
self.send_invalid_request_error(
|
||||
request.clone(),
|
||||
"rollback already in progress for this thread".to_string(),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
if let Err(err) = thread.submit(Op::ThreadRollback { num_turns }).await {
|
||||
// No ThreadRollback event will arrive if an error occurs.
|
||||
// Clean up and reply immediately.
|
||||
let mut map = self.pending_rollbacks.lock().await;
|
||||
map.remove(&thread_id);
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_id);
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
thread_state.pending_rollbacks = None;
|
||||
drop(thread_state);
|
||||
|
||||
self.send_internal_error(request, format!("failed to start rollback: {err}"))
|
||||
.await;
|
||||
@@ -2646,11 +2630,7 @@ impl CodexMessageProcessor {
|
||||
|
||||
/// Best-effort: attach a listener for thread_id if missing.
|
||||
pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) {
|
||||
if self
|
||||
.listener_thread_ids_by_subscription
|
||||
.values()
|
||||
.any(|entry| *entry == thread_id)
|
||||
{
|
||||
if self.thread_state_manager.has_listener_for_thread(thread_id) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4313,7 +4293,8 @@ impl CodexMessageProcessor {
|
||||
let mut state_db_ctx = None;
|
||||
|
||||
// If the thread is active, request shutdown and wait briefly.
|
||||
if let Some(conversation) = self.thread_manager.remove_thread(&thread_id).await {
|
||||
let removed_conversation = self.thread_manager.remove_thread(&thread_id).await;
|
||||
if let Some(conversation) = removed_conversation {
|
||||
if let Some(ctx) = conversation.state_db() {
|
||||
state_db_ctx = Some(ctx);
|
||||
}
|
||||
@@ -4341,6 +4322,9 @@ impl CodexMessageProcessor {
|
||||
error!("failed to submit Shutdown to thread {thread_id}: {err}");
|
||||
}
|
||||
}
|
||||
self.thread_state_manager
|
||||
.remove_thread_state(thread_id)
|
||||
.await;
|
||||
}
|
||||
|
||||
if state_db_ctx.is_none() {
|
||||
@@ -4880,9 +4864,10 @@ impl CodexMessageProcessor {
|
||||
|
||||
// Record the pending interrupt so we can reply when TurnAborted arrives.
|
||||
{
|
||||
let mut map = self.pending_interrupts.lock().await;
|
||||
map.entry(conversation_id)
|
||||
.or_default()
|
||||
let pending_interrupts = self.thread_state_manager.thread_state(conversation_id);
|
||||
let mut thread_state = pending_interrupts.lock().await;
|
||||
thread_state
|
||||
.pending_interrupts
|
||||
.push((request, ApiVersion::V1));
|
||||
}
|
||||
|
||||
@@ -5276,9 +5261,10 @@ impl CodexMessageProcessor {
|
||||
|
||||
// Record the pending interrupt so we can reply when TurnAborted arrives.
|
||||
{
|
||||
let mut map = self.pending_interrupts.lock().await;
|
||||
map.entry(thread_uuid)
|
||||
.or_default()
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_uuid);
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
thread_state
|
||||
.pending_interrupts
|
||||
.push((request, ApiVersion::V2));
|
||||
}
|
||||
|
||||
@@ -5315,16 +5301,13 @@ impl CodexMessageProcessor {
|
||||
params: RemoveConversationListenerParams,
|
||||
) {
|
||||
let RemoveConversationListenerParams { subscription_id } = params;
|
||||
match self.conversation_listeners.remove(&subscription_id) {
|
||||
Some(sender) => {
|
||||
// Signal the spawned task to exit and acknowledge.
|
||||
let _ = sender.send(());
|
||||
if let Some(thread_id) = self
|
||||
.listener_thread_ids_by_subscription
|
||||
.remove(&subscription_id)
|
||||
{
|
||||
info!("removed listener for thread {thread_id}");
|
||||
}
|
||||
match self
|
||||
.thread_state_manager
|
||||
.remove_listener(subscription_id)
|
||||
.await
|
||||
{
|
||||
Some(thread_id) => {
|
||||
info!("removed listener for thread {thread_id}");
|
||||
let response = RemoveConversationSubscriptionResponse {};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
@@ -5342,7 +5325,7 @@ impl CodexMessageProcessor {
|
||||
async fn attach_conversation_listener(
|
||||
&mut self,
|
||||
conversation_id: ThreadId,
|
||||
experimental_raw_events: bool,
|
||||
raw_events_enabled: bool,
|
||||
api_version: ApiVersion,
|
||||
) -> Result<Uuid, JSONRPCErrorError> {
|
||||
let conversation = match self.thread_manager.get_thread(conversation_id).await {
|
||||
@@ -5358,16 +5341,11 @@ impl CodexMessageProcessor {
|
||||
|
||||
let subscription_id = Uuid::new_v4();
|
||||
let (cancel_tx, mut cancel_rx) = oneshot::channel();
|
||||
self.conversation_listeners
|
||||
.insert(subscription_id, cancel_tx);
|
||||
self.listener_thread_ids_by_subscription
|
||||
.insert(subscription_id, conversation_id);
|
||||
|
||||
let thread_state = self
|
||||
.thread_state_manager
|
||||
.set_listener(subscription_id, conversation_id, cancel_tx)
|
||||
.await;
|
||||
let outgoing_for_task = self.outgoing.clone();
|
||||
let pending_interrupts = self.pending_interrupts.clone();
|
||||
let pending_rollbacks = self.pending_rollbacks.clone();
|
||||
let turn_summary_store = self.turn_summary_store.clone();
|
||||
let api_version_for_task = api_version;
|
||||
let fallback_model_provider = self.config.model_provider_id.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
@@ -5385,10 +5363,9 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
};
|
||||
|
||||
if let EventMsg::RawResponseItem(_) = &event.msg
|
||||
&& !experimental_raw_events {
|
||||
continue;
|
||||
}
|
||||
if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
// For now, we send a notification for every event,
|
||||
// JSON-serializing the `Event` as-is, but these should
|
||||
@@ -5427,10 +5404,8 @@ impl CodexMessageProcessor {
|
||||
conversation_id,
|
||||
conversation.clone(),
|
||||
outgoing_for_task.clone(),
|
||||
pending_interrupts.clone(),
|
||||
pending_rollbacks.clone(),
|
||||
turn_summary_store.clone(),
|
||||
api_version_for_task,
|
||||
thread_state.clone(),
|
||||
api_version,
|
||||
fallback_model_provider.clone(),
|
||||
)
|
||||
.await;
|
||||
@@ -5699,7 +5674,7 @@ fn replace_cloud_requirements_loader(
|
||||
cloud_requirements: &RwLock<CloudRequirementsLoader>,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
chatgpt_base_url: String,
|
||||
codex_home: std::path::PathBuf,
|
||||
codex_home: PathBuf,
|
||||
) {
|
||||
let loader = cloud_requirements_loader(auth_manager, chatgpt_base_url, codex_home);
|
||||
if let Ok(mut guard) = cloud_requirements.write() {
|
||||
@@ -6316,4 +6291,30 @@ mod tests {
|
||||
assert_eq!(summary, expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn removing_one_listener_does_not_cancel_other_subscriptions_for_same_thread()
|
||||
-> Result<()> {
|
||||
let mut manager = ThreadStateManager::new();
|
||||
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
|
||||
let listener_a = Uuid::new_v4();
|
||||
let listener_b = Uuid::new_v4();
|
||||
let (cancel_a, cancel_rx_a) = oneshot::channel();
|
||||
let (cancel_b, mut cancel_rx_b) = oneshot::channel();
|
||||
|
||||
manager.set_listener(listener_a, thread_id, cancel_a).await;
|
||||
manager.set_listener(listener_b, thread_id, cancel_b).await;
|
||||
|
||||
assert_eq!(manager.remove_listener(listener_a).await, Some(thread_id));
|
||||
assert_eq!(cancel_rx_a.await, Ok(()));
|
||||
assert!(
|
||||
tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx_b)
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
assert_eq!(manager.remove_listener(listener_b).await, Some(thread_id));
|
||||
assert_eq!(cancel_rx_b.await, Ok(()));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user