Compare commits

...

2 Commits

Author SHA1 Message Date
Max Johnson
9e506da924 app-server: thread resume behavior updates 2026-02-11 11:42:11 -08:00
Max Johnson
8148704fe3 refactor: ThreadState 2026-02-11 11:02:13 -08:00
3 changed files with 252 additions and 179 deletions

View File

@@ -1,8 +1,6 @@
use crate::codex_message_processor::ApiVersion;
use crate::codex_message_processor::PendingInterrupts;
use crate::codex_message_processor::PendingRollbacks;
use crate::codex_message_processor::ThreadState;
use crate::codex_message_processor::TurnSummary;
use crate::codex_message_processor::TurnSummaryStore;
use crate::codex_message_processor::read_rollout_items_from_rollout;
use crate::codex_message_processor::read_summary_from_rollout;
use crate::codex_message_processor::summary_to_thread;
@@ -98,6 +96,7 @@ use std::collections::HashMap;
use std::convert::TryFrom;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use tracing::error;
@@ -109,9 +108,7 @@ pub(crate) async fn apply_bespoke_event_handling(
conversation_id: ThreadId,
conversation: Arc<CodexThread>,
outgoing: Arc<OutgoingMessageSender>,
pending_interrupts: PendingInterrupts,
pending_rollbacks: PendingRollbacks,
turn_summary_store: TurnSummaryStore,
thread_state: Arc<tokio::sync::Mutex<ThreadState>>,
api_version: ApiVersion,
fallback_model_provider: String,
) {
@@ -122,13 +119,7 @@ pub(crate) async fn apply_bespoke_event_handling(
match msg {
EventMsg::TurnStarted(_) => {}
EventMsg::TurnComplete(_ev) => {
handle_turn_complete(
conversation_id,
event_turn_id,
&outgoing,
&turn_summary_store,
)
.await;
handle_turn_complete(conversation_id, event_turn_id, &outgoing, &thread_state).await;
}
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
call_id,
@@ -159,9 +150,11 @@ pub(crate) async fn apply_bespoke_event_handling(
let patch_changes = convert_patch_changes(&changes);
let first_start = {
let mut map = turn_summary_store.lock().await;
let summary = map.entry(conversation_id).or_default();
summary.file_change_started.insert(item_id.clone())
let mut state = thread_state.lock().await;
state
.turn_summary
.file_change_started
.insert(item_id.clone())
};
if first_start {
let item = ThreadItem::FileChange {
@@ -198,7 +191,7 @@ pub(crate) async fn apply_bespoke_event_handling(
rx,
conversation,
outgoing,
turn_summary_store,
thread_state.clone(),
)
.await;
});
@@ -718,7 +711,7 @@ pub(crate) async fn apply_bespoke_event_handling(
return handle_thread_rollback_failed(
conversation_id,
message,
&pending_rollbacks,
&thread_state,
&outgoing,
)
.await;
@@ -729,7 +722,7 @@ pub(crate) async fn apply_bespoke_event_handling(
codex_error_info: ev.codex_error_info.map(V2CodexErrorInfo::from),
additional_details: None,
};
handle_error(conversation_id, turn_error.clone(), &turn_summary_store).await;
handle_error(conversation_id, turn_error.clone(), &thread_state).await;
outgoing
.send_server_notification(ServerNotification::Error(ErrorNotification {
error: turn_error.clone(),
@@ -867,9 +860,11 @@ pub(crate) async fn apply_bespoke_event_handling(
let item_id = patch_begin_event.call_id.clone();
let first_start = {
let mut map = turn_summary_store.lock().await;
let summary = map.entry(conversation_id).or_default();
summary.file_change_started.insert(item_id.clone())
let mut state = thread_state.lock().await;
state
.turn_summary
.file_change_started
.insert(item_id.clone())
};
if first_start {
let item = ThreadItem::FileChange {
@@ -905,7 +900,7 @@ pub(crate) async fn apply_bespoke_event_handling(
status,
event_turn_id.clone(),
outgoing.as_ref(),
&turn_summary_store,
&thread_state,
)
.await;
}
@@ -950,9 +945,8 @@ pub(crate) async fn apply_bespoke_event_handling(
// We need to detect which item type it is so we can emit the right notification.
// We already have state tracking FileChange items on item/started, so let's use that.
let is_file_change = {
let map = turn_summary_store.lock().await;
map.get(&conversation_id)
.is_some_and(|summary| summary.file_change_started.contains(&item_id))
let state = thread_state.lock().await;
state.turn_summary.file_change_started.contains(&item_id)
};
if is_file_change {
let notification = FileChangeOutputDeltaNotification {
@@ -1049,8 +1043,8 @@ pub(crate) async fn apply_bespoke_event_handling(
// If this is a TurnAborted, reply to any pending interrupt requests.
EventMsg::TurnAborted(turn_aborted_event) => {
let pending = {
let mut map = pending_interrupts.lock().await;
map.remove(&conversation_id).unwrap_or_default()
let mut state = thread_state.lock().await;
std::mem::take(&mut state.pending_interrupts)
};
if !pending.is_empty() {
for (rid, ver) in pending {
@@ -1069,18 +1063,12 @@ pub(crate) async fn apply_bespoke_event_handling(
}
}
handle_turn_interrupted(
conversation_id,
event_turn_id,
&outgoing,
&turn_summary_store,
)
.await;
handle_turn_interrupted(conversation_id, event_turn_id, &outgoing, &thread_state).await;
}
EventMsg::ThreadRolledBack(_rollback_event) => {
let pending = {
let mut map = pending_rollbacks.lock().await;
map.remove(&conversation_id)
let mut state = thread_state.lock().await;
state.pending_rollbacks.take()
};
if let Some(request_id) = pending {
@@ -1245,14 +1233,11 @@ async fn complete_file_change_item(
status: PatchApplyStatus,
turn_id: String,
outgoing: &OutgoingMessageSender,
turn_summary_store: &TurnSummaryStore,
thread_state: &Arc<Mutex<ThreadState>>,
) {
{
let mut map = turn_summary_store.lock().await;
if let Some(summary) = map.get_mut(&conversation_id) {
summary.file_change_started.remove(&item_id);
}
}
let mut state = thread_state.lock().await;
state.turn_summary.file_change_started.remove(&item_id);
drop(state);
let item = ThreadItem::FileChange {
id: item_id,
@@ -1324,20 +1309,20 @@ async fn maybe_emit_raw_response_item_completed(
}
async fn find_and_remove_turn_summary(
conversation_id: ThreadId,
turn_summary_store: &TurnSummaryStore,
_conversation_id: ThreadId,
thread_state: &Arc<Mutex<ThreadState>>,
) -> TurnSummary {
let mut map = turn_summary_store.lock().await;
map.remove(&conversation_id).unwrap_or_default()
let mut state = thread_state.lock().await;
std::mem::take(&mut state.turn_summary)
}
async fn handle_turn_complete(
conversation_id: ThreadId,
event_turn_id: String,
outgoing: &OutgoingMessageSender,
turn_summary_store: &TurnSummaryStore,
thread_state: &Arc<Mutex<ThreadState>>,
) {
let turn_summary = find_and_remove_turn_summary(conversation_id, turn_summary_store).await;
let turn_summary = find_and_remove_turn_summary(conversation_id, thread_state).await;
let (status, error) = match turn_summary.last_error {
Some(error) => (TurnStatus::Failed, Some(error)),
@@ -1351,9 +1336,9 @@ async fn handle_turn_interrupted(
conversation_id: ThreadId,
event_turn_id: String,
outgoing: &OutgoingMessageSender,
turn_summary_store: &TurnSummaryStore,
thread_state: &Arc<Mutex<ThreadState>>,
) {
find_and_remove_turn_summary(conversation_id, turn_summary_store).await;
find_and_remove_turn_summary(conversation_id, thread_state).await;
emit_turn_completed_with_status(
conversation_id,
@@ -1366,15 +1351,12 @@ async fn handle_turn_interrupted(
}
async fn handle_thread_rollback_failed(
conversation_id: ThreadId,
_conversation_id: ThreadId,
message: String,
pending_rollbacks: &PendingRollbacks,
thread_state: &Arc<Mutex<ThreadState>>,
outgoing: &OutgoingMessageSender,
) {
let pending_rollback = {
let mut map = pending_rollbacks.lock().await;
map.remove(&conversation_id)
};
let pending_rollback = thread_state.lock().await.pending_rollbacks.take();
if let Some(request_id) = pending_rollback {
outgoing
@@ -1419,12 +1401,12 @@ async fn handle_token_count_event(
}
async fn handle_error(
conversation_id: ThreadId,
_conversation_id: ThreadId,
error: TurnError,
turn_summary_store: &TurnSummaryStore,
thread_state: &Arc<Mutex<ThreadState>>,
) {
let mut map = turn_summary_store.lock().await;
map.entry(conversation_id).or_default().last_error = Some(error);
let mut state = thread_state.lock().await;
state.turn_summary.last_error = Some(error);
}
async fn on_patch_approval_response(
@@ -1652,7 +1634,7 @@ async fn on_file_change_request_approval_response(
receiver: oneshot::Receiver<JsonValue>,
codex: Arc<CodexThread>,
outgoing: Arc<OutgoingMessageSender>,
turn_summary_store: TurnSummaryStore,
thread_state: Arc<Mutex<ThreadState>>,
) {
let response = receiver.await;
let (decision, completion_status) = match response {
@@ -1685,7 +1667,7 @@ async fn on_file_change_request_approval_response(
status,
event_turn_id.clone(),
outgoing.as_ref(),
&turn_summary_store,
&thread_state,
)
.await;
}
@@ -1915,13 +1897,27 @@ mod tests {
use pretty_assertions::assert_eq;
use rmcp::model::Content;
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
fn new_turn_summary_store() -> TurnSummaryStore {
Arc::new(Mutex::new(HashMap::new()))
fn new_thread_state() -> Arc<Mutex<ThreadState>> {
Arc::new(Mutex::new(ThreadState::default()))
}
async fn recv_broadcast_message(
rx: &mut mpsc::Receiver<OutgoingEnvelope>,
) -> Result<OutgoingMessage> {
let envelope = rx
.recv()
.await
.ok_or_else(|| anyhow!("should send one message"))?;
match envelope {
OutgoingEnvelope::Broadcast { message } => Ok(message),
OutgoingEnvelope::ToConnection { connection_id, .. } => {
bail!("unexpected targeted message for connection {connection_id:?}")
}
}
}
async fn recv_broadcast_message(
@@ -1999,7 +1995,7 @@ mod tests {
#[tokio::test]
async fn test_handle_error_records_message() -> Result<()> {
let conversation_id = ThreadId::new();
let turn_summary_store = new_turn_summary_store();
let thread_state = new_thread_state();
handle_error(
conversation_id,
@@ -2008,11 +2004,11 @@ mod tests {
codex_error_info: Some(V2CodexErrorInfo::InternalServerError),
additional_details: None,
},
&turn_summary_store,
&thread_state,
)
.await;
let turn_summary = find_and_remove_turn_summary(conversation_id, &turn_summary_store).await;
let turn_summary = find_and_remove_turn_summary(conversation_id, &thread_state).await;
assert_eq!(
turn_summary.last_error,
Some(TurnError {
@@ -2030,13 +2026,13 @@ mod tests {
let event_turn_id = "complete1".to_string();
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let turn_summary_store = new_turn_summary_store();
let thread_state = new_thread_state();
handle_turn_complete(
conversation_id,
event_turn_id.clone(),
&outgoing,
&turn_summary_store,
&thread_state,
)
.await;
@@ -2057,7 +2053,7 @@ mod tests {
async fn test_handle_turn_interrupted_emits_interrupted_with_error() -> Result<()> {
let conversation_id = ThreadId::new();
let event_turn_id = "interrupt1".to_string();
let turn_summary_store = new_turn_summary_store();
let thread_state = new_thread_state();
handle_error(
conversation_id,
TurnError {
@@ -2065,7 +2061,7 @@ mod tests {
codex_error_info: None,
additional_details: None,
},
&turn_summary_store,
&thread_state,
)
.await;
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
@@ -2075,7 +2071,7 @@ mod tests {
conversation_id,
event_turn_id.clone(),
&outgoing,
&turn_summary_store,
&thread_state,
)
.await;
@@ -2096,7 +2092,7 @@ mod tests {
async fn test_handle_turn_complete_emits_failed_with_error() -> Result<()> {
let conversation_id = ThreadId::new();
let event_turn_id = "complete_err1".to_string();
let turn_summary_store = new_turn_summary_store();
let thread_state = new_thread_state();
handle_error(
conversation_id,
TurnError {
@@ -2104,7 +2100,7 @@ mod tests {
codex_error_info: Some(V2CodexErrorInfo::Other),
additional_details: None,
},
&turn_summary_store,
&thread_state,
)
.await;
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
@@ -2114,7 +2110,7 @@ mod tests {
conversation_id,
event_turn_id.clone(),
&outgoing,
&turn_summary_store,
&thread_state,
)
.await;
@@ -2336,7 +2332,7 @@ mod tests {
// Conversation A will have two turns; Conversation B will have one turn.
let conversation_a = ThreadId::new();
let conversation_b = ThreadId::new();
let turn_summary_store = new_turn_summary_store();
let thread_state = new_thread_state();
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
@@ -2350,16 +2346,10 @@ mod tests {
codex_error_info: Some(V2CodexErrorInfo::BadRequest),
additional_details: None,
},
&turn_summary_store,
)
.await;
handle_turn_complete(
conversation_a,
a_turn1.clone(),
&outgoing,
&turn_summary_store,
&thread_state,
)
.await;
handle_turn_complete(conversation_a, a_turn1.clone(), &outgoing, &thread_state).await;
// Turn 1 on conversation B
let b_turn1 = "b_turn1".to_string();
@@ -2370,26 +2360,14 @@ mod tests {
codex_error_info: None,
additional_details: None,
},
&turn_summary_store,
)
.await;
handle_turn_complete(
conversation_b,
b_turn1.clone(),
&outgoing,
&turn_summary_store,
&thread_state,
)
.await;
handle_turn_complete(conversation_b, b_turn1.clone(), &outgoing, &thread_state).await;
// Turn 2 on conversation A
let a_turn2 = "a_turn2".to_string();
handle_turn_complete(
conversation_a,
a_turn2.clone(),
&outgoing,
&turn_summary_store,
)
.await;
handle_turn_complete(conversation_a, a_turn2.clone(), &outgoing, &thread_state).await;
// Verify: A turn 1
let msg = recv_broadcast_message(&mut rx).await?;

View File

@@ -254,9 +254,6 @@ use crate::filters::compute_source_filters;
use crate::filters::source_kind_matches;
type PendingInterruptQueue = Vec<(ConnectionRequestId, ApiVersion)>;
pub(crate) type PendingInterrupts = Arc<Mutex<HashMap<ThreadId, PendingInterruptQueue>>>;
pub(crate) type PendingRollbacks = Arc<Mutex<HashMap<ThreadId, ConnectionRequestId>>>;
/// Per-conversation accumulation of the latest states e.g. error message while a turn runs.
#[derive(Default, Clone)]
@@ -265,7 +262,90 @@ pub(crate) struct TurnSummary {
pub(crate) last_error: Option<TurnError>,
}
pub(crate) type TurnSummaryStore = Arc<Mutex<HashMap<ThreadId, TurnSummary>>>;
#[derive(Default)]
pub(crate) struct ThreadState {
pub(crate) pending_interrupts: PendingInterruptQueue,
pub(crate) pending_rollbacks: Option<ConnectionRequestId>,
pub(crate) turn_summary: TurnSummary,
pub(crate) listener_cancel_txs: HashMap<Uuid, oneshot::Sender<()>>,
}
impl ThreadState {
fn set_listener(&mut self, subscription_id: Uuid, cancel_tx: oneshot::Sender<()>) {
if let Some(previous) = self.listener_cancel_txs.insert(subscription_id, cancel_tx) {
let _ = previous.send(());
}
}
fn clear_listener(&mut self, subscription_id: Uuid) {
if let Some(cancel_tx) = self.listener_cancel_txs.remove(&subscription_id) {
let _ = cancel_tx.send(());
}
}
fn clear_listeners(&mut self) {
for (_, cancel_tx) in self.listener_cancel_txs.drain() {
let _ = cancel_tx.send(());
}
}
}
#[derive(Default)]
struct ThreadStateManager {
thread_states: HashMap<ThreadId, Arc<Mutex<ThreadState>>>,
thread_id_by_subscription: HashMap<Uuid, ThreadId>,
}
impl ThreadStateManager {
fn new() -> Self {
Self::default()
}
fn has_listener_for_thread(&self, thread_id: ThreadId) -> bool {
self.thread_id_by_subscription
.values()
.any(|existing| *existing == thread_id)
}
fn thread_state(&mut self, thread_id: ThreadId) -> Arc<Mutex<ThreadState>> {
self.thread_states
.entry(thread_id)
.or_insert_with(|| Arc::new(Mutex::new(ThreadState::default())))
.clone()
}
async fn remove_listener(&mut self, subscription_id: Uuid) -> Option<ThreadId> {
let thread_id = self.thread_id_by_subscription.remove(&subscription_id)?;
if let Some(thread_state) = self.thread_states.get(&thread_id) {
thread_state.lock().await.clear_listener(subscription_id);
}
Some(thread_id)
}
async fn remove_thread_state(&mut self, thread_id: ThreadId) {
if let Some(thread_state) = self.thread_states.remove(&thread_id) {
thread_state.lock().await.clear_listeners();
}
self.thread_id_by_subscription
.retain(|_, existing_thread_id| *existing_thread_id != thread_id);
}
async fn set_listener(
&mut self,
subscription_id: Uuid,
thread_id: ThreadId,
cancel_tx: oneshot::Sender<()>,
) -> Arc<Mutex<ThreadState>> {
self.thread_id_by_subscription
.insert(subscription_id, thread_id);
let thread_state = self.thread_state(thread_id);
thread_state
.lock()
.await
.set_listener(subscription_id, cancel_tx);
thread_state
}
}
const THREAD_LIST_DEFAULT_LIMIT: usize = 25;
const THREAD_LIST_MAX_LIMIT: usize = 100;
@@ -303,21 +383,16 @@ pub(crate) struct CodexMessageProcessor {
config: Arc<Config>,
cli_overrides: Vec<(String, TomlValue)>,
cloud_requirements: Arc<RwLock<CloudRequirementsLoader>>,
conversation_listeners: HashMap<Uuid, oneshot::Sender<()>>,
listener_thread_ids_by_subscription: HashMap<Uuid, ThreadId>,
active_login: Arc<Mutex<Option<ActiveLogin>>>,
// Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives.
pending_interrupts: PendingInterrupts,
// Queue of pending rollback requests per conversation. We reply when ThreadRollback arrives.
pending_rollbacks: PendingRollbacks,
turn_summary_store: TurnSummaryStore,
thread_state_manager: ThreadStateManager,
pending_fuzzy_searches: Arc<Mutex<HashMap<String, Arc<AtomicBool>>>>,
feedback: CodexFeedback,
}
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, Default)]
pub(crate) enum ApiVersion {
V1,
#[default]
V2,
}
@@ -375,12 +450,8 @@ impl CodexMessageProcessor {
config,
cli_overrides,
cloud_requirements,
conversation_listeners: HashMap::new(),
listener_thread_ids_by_subscription: HashMap::new(),
active_login: Arc::new(Mutex::new(None)),
pending_interrupts: Arc::new(Mutex::new(HashMap::new())),
pending_rollbacks: Arc::new(Mutex::new(HashMap::new())),
turn_summary_store: Arc::new(Mutex::new(HashMap::new())),
thread_state_manager: ThreadStateManager::new(),
pending_fuzzy_searches: Arc::new(Mutex::new(HashMap::new())),
feedback,
}
@@ -977,7 +1048,6 @@ impl CodexMessageProcessor {
let auth_manager = self.auth_manager.clone();
let cloud_requirements = self.cloud_requirements.clone();
let chatgpt_base_url = self.config.chatgpt_base_url.clone();
let codex_home = self.config.codex_home.clone();
let cli_overrides = self.cli_overrides.clone();
let auth_url = server.auth_url.clone();
tokio::spawn(async move {
@@ -1012,7 +1082,6 @@ impl CodexMessageProcessor {
cloud_requirements.as_ref(),
auth_manager.clone(),
chatgpt_base_url,
codex_home.clone(),
);
sync_default_client_residency_requirement(
&cli_overrides,
@@ -1085,7 +1154,6 @@ impl CodexMessageProcessor {
let auth_manager = self.auth_manager.clone();
let cloud_requirements = self.cloud_requirements.clone();
let chatgpt_base_url = self.config.chatgpt_base_url.clone();
let codex_home = self.config.codex_home.clone();
let cli_overrides = self.cli_overrides.clone();
let auth_url = server.auth_url.clone();
tokio::spawn(async move {
@@ -1120,7 +1188,6 @@ impl CodexMessageProcessor {
cloud_requirements.as_ref(),
auth_manager.clone(),
chatgpt_base_url,
codex_home.clone(),
);
sync_default_client_residency_requirement(
&cli_overrides,
@@ -1292,7 +1359,6 @@ impl CodexMessageProcessor {
self.cloud_requirements.as_ref(),
self.auth_manager.clone(),
self.config.chatgpt_base_url.clone(),
self.config.codex_home.clone(),
);
sync_default_client_residency_requirement(
&self.cli_overrides,
@@ -2321,25 +2387,32 @@ impl CodexMessageProcessor {
let request = request_id.clone();
{
let mut map = self.pending_rollbacks.lock().await;
if map.contains_key(&thread_id) {
self.send_invalid_request_error(
request.clone(),
"rollback already in progress for this thread".to_string(),
)
.await;
return;
let rollback_already_in_progress = {
let thread_state = self.thread_state_manager.thread_state(thread_id);
let mut thread_state = thread_state.lock().await;
if thread_state.pending_rollbacks.is_some() {
true
} else {
thread_state.pending_rollbacks = Some(request.clone());
false
}
map.insert(thread_id, request.clone());
};
if rollback_already_in_progress {
self.send_invalid_request_error(
request.clone(),
"rollback already in progress for this thread".to_string(),
)
.await;
return;
}
if let Err(err) = thread.submit(Op::ThreadRollback { num_turns }).await {
// No ThreadRollback event will arrive if an error occurs.
// Clean up and reply immediately.
let mut map = self.pending_rollbacks.lock().await;
map.remove(&thread_id);
let thread_state = self.thread_state_manager.thread_state(thread_id);
let mut thread_state = thread_state.lock().await;
thread_state.pending_rollbacks = None;
drop(thread_state);
self.send_internal_error(request, format!("failed to start rollback: {err}"))
.await;
@@ -2637,11 +2710,7 @@ impl CodexMessageProcessor {
/// Best-effort: attach a listener for thread_id if missing.
pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) {
if self
.listener_thread_ids_by_subscription
.values()
.any(|entry| *entry == thread_id)
{
if self.thread_state_manager.has_listener_for_thread(thread_id) {
return;
}
@@ -4273,7 +4342,8 @@ impl CodexMessageProcessor {
let mut state_db_ctx = None;
// If the thread is active, request shutdown and wait briefly.
if let Some(conversation) = self.thread_manager.remove_thread(&thread_id).await {
let removed_conversation = self.thread_manager.remove_thread(&thread_id).await;
if let Some(conversation) = removed_conversation {
if let Some(ctx) = conversation.state_db() {
state_db_ctx = Some(ctx);
}
@@ -4301,6 +4371,9 @@ impl CodexMessageProcessor {
error!("failed to submit Shutdown to thread {thread_id}: {err}");
}
}
self.thread_state_manager
.remove_thread_state(thread_id)
.await;
}
if state_db_ctx.is_none() {
@@ -4840,9 +4913,10 @@ impl CodexMessageProcessor {
// Record the pending interrupt so we can reply when TurnAborted arrives.
{
let mut map = self.pending_interrupts.lock().await;
map.entry(conversation_id)
.or_default()
let pending_interrupts = self.thread_state_manager.thread_state(conversation_id);
let mut thread_state = pending_interrupts.lock().await;
thread_state
.pending_interrupts
.push((request, ApiVersion::V1));
}
@@ -5236,9 +5310,10 @@ impl CodexMessageProcessor {
// Record the pending interrupt so we can reply when TurnAborted arrives.
{
let mut map = self.pending_interrupts.lock().await;
map.entry(thread_uuid)
.or_default()
let thread_state = self.thread_state_manager.thread_state(thread_uuid);
let mut thread_state = thread_state.lock().await;
thread_state
.pending_interrupts
.push((request, ApiVersion::V2));
}
@@ -5275,16 +5350,13 @@ impl CodexMessageProcessor {
params: RemoveConversationListenerParams,
) {
let RemoveConversationListenerParams { subscription_id } = params;
match self.conversation_listeners.remove(&subscription_id) {
Some(sender) => {
// Signal the spawned task to exit and acknowledge.
let _ = sender.send(());
if let Some(thread_id) = self
.listener_thread_ids_by_subscription
.remove(&subscription_id)
{
info!("removed listener for thread {thread_id}");
}
match self
.thread_state_manager
.remove_listener(subscription_id)
.await
{
Some(thread_id) => {
info!("removed listener for thread {thread_id}");
let response = RemoveConversationSubscriptionResponse {};
self.outgoing.send_response(request_id, response).await;
}
@@ -5302,7 +5374,7 @@ impl CodexMessageProcessor {
async fn attach_conversation_listener(
&mut self,
conversation_id: ThreadId,
experimental_raw_events: bool,
raw_events_enabled: bool,
api_version: ApiVersion,
) -> Result<Uuid, JSONRPCErrorError> {
let conversation = match self.thread_manager.get_thread(conversation_id).await {
@@ -5318,16 +5390,11 @@ impl CodexMessageProcessor {
let subscription_id = Uuid::new_v4();
let (cancel_tx, mut cancel_rx) = oneshot::channel();
self.conversation_listeners
.insert(subscription_id, cancel_tx);
self.listener_thread_ids_by_subscription
.insert(subscription_id, conversation_id);
let thread_state = self
.thread_state_manager
.set_listener(subscription_id, conversation_id, cancel_tx)
.await;
let outgoing_for_task = self.outgoing.clone();
let pending_interrupts = self.pending_interrupts.clone();
let pending_rollbacks = self.pending_rollbacks.clone();
let turn_summary_store = self.turn_summary_store.clone();
let api_version_for_task = api_version;
let fallback_model_provider = self.config.model_provider_id.clone();
tokio::spawn(async move {
loop {
@@ -5345,10 +5412,9 @@ impl CodexMessageProcessor {
}
};
if let EventMsg::RawResponseItem(_) = &event.msg
&& !experimental_raw_events {
continue;
}
if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled {
continue;
}
// For now, we send a notification for every event,
// JSON-serializing the `Event` as-is, but these should
@@ -5387,10 +5453,8 @@ impl CodexMessageProcessor {
conversation_id,
conversation.clone(),
outgoing_for_task.clone(),
pending_interrupts.clone(),
pending_rollbacks.clone(),
turn_summary_store.clone(),
api_version_for_task,
thread_state.clone(),
api_version,
fallback_model_provider.clone(),
)
.await;
@@ -5659,9 +5723,8 @@ fn replace_cloud_requirements_loader(
cloud_requirements: &RwLock<CloudRequirementsLoader>,
auth_manager: Arc<AuthManager>,
chatgpt_base_url: String,
codex_home: std::path::PathBuf,
) {
let loader = cloud_requirements_loader(auth_manager, chatgpt_base_url, codex_home);
let loader = cloud_requirements_loader(auth_manager, chatgpt_base_url);
if let Ok(mut guard) = cloud_requirements.write() {
*guard = loader;
} else {
@@ -6276,4 +6339,30 @@ mod tests {
assert_eq!(summary, expected);
Ok(())
}
#[tokio::test]
async fn removing_one_listener_does_not_cancel_other_subscriptions_for_same_thread()
-> Result<()> {
let mut manager = ThreadStateManager::new();
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
let listener_a = Uuid::new_v4();
let listener_b = Uuid::new_v4();
let (cancel_a, cancel_rx_a) = oneshot::channel();
let (cancel_b, mut cancel_rx_b) = oneshot::channel();
manager.set_listener(listener_a, thread_id, cancel_a).await;
manager.set_listener(listener_b, thread_id, cancel_b).await;
assert_eq!(manager.remove_listener(listener_a).await, Some(thread_id));
assert_eq!(cancel_rx_a.await, Ok(()));
assert!(
tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx_b)
.await
.is_err()
);
assert_eq!(manager.remove_listener(listener_b).await, Some(thread_id));
assert_eq!(cancel_rx_b.await, Ok(()));
Ok(())
}
}

View File

@@ -402,6 +402,12 @@ impl MessageProcessor {
.await;
}
pub(crate) async fn connection_closed(&mut self, connection_id: ConnectionId) {
self.codex_message_processor
.connection_closed(connection_id)
.await;
}
/// Handle a standalone JSON-RPC response originating from the peer.
pub(crate) async fn process_response(&mut self, response: JSONRPCResponse) {
tracing::info!("<- response: {:?}", response);