refactor: codex app-server ThreadState (#11419)

this is a no-op functionality wise. consolidates thread-specific message
processor / event handling state in ThreadState
This commit is contained in:
Max Johnson
2026-02-11 12:20:54 -08:00
committed by GitHub
parent 42e22f3bde
commit b5339a591d
4 changed files with 258 additions and 187 deletions

View File

@@ -137,7 +137,6 @@ use codex_app_server_protocol::ThreadStartedNotification;
use codex_app_server_protocol::ThreadUnarchiveParams;
use codex_app_server_protocol::ThreadUnarchiveResponse;
use codex_app_server_protocol::Turn;
use codex_app_server_protocol::TurnError;
use codex_app_server_protocol::TurnInterruptParams;
use codex_app_server_protocol::TurnStartParams;
use codex_app_server_protocol::TurnStartResponse;
@@ -252,20 +251,7 @@ use uuid::Uuid;
use crate::filters::compute_source_filters;
use crate::filters::source_kind_matches;
type PendingInterruptQueue = Vec<(ConnectionRequestId, ApiVersion)>;
pub(crate) type PendingInterrupts = Arc<Mutex<HashMap<ThreadId, PendingInterruptQueue>>>;
pub(crate) type PendingRollbacks = Arc<Mutex<HashMap<ThreadId, ConnectionRequestId>>>;
/// Per-conversation accumulation of the latest states e.g. error message while a turn runs.
#[derive(Default, Clone)]
pub(crate) struct TurnSummary {
pub(crate) file_change_started: HashSet<String>,
pub(crate) last_error: Option<TurnError>,
}
pub(crate) type TurnSummaryStore = Arc<Mutex<HashMap<ThreadId, TurnSummary>>>;
use crate::thread_state::ThreadStateManager;
const THREAD_LIST_DEFAULT_LIMIT: usize = 25;
const THREAD_LIST_MAX_LIMIT: usize = 100;
@@ -303,21 +289,16 @@ pub(crate) struct CodexMessageProcessor {
config: Arc<Config>,
cli_overrides: Vec<(String, TomlValue)>,
cloud_requirements: Arc<RwLock<CloudRequirementsLoader>>,
conversation_listeners: HashMap<Uuid, oneshot::Sender<()>>,
listener_thread_ids_by_subscription: HashMap<Uuid, ThreadId>,
active_login: Arc<Mutex<Option<ActiveLogin>>>,
// Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives.
pending_interrupts: PendingInterrupts,
// Queue of pending rollback requests per conversation. We reply when ThreadRollback arrives.
pending_rollbacks: PendingRollbacks,
turn_summary_store: TurnSummaryStore,
thread_state_manager: ThreadStateManager,
pending_fuzzy_searches: Arc<Mutex<HashMap<String, Arc<AtomicBool>>>>,
feedback: CodexFeedback,
}
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, Default)]
pub(crate) enum ApiVersion {
V1,
#[default]
V2,
}
@@ -375,12 +356,8 @@ impl CodexMessageProcessor {
config,
cli_overrides,
cloud_requirements,
conversation_listeners: HashMap::new(),
listener_thread_ids_by_subscription: HashMap::new(),
active_login: Arc::new(Mutex::new(None)),
pending_interrupts: Arc::new(Mutex::new(HashMap::new())),
pending_rollbacks: Arc::new(Mutex::new(HashMap::new())),
turn_summary_store: Arc::new(Mutex::new(HashMap::new())),
thread_state_manager: ThreadStateManager::new(),
pending_fuzzy_searches: Arc::new(Mutex::new(HashMap::new())),
feedback,
}
@@ -1012,7 +989,7 @@ impl CodexMessageProcessor {
cloud_requirements.as_ref(),
auth_manager.clone(),
chatgpt_base_url,
codex_home.clone(),
codex_home,
);
sync_default_client_residency_requirement(
&cli_overrides,
@@ -1120,7 +1097,7 @@ impl CodexMessageProcessor {
cloud_requirements.as_ref(),
auth_manager.clone(),
chatgpt_base_url,
codex_home.clone(),
codex_home,
);
sync_default_client_residency_requirement(
&cli_overrides,
@@ -2330,25 +2307,32 @@ impl CodexMessageProcessor {
let request = request_id.clone();
{
let mut map = self.pending_rollbacks.lock().await;
if map.contains_key(&thread_id) {
self.send_invalid_request_error(
request.clone(),
"rollback already in progress for this thread".to_string(),
)
.await;
return;
let rollback_already_in_progress = {
let thread_state = self.thread_state_manager.thread_state(thread_id);
let mut thread_state = thread_state.lock().await;
if thread_state.pending_rollbacks.is_some() {
true
} else {
thread_state.pending_rollbacks = Some(request.clone());
false
}
map.insert(thread_id, request.clone());
};
if rollback_already_in_progress {
self.send_invalid_request_error(
request.clone(),
"rollback already in progress for this thread".to_string(),
)
.await;
return;
}
if let Err(err) = thread.submit(Op::ThreadRollback { num_turns }).await {
// No ThreadRollback event will arrive if an error occurs.
// Clean up and reply immediately.
let mut map = self.pending_rollbacks.lock().await;
map.remove(&thread_id);
let thread_state = self.thread_state_manager.thread_state(thread_id);
let mut thread_state = thread_state.lock().await;
thread_state.pending_rollbacks = None;
drop(thread_state);
self.send_internal_error(request, format!("failed to start rollback: {err}"))
.await;
@@ -2646,11 +2630,7 @@ impl CodexMessageProcessor {
/// Best-effort: attach a listener for thread_id if missing.
pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) {
if self
.listener_thread_ids_by_subscription
.values()
.any(|entry| *entry == thread_id)
{
if self.thread_state_manager.has_listener_for_thread(thread_id) {
return;
}
@@ -4313,7 +4293,8 @@ impl CodexMessageProcessor {
let mut state_db_ctx = None;
// If the thread is active, request shutdown and wait briefly.
if let Some(conversation) = self.thread_manager.remove_thread(&thread_id).await {
let removed_conversation = self.thread_manager.remove_thread(&thread_id).await;
if let Some(conversation) = removed_conversation {
if let Some(ctx) = conversation.state_db() {
state_db_ctx = Some(ctx);
}
@@ -4341,6 +4322,9 @@ impl CodexMessageProcessor {
error!("failed to submit Shutdown to thread {thread_id}: {err}");
}
}
self.thread_state_manager
.remove_thread_state(thread_id)
.await;
}
if state_db_ctx.is_none() {
@@ -4880,9 +4864,10 @@ impl CodexMessageProcessor {
// Record the pending interrupt so we can reply when TurnAborted arrives.
{
let mut map = self.pending_interrupts.lock().await;
map.entry(conversation_id)
.or_default()
let pending_interrupts = self.thread_state_manager.thread_state(conversation_id);
let mut thread_state = pending_interrupts.lock().await;
thread_state
.pending_interrupts
.push((request, ApiVersion::V1));
}
@@ -5276,9 +5261,10 @@ impl CodexMessageProcessor {
// Record the pending interrupt so we can reply when TurnAborted arrives.
{
let mut map = self.pending_interrupts.lock().await;
map.entry(thread_uuid)
.or_default()
let thread_state = self.thread_state_manager.thread_state(thread_uuid);
let mut thread_state = thread_state.lock().await;
thread_state
.pending_interrupts
.push((request, ApiVersion::V2));
}
@@ -5315,16 +5301,13 @@ impl CodexMessageProcessor {
params: RemoveConversationListenerParams,
) {
let RemoveConversationListenerParams { subscription_id } = params;
match self.conversation_listeners.remove(&subscription_id) {
Some(sender) => {
// Signal the spawned task to exit and acknowledge.
let _ = sender.send(());
if let Some(thread_id) = self
.listener_thread_ids_by_subscription
.remove(&subscription_id)
{
info!("removed listener for thread {thread_id}");
}
match self
.thread_state_manager
.remove_listener(subscription_id)
.await
{
Some(thread_id) => {
info!("removed listener for thread {thread_id}");
let response = RemoveConversationSubscriptionResponse {};
self.outgoing.send_response(request_id, response).await;
}
@@ -5342,7 +5325,7 @@ impl CodexMessageProcessor {
async fn attach_conversation_listener(
&mut self,
conversation_id: ThreadId,
experimental_raw_events: bool,
raw_events_enabled: bool,
api_version: ApiVersion,
) -> Result<Uuid, JSONRPCErrorError> {
let conversation = match self.thread_manager.get_thread(conversation_id).await {
@@ -5358,16 +5341,11 @@ impl CodexMessageProcessor {
let subscription_id = Uuid::new_v4();
let (cancel_tx, mut cancel_rx) = oneshot::channel();
self.conversation_listeners
.insert(subscription_id, cancel_tx);
self.listener_thread_ids_by_subscription
.insert(subscription_id, conversation_id);
let thread_state = self
.thread_state_manager
.set_listener(subscription_id, conversation_id, cancel_tx)
.await;
let outgoing_for_task = self.outgoing.clone();
let pending_interrupts = self.pending_interrupts.clone();
let pending_rollbacks = self.pending_rollbacks.clone();
let turn_summary_store = self.turn_summary_store.clone();
let api_version_for_task = api_version;
let fallback_model_provider = self.config.model_provider_id.clone();
tokio::spawn(async move {
loop {
@@ -5385,10 +5363,9 @@ impl CodexMessageProcessor {
}
};
if let EventMsg::RawResponseItem(_) = &event.msg
&& !experimental_raw_events {
continue;
}
if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled {
continue;
}
// For now, we send a notification for every event,
// JSON-serializing the `Event` as-is, but these should
@@ -5427,10 +5404,8 @@ impl CodexMessageProcessor {
conversation_id,
conversation.clone(),
outgoing_for_task.clone(),
pending_interrupts.clone(),
pending_rollbacks.clone(),
turn_summary_store.clone(),
api_version_for_task,
thread_state.clone(),
api_version,
fallback_model_provider.clone(),
)
.await;
@@ -5699,7 +5674,7 @@ fn replace_cloud_requirements_loader(
cloud_requirements: &RwLock<CloudRequirementsLoader>,
auth_manager: Arc<AuthManager>,
chatgpt_base_url: String,
codex_home: std::path::PathBuf,
codex_home: PathBuf,
) {
let loader = cloud_requirements_loader(auth_manager, chatgpt_base_url, codex_home);
if let Ok(mut guard) = cloud_requirements.write() {
@@ -6316,4 +6291,30 @@ mod tests {
assert_eq!(summary, expected);
Ok(())
}
#[tokio::test]
async fn removing_one_listener_does_not_cancel_other_subscriptions_for_same_thread()
-> Result<()> {
let mut manager = ThreadStateManager::new();
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let listener_a = Uuid::new_v4();
let listener_b = Uuid::new_v4();
let (cancel_a, cancel_rx_a) = oneshot::channel();
let (cancel_b, mut cancel_rx_b) = oneshot::channel();
manager.set_listener(listener_a, thread_id, cancel_a).await;
manager.set_listener(listener_b, thread_id, cancel_b).await;
assert_eq!(manager.remove_listener(listener_a).await, Some(thread_id));
assert_eq!(cancel_rx_a.await, Ok(()));
assert!(
tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx_b)
.await
.is_err()
);
assert_eq!(manager.remove_listener(listener_b).await, Some(thread_id));
assert_eq!(cancel_rx_b.await, Ok(()));
Ok(())
}
}