mirror of
https://github.com/openai/codex.git
synced 2026-03-04 05:33:19 +00:00
Compare commits
2 Commits
fix/notify
...
maxj/threa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e506da924 | ||
|
|
8148704fe3 |
@@ -1,8 +1,6 @@
|
||||
use crate::codex_message_processor::ApiVersion;
|
||||
use crate::codex_message_processor::PendingInterrupts;
|
||||
use crate::codex_message_processor::PendingRollbacks;
|
||||
use crate::codex_message_processor::ThreadState;
|
||||
use crate::codex_message_processor::TurnSummary;
|
||||
use crate::codex_message_processor::TurnSummaryStore;
|
||||
use crate::codex_message_processor::read_rollout_items_from_rollout;
|
||||
use crate::codex_message_processor::read_summary_from_rollout;
|
||||
use crate::codex_message_processor::summary_to_thread;
|
||||
@@ -98,6 +96,7 @@ use std::collections::HashMap;
|
||||
use std::convert::TryFrom;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::error;
|
||||
|
||||
@@ -109,9 +108,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
conversation_id: ThreadId,
|
||||
conversation: Arc<CodexThread>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
pending_interrupts: PendingInterrupts,
|
||||
pending_rollbacks: PendingRollbacks,
|
||||
turn_summary_store: TurnSummaryStore,
|
||||
thread_state: Arc<tokio::sync::Mutex<ThreadState>>,
|
||||
api_version: ApiVersion,
|
||||
fallback_model_provider: String,
|
||||
) {
|
||||
@@ -122,13 +119,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
match msg {
|
||||
EventMsg::TurnStarted(_) => {}
|
||||
EventMsg::TurnComplete(_ev) => {
|
||||
handle_turn_complete(
|
||||
conversation_id,
|
||||
event_turn_id,
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
)
|
||||
.await;
|
||||
handle_turn_complete(conversation_id, event_turn_id, &outgoing, &thread_state).await;
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
call_id,
|
||||
@@ -159,9 +150,11 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
let patch_changes = convert_patch_changes(&changes);
|
||||
|
||||
let first_start = {
|
||||
let mut map = turn_summary_store.lock().await;
|
||||
let summary = map.entry(conversation_id).or_default();
|
||||
summary.file_change_started.insert(item_id.clone())
|
||||
let mut state = thread_state.lock().await;
|
||||
state
|
||||
.turn_summary
|
||||
.file_change_started
|
||||
.insert(item_id.clone())
|
||||
};
|
||||
if first_start {
|
||||
let item = ThreadItem::FileChange {
|
||||
@@ -198,7 +191,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
rx,
|
||||
conversation,
|
||||
outgoing,
|
||||
turn_summary_store,
|
||||
thread_state.clone(),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
@@ -718,7 +711,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
return handle_thread_rollback_failed(
|
||||
conversation_id,
|
||||
message,
|
||||
&pending_rollbacks,
|
||||
&thread_state,
|
||||
&outgoing,
|
||||
)
|
||||
.await;
|
||||
@@ -729,7 +722,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
codex_error_info: ev.codex_error_info.map(V2CodexErrorInfo::from),
|
||||
additional_details: None,
|
||||
};
|
||||
handle_error(conversation_id, turn_error.clone(), &turn_summary_store).await;
|
||||
handle_error(conversation_id, turn_error.clone(), &thread_state).await;
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::Error(ErrorNotification {
|
||||
error: turn_error.clone(),
|
||||
@@ -867,9 +860,11 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
let item_id = patch_begin_event.call_id.clone();
|
||||
|
||||
let first_start = {
|
||||
let mut map = turn_summary_store.lock().await;
|
||||
let summary = map.entry(conversation_id).or_default();
|
||||
summary.file_change_started.insert(item_id.clone())
|
||||
let mut state = thread_state.lock().await;
|
||||
state
|
||||
.turn_summary
|
||||
.file_change_started
|
||||
.insert(item_id.clone())
|
||||
};
|
||||
if first_start {
|
||||
let item = ThreadItem::FileChange {
|
||||
@@ -905,7 +900,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
status,
|
||||
event_turn_id.clone(),
|
||||
outgoing.as_ref(),
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -950,9 +945,8 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
// We need to detect which item type it is so we can emit the right notification.
|
||||
// We already have state tracking FileChange items on item/started, so let's use that.
|
||||
let is_file_change = {
|
||||
let map = turn_summary_store.lock().await;
|
||||
map.get(&conversation_id)
|
||||
.is_some_and(|summary| summary.file_change_started.contains(&item_id))
|
||||
let state = thread_state.lock().await;
|
||||
state.turn_summary.file_change_started.contains(&item_id)
|
||||
};
|
||||
if is_file_change {
|
||||
let notification = FileChangeOutputDeltaNotification {
|
||||
@@ -1049,8 +1043,8 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
// If this is a TurnAborted, reply to any pending interrupt requests.
|
||||
EventMsg::TurnAborted(turn_aborted_event) => {
|
||||
let pending = {
|
||||
let mut map = pending_interrupts.lock().await;
|
||||
map.remove(&conversation_id).unwrap_or_default()
|
||||
let mut state = thread_state.lock().await;
|
||||
std::mem::take(&mut state.pending_interrupts)
|
||||
};
|
||||
if !pending.is_empty() {
|
||||
for (rid, ver) in pending {
|
||||
@@ -1069,18 +1063,12 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
}
|
||||
}
|
||||
|
||||
handle_turn_interrupted(
|
||||
conversation_id,
|
||||
event_turn_id,
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
)
|
||||
.await;
|
||||
handle_turn_interrupted(conversation_id, event_turn_id, &outgoing, &thread_state).await;
|
||||
}
|
||||
EventMsg::ThreadRolledBack(_rollback_event) => {
|
||||
let pending = {
|
||||
let mut map = pending_rollbacks.lock().await;
|
||||
map.remove(&conversation_id)
|
||||
let mut state = thread_state.lock().await;
|
||||
state.pending_rollbacks.take()
|
||||
};
|
||||
|
||||
if let Some(request_id) = pending {
|
||||
@@ -1245,14 +1233,11 @@ async fn complete_file_change_item(
|
||||
status: PatchApplyStatus,
|
||||
turn_id: String,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
turn_summary_store: &TurnSummaryStore,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
{
|
||||
let mut map = turn_summary_store.lock().await;
|
||||
if let Some(summary) = map.get_mut(&conversation_id) {
|
||||
summary.file_change_started.remove(&item_id);
|
||||
}
|
||||
}
|
||||
let mut state = thread_state.lock().await;
|
||||
state.turn_summary.file_change_started.remove(&item_id);
|
||||
drop(state);
|
||||
|
||||
let item = ThreadItem::FileChange {
|
||||
id: item_id,
|
||||
@@ -1324,20 +1309,20 @@ async fn maybe_emit_raw_response_item_completed(
|
||||
}
|
||||
|
||||
async fn find_and_remove_turn_summary(
|
||||
conversation_id: ThreadId,
|
||||
turn_summary_store: &TurnSummaryStore,
|
||||
_conversation_id: ThreadId,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) -> TurnSummary {
|
||||
let mut map = turn_summary_store.lock().await;
|
||||
map.remove(&conversation_id).unwrap_or_default()
|
||||
let mut state = thread_state.lock().await;
|
||||
std::mem::take(&mut state.turn_summary)
|
||||
}
|
||||
|
||||
async fn handle_turn_complete(
|
||||
conversation_id: ThreadId,
|
||||
event_turn_id: String,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
turn_summary_store: &TurnSummaryStore,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
let turn_summary = find_and_remove_turn_summary(conversation_id, turn_summary_store).await;
|
||||
let turn_summary = find_and_remove_turn_summary(conversation_id, thread_state).await;
|
||||
|
||||
let (status, error) = match turn_summary.last_error {
|
||||
Some(error) => (TurnStatus::Failed, Some(error)),
|
||||
@@ -1351,9 +1336,9 @@ async fn handle_turn_interrupted(
|
||||
conversation_id: ThreadId,
|
||||
event_turn_id: String,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
turn_summary_store: &TurnSummaryStore,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
find_and_remove_turn_summary(conversation_id, turn_summary_store).await;
|
||||
find_and_remove_turn_summary(conversation_id, thread_state).await;
|
||||
|
||||
emit_turn_completed_with_status(
|
||||
conversation_id,
|
||||
@@ -1366,15 +1351,12 @@ async fn handle_turn_interrupted(
|
||||
}
|
||||
|
||||
async fn handle_thread_rollback_failed(
|
||||
conversation_id: ThreadId,
|
||||
_conversation_id: ThreadId,
|
||||
message: String,
|
||||
pending_rollbacks: &PendingRollbacks,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
) {
|
||||
let pending_rollback = {
|
||||
let mut map = pending_rollbacks.lock().await;
|
||||
map.remove(&conversation_id)
|
||||
};
|
||||
let pending_rollback = thread_state.lock().await.pending_rollbacks.take();
|
||||
|
||||
if let Some(request_id) = pending_rollback {
|
||||
outgoing
|
||||
@@ -1419,12 +1401,12 @@ async fn handle_token_count_event(
|
||||
}
|
||||
|
||||
async fn handle_error(
|
||||
conversation_id: ThreadId,
|
||||
_conversation_id: ThreadId,
|
||||
error: TurnError,
|
||||
turn_summary_store: &TurnSummaryStore,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
let mut map = turn_summary_store.lock().await;
|
||||
map.entry(conversation_id).or_default().last_error = Some(error);
|
||||
let mut state = thread_state.lock().await;
|
||||
state.turn_summary.last_error = Some(error);
|
||||
}
|
||||
|
||||
async fn on_patch_approval_response(
|
||||
@@ -1652,7 +1634,7 @@ async fn on_file_change_request_approval_response(
|
||||
receiver: oneshot::Receiver<JsonValue>,
|
||||
codex: Arc<CodexThread>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
turn_summary_store: TurnSummaryStore,
|
||||
thread_state: Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
let (decision, completion_status) = match response {
|
||||
@@ -1685,7 +1667,7 @@ async fn on_file_change_request_approval_response(
|
||||
status,
|
||||
event_turn_id.clone(),
|
||||
outgoing.as_ref(),
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -1915,13 +1897,27 @@ mod tests {
|
||||
use pretty_assertions::assert_eq;
|
||||
use rmcp::model::Content;
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
fn new_turn_summary_store() -> TurnSummaryStore {
|
||||
Arc::new(Mutex::new(HashMap::new()))
|
||||
fn new_thread_state() -> Arc<Mutex<ThreadState>> {
|
||||
Arc::new(Mutex::new(ThreadState::default()))
|
||||
}
|
||||
|
||||
async fn recv_broadcast_message(
|
||||
rx: &mut mpsc::Receiver<OutgoingEnvelope>,
|
||||
) -> Result<OutgoingMessage> {
|
||||
let envelope = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("should send one message"))?;
|
||||
match envelope {
|
||||
OutgoingEnvelope::Broadcast { message } => Ok(message),
|
||||
OutgoingEnvelope::ToConnection { connection_id, .. } => {
|
||||
bail!("unexpected targeted message for connection {connection_id:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv_broadcast_message(
|
||||
@@ -1999,7 +1995,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_handle_error_records_message() -> Result<()> {
|
||||
let conversation_id = ThreadId::new();
|
||||
let turn_summary_store = new_turn_summary_store();
|
||||
let thread_state = new_thread_state();
|
||||
|
||||
handle_error(
|
||||
conversation_id,
|
||||
@@ -2008,11 +2004,11 @@ mod tests {
|
||||
codex_error_info: Some(V2CodexErrorInfo::InternalServerError),
|
||||
additional_details: None,
|
||||
},
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
|
||||
let turn_summary = find_and_remove_turn_summary(conversation_id, &turn_summary_store).await;
|
||||
let turn_summary = find_and_remove_turn_summary(conversation_id, &thread_state).await;
|
||||
assert_eq!(
|
||||
turn_summary.last_error,
|
||||
Some(TurnError {
|
||||
@@ -2030,13 +2026,13 @@ mod tests {
|
||||
let event_turn_id = "complete1".to_string();
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
let turn_summary_store = new_turn_summary_store();
|
||||
let thread_state = new_thread_state();
|
||||
|
||||
handle_turn_complete(
|
||||
conversation_id,
|
||||
event_turn_id.clone(),
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -2057,7 +2053,7 @@ mod tests {
|
||||
async fn test_handle_turn_interrupted_emits_interrupted_with_error() -> Result<()> {
|
||||
let conversation_id = ThreadId::new();
|
||||
let event_turn_id = "interrupt1".to_string();
|
||||
let turn_summary_store = new_turn_summary_store();
|
||||
let thread_state = new_thread_state();
|
||||
handle_error(
|
||||
conversation_id,
|
||||
TurnError {
|
||||
@@ -2065,7 +2061,7 @@ mod tests {
|
||||
codex_error_info: None,
|
||||
additional_details: None,
|
||||
},
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
@@ -2075,7 +2071,7 @@ mod tests {
|
||||
conversation_id,
|
||||
event_turn_id.clone(),
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -2096,7 +2092,7 @@ mod tests {
|
||||
async fn test_handle_turn_complete_emits_failed_with_error() -> Result<()> {
|
||||
let conversation_id = ThreadId::new();
|
||||
let event_turn_id = "complete_err1".to_string();
|
||||
let turn_summary_store = new_turn_summary_store();
|
||||
let thread_state = new_thread_state();
|
||||
handle_error(
|
||||
conversation_id,
|
||||
TurnError {
|
||||
@@ -2104,7 +2100,7 @@ mod tests {
|
||||
codex_error_info: Some(V2CodexErrorInfo::Other),
|
||||
additional_details: None,
|
||||
},
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
@@ -2114,7 +2110,7 @@ mod tests {
|
||||
conversation_id,
|
||||
event_turn_id.clone(),
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -2336,7 +2332,7 @@ mod tests {
|
||||
// Conversation A will have two turns; Conversation B will have one turn.
|
||||
let conversation_a = ThreadId::new();
|
||||
let conversation_b = ThreadId::new();
|
||||
let turn_summary_store = new_turn_summary_store();
|
||||
let thread_state = new_thread_state();
|
||||
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
@@ -2350,16 +2346,10 @@ mod tests {
|
||||
codex_error_info: Some(V2CodexErrorInfo::BadRequest),
|
||||
additional_details: None,
|
||||
},
|
||||
&turn_summary_store,
|
||||
)
|
||||
.await;
|
||||
handle_turn_complete(
|
||||
conversation_a,
|
||||
a_turn1.clone(),
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
handle_turn_complete(conversation_a, a_turn1.clone(), &outgoing, &thread_state).await;
|
||||
|
||||
// Turn 1 on conversation B
|
||||
let b_turn1 = "b_turn1".to_string();
|
||||
@@ -2370,26 +2360,14 @@ mod tests {
|
||||
codex_error_info: None,
|
||||
additional_details: None,
|
||||
},
|
||||
&turn_summary_store,
|
||||
)
|
||||
.await;
|
||||
handle_turn_complete(
|
||||
conversation_b,
|
||||
b_turn1.clone(),
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
handle_turn_complete(conversation_b, b_turn1.clone(), &outgoing, &thread_state).await;
|
||||
|
||||
// Turn 2 on conversation A
|
||||
let a_turn2 = "a_turn2".to_string();
|
||||
handle_turn_complete(
|
||||
conversation_a,
|
||||
a_turn2.clone(),
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
)
|
||||
.await;
|
||||
handle_turn_complete(conversation_a, a_turn2.clone(), &outgoing, &thread_state).await;
|
||||
|
||||
// Verify: A turn 1
|
||||
let msg = recv_broadcast_message(&mut rx).await?;
|
||||
|
||||
@@ -254,9 +254,6 @@ use crate::filters::compute_source_filters;
|
||||
use crate::filters::source_kind_matches;
|
||||
|
||||
type PendingInterruptQueue = Vec<(ConnectionRequestId, ApiVersion)>;
|
||||
pub(crate) type PendingInterrupts = Arc<Mutex<HashMap<ThreadId, PendingInterruptQueue>>>;
|
||||
|
||||
pub(crate) type PendingRollbacks = Arc<Mutex<HashMap<ThreadId, ConnectionRequestId>>>;
|
||||
|
||||
/// Per-conversation accumulation of the latest states e.g. error message while a turn runs.
|
||||
#[derive(Default, Clone)]
|
||||
@@ -265,7 +262,90 @@ pub(crate) struct TurnSummary {
|
||||
pub(crate) last_error: Option<TurnError>,
|
||||
}
|
||||
|
||||
pub(crate) type TurnSummaryStore = Arc<Mutex<HashMap<ThreadId, TurnSummary>>>;
|
||||
#[derive(Default)]
|
||||
pub(crate) struct ThreadState {
|
||||
pub(crate) pending_interrupts: PendingInterruptQueue,
|
||||
pub(crate) pending_rollbacks: Option<ConnectionRequestId>,
|
||||
pub(crate) turn_summary: TurnSummary,
|
||||
pub(crate) listener_cancel_txs: HashMap<Uuid, oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
impl ThreadState {
|
||||
fn set_listener(&mut self, subscription_id: Uuid, cancel_tx: oneshot::Sender<()>) {
|
||||
if let Some(previous) = self.listener_cancel_txs.insert(subscription_id, cancel_tx) {
|
||||
let _ = previous.send(());
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_listener(&mut self, subscription_id: Uuid) {
|
||||
if let Some(cancel_tx) = self.listener_cancel_txs.remove(&subscription_id) {
|
||||
let _ = cancel_tx.send(());
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_listeners(&mut self) {
|
||||
for (_, cancel_tx) in self.listener_cancel_txs.drain() {
|
||||
let _ = cancel_tx.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct ThreadStateManager {
|
||||
thread_states: HashMap<ThreadId, Arc<Mutex<ThreadState>>>,
|
||||
thread_id_by_subscription: HashMap<Uuid, ThreadId>,
|
||||
}
|
||||
|
||||
impl ThreadStateManager {
|
||||
fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn has_listener_for_thread(&self, thread_id: ThreadId) -> bool {
|
||||
self.thread_id_by_subscription
|
||||
.values()
|
||||
.any(|existing| *existing == thread_id)
|
||||
}
|
||||
|
||||
fn thread_state(&mut self, thread_id: ThreadId) -> Arc<Mutex<ThreadState>> {
|
||||
self.thread_states
|
||||
.entry(thread_id)
|
||||
.or_insert_with(|| Arc::new(Mutex::new(ThreadState::default())))
|
||||
.clone()
|
||||
}
|
||||
|
||||
async fn remove_listener(&mut self, subscription_id: Uuid) -> Option<ThreadId> {
|
||||
let thread_id = self.thread_id_by_subscription.remove(&subscription_id)?;
|
||||
if let Some(thread_state) = self.thread_states.get(&thread_id) {
|
||||
thread_state.lock().await.clear_listener(subscription_id);
|
||||
}
|
||||
Some(thread_id)
|
||||
}
|
||||
|
||||
async fn remove_thread_state(&mut self, thread_id: ThreadId) {
|
||||
if let Some(thread_state) = self.thread_states.remove(&thread_id) {
|
||||
thread_state.lock().await.clear_listeners();
|
||||
}
|
||||
self.thread_id_by_subscription
|
||||
.retain(|_, existing_thread_id| *existing_thread_id != thread_id);
|
||||
}
|
||||
|
||||
async fn set_listener(
|
||||
&mut self,
|
||||
subscription_id: Uuid,
|
||||
thread_id: ThreadId,
|
||||
cancel_tx: oneshot::Sender<()>,
|
||||
) -> Arc<Mutex<ThreadState>> {
|
||||
self.thread_id_by_subscription
|
||||
.insert(subscription_id, thread_id);
|
||||
let thread_state = self.thread_state(thread_id);
|
||||
thread_state
|
||||
.lock()
|
||||
.await
|
||||
.set_listener(subscription_id, cancel_tx);
|
||||
thread_state
|
||||
}
|
||||
}
|
||||
|
||||
const THREAD_LIST_DEFAULT_LIMIT: usize = 25;
|
||||
const THREAD_LIST_MAX_LIMIT: usize = 100;
|
||||
@@ -303,21 +383,16 @@ pub(crate) struct CodexMessageProcessor {
|
||||
config: Arc<Config>,
|
||||
cli_overrides: Vec<(String, TomlValue)>,
|
||||
cloud_requirements: Arc<RwLock<CloudRequirementsLoader>>,
|
||||
conversation_listeners: HashMap<Uuid, oneshot::Sender<()>>,
|
||||
listener_thread_ids_by_subscription: HashMap<Uuid, ThreadId>,
|
||||
active_login: Arc<Mutex<Option<ActiveLogin>>>,
|
||||
// Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives.
|
||||
pending_interrupts: PendingInterrupts,
|
||||
// Queue of pending rollback requests per conversation. We reply when ThreadRollback arrives.
|
||||
pending_rollbacks: PendingRollbacks,
|
||||
turn_summary_store: TurnSummaryStore,
|
||||
thread_state_manager: ThreadStateManager,
|
||||
pending_fuzzy_searches: Arc<Mutex<HashMap<String, Arc<AtomicBool>>>>,
|
||||
feedback: CodexFeedback,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub(crate) enum ApiVersion {
|
||||
V1,
|
||||
#[default]
|
||||
V2,
|
||||
}
|
||||
|
||||
@@ -375,12 +450,8 @@ impl CodexMessageProcessor {
|
||||
config,
|
||||
cli_overrides,
|
||||
cloud_requirements,
|
||||
conversation_listeners: HashMap::new(),
|
||||
listener_thread_ids_by_subscription: HashMap::new(),
|
||||
active_login: Arc::new(Mutex::new(None)),
|
||||
pending_interrupts: Arc::new(Mutex::new(HashMap::new())),
|
||||
pending_rollbacks: Arc::new(Mutex::new(HashMap::new())),
|
||||
turn_summary_store: Arc::new(Mutex::new(HashMap::new())),
|
||||
thread_state_manager: ThreadStateManager::new(),
|
||||
pending_fuzzy_searches: Arc::new(Mutex::new(HashMap::new())),
|
||||
feedback,
|
||||
}
|
||||
@@ -977,7 +1048,6 @@ impl CodexMessageProcessor {
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let cloud_requirements = self.cloud_requirements.clone();
|
||||
let chatgpt_base_url = self.config.chatgpt_base_url.clone();
|
||||
let codex_home = self.config.codex_home.clone();
|
||||
let cli_overrides = self.cli_overrides.clone();
|
||||
let auth_url = server.auth_url.clone();
|
||||
tokio::spawn(async move {
|
||||
@@ -1012,7 +1082,6 @@ impl CodexMessageProcessor {
|
||||
cloud_requirements.as_ref(),
|
||||
auth_manager.clone(),
|
||||
chatgpt_base_url,
|
||||
codex_home.clone(),
|
||||
);
|
||||
sync_default_client_residency_requirement(
|
||||
&cli_overrides,
|
||||
@@ -1085,7 +1154,6 @@ impl CodexMessageProcessor {
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let cloud_requirements = self.cloud_requirements.clone();
|
||||
let chatgpt_base_url = self.config.chatgpt_base_url.clone();
|
||||
let codex_home = self.config.codex_home.clone();
|
||||
let cli_overrides = self.cli_overrides.clone();
|
||||
let auth_url = server.auth_url.clone();
|
||||
tokio::spawn(async move {
|
||||
@@ -1120,7 +1188,6 @@ impl CodexMessageProcessor {
|
||||
cloud_requirements.as_ref(),
|
||||
auth_manager.clone(),
|
||||
chatgpt_base_url,
|
||||
codex_home.clone(),
|
||||
);
|
||||
sync_default_client_residency_requirement(
|
||||
&cli_overrides,
|
||||
@@ -1292,7 +1359,6 @@ impl CodexMessageProcessor {
|
||||
self.cloud_requirements.as_ref(),
|
||||
self.auth_manager.clone(),
|
||||
self.config.chatgpt_base_url.clone(),
|
||||
self.config.codex_home.clone(),
|
||||
);
|
||||
sync_default_client_residency_requirement(
|
||||
&self.cli_overrides,
|
||||
@@ -2321,25 +2387,32 @@ impl CodexMessageProcessor {
|
||||
|
||||
let request = request_id.clone();
|
||||
|
||||
{
|
||||
let mut map = self.pending_rollbacks.lock().await;
|
||||
if map.contains_key(&thread_id) {
|
||||
self.send_invalid_request_error(
|
||||
request.clone(),
|
||||
"rollback already in progress for this thread".to_string(),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
let rollback_already_in_progress = {
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_id);
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
if thread_state.pending_rollbacks.is_some() {
|
||||
true
|
||||
} else {
|
||||
thread_state.pending_rollbacks = Some(request.clone());
|
||||
false
|
||||
}
|
||||
|
||||
map.insert(thread_id, request.clone());
|
||||
};
|
||||
if rollback_already_in_progress {
|
||||
self.send_invalid_request_error(
|
||||
request.clone(),
|
||||
"rollback already in progress for this thread".to_string(),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
if let Err(err) = thread.submit(Op::ThreadRollback { num_turns }).await {
|
||||
// No ThreadRollback event will arrive if an error occurs.
|
||||
// Clean up and reply immediately.
|
||||
let mut map = self.pending_rollbacks.lock().await;
|
||||
map.remove(&thread_id);
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_id);
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
thread_state.pending_rollbacks = None;
|
||||
drop(thread_state);
|
||||
|
||||
self.send_internal_error(request, format!("failed to start rollback: {err}"))
|
||||
.await;
|
||||
@@ -2637,11 +2710,7 @@ impl CodexMessageProcessor {
|
||||
|
||||
/// Best-effort: attach a listener for thread_id if missing.
|
||||
pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) {
|
||||
if self
|
||||
.listener_thread_ids_by_subscription
|
||||
.values()
|
||||
.any(|entry| *entry == thread_id)
|
||||
{
|
||||
if self.thread_state_manager.has_listener_for_thread(thread_id) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4273,7 +4342,8 @@ impl CodexMessageProcessor {
|
||||
let mut state_db_ctx = None;
|
||||
|
||||
// If the thread is active, request shutdown and wait briefly.
|
||||
if let Some(conversation) = self.thread_manager.remove_thread(&thread_id).await {
|
||||
let removed_conversation = self.thread_manager.remove_thread(&thread_id).await;
|
||||
if let Some(conversation) = removed_conversation {
|
||||
if let Some(ctx) = conversation.state_db() {
|
||||
state_db_ctx = Some(ctx);
|
||||
}
|
||||
@@ -4301,6 +4371,9 @@ impl CodexMessageProcessor {
|
||||
error!("failed to submit Shutdown to thread {thread_id}: {err}");
|
||||
}
|
||||
}
|
||||
self.thread_state_manager
|
||||
.remove_thread_state(thread_id)
|
||||
.await;
|
||||
}
|
||||
|
||||
if state_db_ctx.is_none() {
|
||||
@@ -4840,9 +4913,10 @@ impl CodexMessageProcessor {
|
||||
|
||||
// Record the pending interrupt so we can reply when TurnAborted arrives.
|
||||
{
|
||||
let mut map = self.pending_interrupts.lock().await;
|
||||
map.entry(conversation_id)
|
||||
.or_default()
|
||||
let pending_interrupts = self.thread_state_manager.thread_state(conversation_id);
|
||||
let mut thread_state = pending_interrupts.lock().await;
|
||||
thread_state
|
||||
.pending_interrupts
|
||||
.push((request, ApiVersion::V1));
|
||||
}
|
||||
|
||||
@@ -5236,9 +5310,10 @@ impl CodexMessageProcessor {
|
||||
|
||||
// Record the pending interrupt so we can reply when TurnAborted arrives.
|
||||
{
|
||||
let mut map = self.pending_interrupts.lock().await;
|
||||
map.entry(thread_uuid)
|
||||
.or_default()
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_uuid);
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
thread_state
|
||||
.pending_interrupts
|
||||
.push((request, ApiVersion::V2));
|
||||
}
|
||||
|
||||
@@ -5275,16 +5350,13 @@ impl CodexMessageProcessor {
|
||||
params: RemoveConversationListenerParams,
|
||||
) {
|
||||
let RemoveConversationListenerParams { subscription_id } = params;
|
||||
match self.conversation_listeners.remove(&subscription_id) {
|
||||
Some(sender) => {
|
||||
// Signal the spawned task to exit and acknowledge.
|
||||
let _ = sender.send(());
|
||||
if let Some(thread_id) = self
|
||||
.listener_thread_ids_by_subscription
|
||||
.remove(&subscription_id)
|
||||
{
|
||||
info!("removed listener for thread {thread_id}");
|
||||
}
|
||||
match self
|
||||
.thread_state_manager
|
||||
.remove_listener(subscription_id)
|
||||
.await
|
||||
{
|
||||
Some(thread_id) => {
|
||||
info!("removed listener for thread {thread_id}");
|
||||
let response = RemoveConversationSubscriptionResponse {};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
@@ -5302,7 +5374,7 @@ impl CodexMessageProcessor {
|
||||
async fn attach_conversation_listener(
|
||||
&mut self,
|
||||
conversation_id: ThreadId,
|
||||
experimental_raw_events: bool,
|
||||
raw_events_enabled: bool,
|
||||
api_version: ApiVersion,
|
||||
) -> Result<Uuid, JSONRPCErrorError> {
|
||||
let conversation = match self.thread_manager.get_thread(conversation_id).await {
|
||||
@@ -5318,16 +5390,11 @@ impl CodexMessageProcessor {
|
||||
|
||||
let subscription_id = Uuid::new_v4();
|
||||
let (cancel_tx, mut cancel_rx) = oneshot::channel();
|
||||
self.conversation_listeners
|
||||
.insert(subscription_id, cancel_tx);
|
||||
self.listener_thread_ids_by_subscription
|
||||
.insert(subscription_id, conversation_id);
|
||||
|
||||
let thread_state = self
|
||||
.thread_state_manager
|
||||
.set_listener(subscription_id, conversation_id, cancel_tx)
|
||||
.await;
|
||||
let outgoing_for_task = self.outgoing.clone();
|
||||
let pending_interrupts = self.pending_interrupts.clone();
|
||||
let pending_rollbacks = self.pending_rollbacks.clone();
|
||||
let turn_summary_store = self.turn_summary_store.clone();
|
||||
let api_version_for_task = api_version;
|
||||
let fallback_model_provider = self.config.model_provider_id.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
@@ -5345,10 +5412,9 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
};
|
||||
|
||||
if let EventMsg::RawResponseItem(_) = &event.msg
|
||||
&& !experimental_raw_events {
|
||||
continue;
|
||||
}
|
||||
if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
// For now, we send a notification for every event,
|
||||
// JSON-serializing the `Event` as-is, but these should
|
||||
@@ -5387,10 +5453,8 @@ impl CodexMessageProcessor {
|
||||
conversation_id,
|
||||
conversation.clone(),
|
||||
outgoing_for_task.clone(),
|
||||
pending_interrupts.clone(),
|
||||
pending_rollbacks.clone(),
|
||||
turn_summary_store.clone(),
|
||||
api_version_for_task,
|
||||
thread_state.clone(),
|
||||
api_version,
|
||||
fallback_model_provider.clone(),
|
||||
)
|
||||
.await;
|
||||
@@ -5659,9 +5723,8 @@ fn replace_cloud_requirements_loader(
|
||||
cloud_requirements: &RwLock<CloudRequirementsLoader>,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
chatgpt_base_url: String,
|
||||
codex_home: std::path::PathBuf,
|
||||
) {
|
||||
let loader = cloud_requirements_loader(auth_manager, chatgpt_base_url, codex_home);
|
||||
let loader = cloud_requirements_loader(auth_manager, chatgpt_base_url);
|
||||
if let Ok(mut guard) = cloud_requirements.write() {
|
||||
*guard = loader;
|
||||
} else {
|
||||
@@ -6276,4 +6339,30 @@ mod tests {
|
||||
assert_eq!(summary, expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn removing_one_listener_does_not_cancel_other_subscriptions_for_same_thread()
|
||||
-> Result<()> {
|
||||
let mut manager = ThreadStateManager::new();
|
||||
let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?;
|
||||
let listener_a = Uuid::new_v4();
|
||||
let listener_b = Uuid::new_v4();
|
||||
let (cancel_a, cancel_rx_a) = oneshot::channel();
|
||||
let (cancel_b, mut cancel_rx_b) = oneshot::channel();
|
||||
|
||||
manager.set_listener(listener_a, thread_id, cancel_a).await;
|
||||
manager.set_listener(listener_b, thread_id, cancel_b).await;
|
||||
|
||||
assert_eq!(manager.remove_listener(listener_a).await, Some(thread_id));
|
||||
assert_eq!(cancel_rx_a.await, Ok(()));
|
||||
assert!(
|
||||
tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx_b)
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
assert_eq!(manager.remove_listener(listener_b).await, Some(thread_id));
|
||||
assert_eq!(cancel_rx_b.await, Ok(()));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -402,6 +402,12 @@ impl MessageProcessor {
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn connection_closed(&mut self, connection_id: ConnectionId) {
|
||||
self.codex_message_processor
|
||||
.connection_closed(connection_id)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Handle a standalone JSON-RPC response originating from the peer.
|
||||
pub(crate) async fn process_response(&mut self, response: JSONRPCResponse) {
|
||||
tracing::info!("<- response: {:?}", response);
|
||||
|
||||
Reference in New Issue
Block a user